ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 중급반 8주차 풀이 - 세그먼트 트리
    ALOHA - 2022/ALOHA - 중급반 (2022 1학기) 2022. 5. 31. 13:05

    1. [2042] 구간 합 구하기 (Gold I)

    더보기

    Segment Tree (Point Update & Range Sum Query) 연습 문제입니다.

    const int N = 1048576;
    ll seg[2097160];
    
    void upd(int pos, ll val){ pos += N-1;
    	seg[pos] = val; pos >>= 1;
    	while (pos){
    		seg[pos] = seg[pos<<1] + seg[pos<<1|1]; pos >>= 1;
    	}
    }
    ll qry(int st, int ed){ st += N-1; ed += N-1;
    	ll res = 0;
    	while (st <= ed){
    		if (st & 1){ res += seg[st]; st += 1; }
    		if (~ed & 1){ res += seg[ed]; ed -= 1; }
    		if (st > ed){ break; }
    		st >>= 1; ed >>= 1;
    	}
    	return res;
    }
    
    void Main(){
    	int n, m, k; cin >> n >> m >> k; int q = m+k;
    	for (int i = 1; i <= n; i++){ cin >> seg[i+N-1]; }
    	for (int i = N-1; i >= 1; i--){ seg[i] = seg[i<<1] + seg[i<<1|1]; }
    	while (q--){
    		int typ; cin >> typ;
    		if (typ == 1){ int pos; ll val; cin >> pos >> val; upd(pos, val); }
    		if (typ == 2){ int st, ed; cin >> st >> ed; cout << qry(st, ed) << endl; }
    	}
    }

    2. [2357] 최솟값과 최댓값 (Gold I)

    더보기

    Segment Tree (Range Minimum/Maximum Query) 연습 문제입니다.

    const int INF = 2e9;
    const int N = 131072;
    pi2 seg[262150];
    
    pi2 qry(int st, int ed){ st += N-1; ed += N-1;
    	pi2 res = {INF, 0};
    	while (st <= ed){
    		if (st & 1){
    			res.fr = min(res.fr, seg[st].fr); res.sc = max(res.sc, seg[st].sc);
    			st += 1;
    		}
    		if (~ed & 1){
    			res.fr = min(res.fr, seg[ed].fr); res.sc = max(res.sc, seg[ed].sc);
    			ed -= 1;
    		}
    		if (st > ed){ break; }
    		st >>= 1; ed >>= 1;
    	}
    	return res;
    }
    
    void Main(){
    	int n, q; cin >> n >> q;
    	for (int i = 1; i <= n; i++){ cin >> seg[i+N-1].fr; seg[i+N-1].sc = seg[i+N-1].fr; }
    	for (int i = N-1; i >= 1; i--){
    		seg[i].fr = min(seg[i<<1].fr, seg[i<<1|1].fr);
    		seg[i].sc = max(seg[i<<1].sc, seg[i<<1|1].sc);
    	}
    	while (q--){
    		int st, ed; cin >> st >> ed;
    		pi2 res = qry(st, ed); cout << res.fr << ' ' << res.sc << endl;
    	}
    }

    3. [10090] Counting Inversions (Platinum V)

    더보기

    Inversion을 좀 더 명확히 정의하면, \( i < j \)이면서 \( A_{i} > A_{j} \)인 \( (i, j) \) 쌍을 의미합니다.

    그럼 우리가 \( j \)를 기준으로 돌리면서 앞에 나온 값들 중 \( A_{j} \)보다 더 큰 값의 개수를 찾아주면 됩니다.

     

    근데 이건 어떻게 구할까요?

    \( B_{i} := \) \( A_{k} = i \)인 \( k < j \)의 개수 (0 or 1)를 정의하면,

    \( A_{j} \)보다 더 큰 값의 개수 = \( \sum_{x=A_{j}+1}^{\infty} B_{x} \)가 됩니다.

    즉, Range Sum Query로 잘 해결할 수 있는 문제가 됩니다.

     

    추가적으로, 위 쿼리를 수행한 뒤에는 지금 보고 있던 \( A_{j} \)가 \( A_{i} \)로 들어가야 하니

    \( B_{A_{i}} \)에 1을 더해줘야 합니다.

    즉, Point Update까지 들어가게 됩니다.

     

    그러니까 이 문제는 완벽한 세그먼트 트리 문제가 됩니다.

    const int N = 1048576;
    int seg[2097160];
    
    void upd(int pos, int val){ pos += N-1;
    	seg[pos] += val; pos >>= 1;
    	while (pos){
    		seg[pos] = seg[pos<<1] + seg[pos<<1|1]; pos >>= 1;
    	}
    }
    int qry(int st, int ed){ st += N-1; ed += N-1;
    	ll res = 0;
    	while (st <= ed){
    		if (st & 1){ res += seg[st]; st += 1; }
    		if (~ed & 1){ res += seg[ed]; ed -= 1; }
    		if (st > ed){ break; }
    		st >>= 1; ed >>= 1;
    	}
    	return res;
    }
    
    void Main(){
    	int n; cin >> n;
    	ll ans = 0;
    	for (int i = 1; i <= n; i++){
    		int x; cin >> x;
    		ans += qry(x+1, N);
    		upd(x, 1);
    	}
    	cout << ans;
    }
    더보기

    중급반 2주차 때 만났던 "[1517] 버블 소트" 문제를 기억하시나요?

    버블 소트는 1회 swap이 Inversion의 1 감소와 동일하기 때문에

    버블 소트의 swap 횟수 = Inversion의 개수가 됩니다.

    즉, 같은 문제입니다.

    int arr[1000020], tmp[1000020];
    
    ll ans = 0;
    void f(int st, int ed){ int mid = st+ed >> 1;
    	if (st == ed){ return; }
    	f(st, mid); f(mid+1, ed);
    	int p1 = st, p2 = mid+1, ptr = st;
    	while (p1 <= mid || p2 <= ed){
    		if (p1 > mid){ tmp[ptr] = arr[p2]; p2 += 1; ptr += 1; }
    		else if (p2 > ed){ tmp[ptr] = arr[p1]; p1 += 1; ptr += 1; }
    		else{
    			if (arr[p1] <= arr[p2]){ tmp[ptr] = arr[p1]; p1 += 1; ptr += 1; }
    			else{
    				ans += mid-p1 + 1;
    				tmp[ptr] = arr[p2]; p2 += 1; ptr += 1;
    			}
    		}
    	}
    	for (int i = st; i <= ed; i++){ arr[i] = tmp[i]; }
    }
    
    void Main(){
    	int n; cin >> n;
    	for (int i = 1; i <= n; i++){ cin >> arr[i]; }
    	f(1, n);
    	cout << ans;
    }

    4. [11505] 구간 곱 구하기 (Gold I)

    더보기

    Range Multiplication Query는 자주 나오지는 않지만, 세그먼트 트리가 다양한 연산에 적용될 수 있음을 보이기에 좋은 문제입니다.

    Range Sum과의 차이점은 합 대신 곱을 계산한다는 점과 Query에서의 기본값이 1로 바뀐다는 점밖에 없습니다.

    + modular 연산을 곱할 때마다 해줘야 합니다!

    const ll mod = 1e9 + 7;
    const int N = 1048576;
    ll seg[2097160];
    
    void upd(int pos, ll val){ pos += N-1;
    	seg[pos] = val; pos >>= 1;
    	while (pos){
    		seg[pos] = seg[pos<<1] * seg[pos<<1|1] % mod; pos >>= 1;
    	}
    }
    ll qry(int st, int ed){ st += N-1; ed += N-1;
    	ll res = 1;
    	while (st <= ed){
    		if (st & 1){ res *= seg[st]; res %= mod; st += 1; }
    		if (~ed & 1){ res *= seg[ed]; res %= mod; ed -= 1; }
    		if (st > ed){ break; }
    		st >>= 1; ed >>= 1;
    	}
    	return res;
    }
    
    void Main(){
    	int n, m, k; cin >> n >> m >> k; int q = m+k;
    	for (int i = 1; i <= n; i++){ cin >> seg[i+N-1]; }
    	for (int i = N-1; i >= 1; i--){ seg[i] = seg[i<<1] * seg[i<<1|1] % mod; }
    	while (q--){
    		int typ; cin >> typ;
    		if (typ == 1){ int pos; ll val; cin >> pos >> val; upd(pos, val); }
    		if (typ == 2){ int st, ed; cin >> st >> ed; cout << qry(st, ed) << endl; }
    	}
    }

    5. [7578] 공장 (Platinum V)

    더보기

    어떤 두 케이블이 교차한다는 건 무슨 소리일까요?

    A열과 B열의 케이블의 연결 상태를 \( f(i) \)라고 놓으면

    A열에서의 두 케이블이 \( i, j \) 위치에 있다고 하면 B열에서는 \( f(i), f(j) \)위치에 있게 됩니다.

    이 두 케이블이 교차하려면 \( i < j \)인데 \( f(i) > f(j) \)여야 하고, 이는 위에서 봤던 Inversion과 동일한 정의입니다.

    즉, \( f(i) \)만 구하면 이 문제는 그냥 Inversion Counting 문제가 됩니다.

    const int N = 524288;
    int seg[1048580];
    
    void upd(int pos, int val){ pos += N-1;
    	seg[pos] += val; pos >>= 1;
    	while (pos){
    		seg[pos] = seg[pos<<1] + seg[pos<<1|1]; pos >>= 1;
    	}
    }
    int qry(int st, int ed){ st += N-1; ed += N-1;
    	ll res = 0;
    	while (st <= ed){
    		if (st & 1){ res += seg[st]; st += 1; }
    		if (~ed & 1){ res += seg[ed]; ed -= 1; }
    		if (st > ed){ break; }
    		st >>= 1; ed >>= 1;
    	}
    	return res;
    }
    
    int pos[1000020];
    
    void Main(){
    	int n; cin >> n;
    	for (int i = 1; i <= n; i++){ int x; cin >> x; pos[x] = i; }
    	ll ans = 0;
    	for (int i = 1; i <= n; i++){
    		int x; cin >> x; x = pos[x];
    		ans += qry(x+1, N);
    		upd(x, 1);
    	}
    	cout << ans;
    }

    6. [3653] 영화 수집 (Platinum IV)

    더보기

    잠시 DVD가 중력을 무시한다고 생각해봅시다.

    그러면, 초기 상태에는 DVD가 1번 위치부터 N번 위치까지 쌓여있는 형태가 됩니다.

    이 상태에서 영화를 한 편 보면, 그 영화는 N+1번 위치에 올라가게 되고,

    그 상태에서 영화를 또 한 편 보면, 그 영화는 N+2번 위치에 올라가게 되고,

    ... 가 반복됩니다.

     

    영화를 볼 때마다 그 위에 있는 DVD의 개수는 나보다 더 높은 위치에 있는 영화의 개수가 되니, 이는 Range Sum Query가 되고, 영화의 위치를 바꾸는 건 "현재 i번 영화의 위치"를 저장하는 배열을 추가적으로 관리해주면서 Point Update를 해주면 됩니다.

     

    이제 세그먼트 트리로 풀면 되는 문제가 되었으니 신나게 써주면 됩니다.

    const int N = 262144;
    int seg[524290];
    
    void upd(int pos, int val){ pos += N-1;
    	seg[pos] = val; pos >>= 1;
    	while (pos){
    		seg[pos] = seg[pos<<1] + seg[pos<<1|1]; pos >>= 1;
    	}
    }
    int qry(int st, int ed){ st += N-1; ed += N-1;
    	ll res = 0;
    	while (st <= ed){
    		if (st & 1){ res += seg[st]; st += 1; }
    		if (~ed & 1){ res += seg[ed]; ed -= 1; }
    		if (st > ed){ break; }
    		st >>= 1; ed >>= 1;
    	}
    	return res;
    }
    
    int ptr[100020];
    
    void Main(){
    	int t; cin >> t; while (t--){
    		int n, q; cin >> n >> q;
    		memset(seg, 0, sizeof(seg));
    		for (int i = 1; i <= n; i++){ ptr[i] = n+1 - i; seg[i+N-1] = 1; }
    		for (int i = N-1; i >= 1; i--){ seg[i] = seg[i<<1] + seg[i<<1|1]; }
    		for (int i = 1; i <= q; i++){
    			int x; cin >> x; int p = ptr[x];
    			cout << qry(p+1, N) << ' '; upd(p, 0); upd(n+i, 1);
    			ptr[x] = n+i;
    		}
    		cout << endl;
    	}
    }

    7. [1306] 달려라 홍준 (Platinum V)

    더보기

    너무 자명하게 Range Maxmimum Query입니다.

    const int N = 1048576;
    int seg[2097160];
    
    int qry(int st, int ed){ st += N-1; ed += N-1;
    	int res = 0;
    	while (st <= ed){
    		if (st & 1){ res = max(res, seg[st]); st += 1; }
    		if (~ed & 1){ res = max(res, seg[ed]); ed -= 1; }
    		if (st > ed){ break; }
    		st >>= 1; ed >>= 1;
    	}
    	return res;
    }
    
    void Main(){
    	int n, m; cin >> n >> m;
    	for (int i = 1; i <= n; i++){ cin >> seg[i+N-1]; }
    	for (int i = N-1; i >= 1; i--){ seg[i] = max(seg[i<<1], seg[i<<1|1]); }
    	for (int i = m; i+m-1 <= n; i++){
    		int st = i-m+1, ed = i+m-1;
    		cout << qry(st, ed) << ' ';
    	}
    }
    더보기

    Update가 없는 Range Maximum Query는 세그먼트 트리 없이도 풀 수 있습니다.

     

    각 위치 \( i \)마다 \( B_{i,k} = \min_{i \le p \lt i+2^k} A_{p} \)를 저장한 뒤

    \( [s, e] \) 구간의 쿼리는 곧 길이 \( l = e-s+1 \)의 구간의 쿼리이므로

    \( l \ge 2^k \)인 가장 큰 \( k \)를 찾아준 뒤 \( [s, s+2^k) \) 구간과 \( (e-2^k, e] \) 구간의 max를 찾아주면 됩니다.

    int arr[1000020];
    int spr[22][1000020];
    
    void Main(){
    	int n, m; cin >> n >> m; int l = 2*m-1;
    	int k = 0; while ((1<<k) <= l){ k += 1; } k -= 1;
    	for (int i = 1; i <= n; i++){ cin >> spr[0][i]; }
    	for (int j = 1; j <= 20; j++){
    		for (int i = 1; i <= n; i++){
    			spr[j][i] = spr[j-1][i]; int p = i + (1 << j-1);
    			if (p <= n){ spr[j][i] = max(spr[j][i], spr[j-1][p]); }
    		}
    	}
    	for (int i = m; i+m-1 <= n; i++){
    		int st = i-m+1, ed = i+m-1;
    		cout << max(spr[k][st], spr[k][ed-(1<<k)+1]) << ' ';
    	}
    }
    더보기

    Update가 없고 구간의 길이가 고정된 Range Maximum Query는 이렇게도 풀 수 있습니다.

     

    우선, 고정된 구간의 길이를 \( l \)이라고 합시다. 이제, 수열을 \( [1, 2l], [l+1, 3l], [2l+1, 4l], \ldots \)로 나눠봅시다.

    지금부터는 편의상 \( [1, 2l] \) 부분만 설명하겠습니다. 다른 구간에서도 동일하게 적용해주면 됩니다.

     

    우선 \( [1, l] \) 구간에서 다음을 정의합니다. \( S_{i} = \max_{i \le p \le l} A_{p} \)

    그리고 \( [l+1, 2l] \) 구간에서 다음을 정의합니다. \( P_{i} = \max_{l+1 \le p \le i} A_{p} \)

    이제 구간 \( [1, 2l] \)에 완벽히 속하는 모든 \( [i, i+l) \)에 대한 답은 \( \max(S_{i}, P_{i+l-1}) \)이 됩니다.

    int arr[3000020];
    int prf[3000020], suf[3000020];
    
    void Main(){
    	int n, m; cin >> n >> m; m = 2*m - 1;
    	for (int i = 1; i <= n; i++){ cin >> arr[i]; }
    	int k = n;
    	for (int i = 1; i <= k; i++){
    		if (i%m == 1%m){ prf[i] = arr[i]; }
    		else{ prf[i] = max(arr[i], prf[i-1]); }
    	}
    	for (int i = k; i >= 1; i--){
    		if (i%m == 0){ suf[i] = arr[i]; }
    		else{ suf[i] = max(arr[i], suf[i+1]); }
    	}
    	for (int i = 1; i <= n-m+1; i++){
    		cout << max(suf[i], prf[i+m-1]) << ' ';
    	}
    }

    8. [1280] 나무 심기 (Platinum IV)

    더보기

    좌표 \( x \)에 심어진 나무의 수를 \( C_{x} \)라고 한 뒤, 좌표 \( p \)에 새로운 나무를 심어봅시다.

    이 때 생기는 비용은 \( \sum_{x=0}^{X} |p-x| \cdot C_{x} \)가 됩니다.

    이 식을 약간 뜯어보면, \( x < p \)인 동안에는 \( pC_{x} - xC_{x} \)를 더하고, \( x > p \)인 동안에는 \( pC_{x} - xC_{x} \)를 빼게 됩니다.

    \( C_{x} \)의 합과 \( xC_{x} \)의 합을 저장하는 세그먼트 트리를 만들면, \( C_{x} \times p - xC_{x} \)를 특정 구간에 쿼리를 날린 뒤 받은 값을 연산해서 계산할 수 있음을 볼 수 있습니다.

     

    (세그먼트 트리를 1-indexed로 구현했다면) 이 문제에서의 좌표가 0이 들어올 수 있음에 주의해주세요.

    const ll mod = 1e9 + 7;
    
    const int N = 262144;
    pl2 seg[524290];
    
    void upd(int idx){ int pos = idx+N;
    	seg[pos].fr += 1; seg[pos].sc += idx; pos >>= 1;
    	while (pos){
    		seg[pos].fr = seg[pos<<1].fr + seg[pos<<1|1].fr; seg[pos].fr %= mod;
    		seg[pos].sc = seg[pos<<1].sc + seg[pos<<1|1].sc; seg[pos].sc %= mod;
    		pos >>= 1;
    	}
    }
    pl2 qry(int st, int ed){ st += N; ed += N;
    	pl2 res = {0, 0};
    	while (st <= ed){
    		if (st & 1){
    			res.fr += seg[st].fr; res.fr %= mod;
    			res.sc += seg[st].sc; res.sc %= mod;
    			st += 1;
    		}
    		if (~ed & 1){
    			res.fr += seg[ed].fr; res.fr %= mod;
    			res.sc += seg[ed].sc; res.sc %= mod;
    			ed -= 1;
    		}
    		if (st > ed){ break; }
    		st >>= 1; ed >>= 1;
    	}
    	return res;
    }
    
    void Main(){
    	int n; cin >> n;
    	ll ans = 1;
    	for (int i = 1; i <= n; i++){
    		int p; cin >> p;
    		if (i > 1){
    			pl2 p1 = qry(0, p-1), p2 = qry(p+1, N);
    			ll res = (p*p1.fr - p1.sc) % mod - (p*p2.fr - p2.sc) % mod;
    			res = (res+mod) % mod;
    			ans *= res; ans %= mod;
    		}
    		upd(p);
    	}
    	cout << ans;
    }
    더보기

    Range Triangular Sum Query를 구현해서 쿼리를 날릴 때마다 각 위치의 가중치를 넣어줄 수도 있고,

    Segment Tree with Lazy Propagation + 직선의 방정식을 사용해서 풀 수도 있습니다.

    [TODO] 이 풀이들 자세한 설명 + 코드 짜기

Designed by Tistory.