본문 바로가기

자료구조

느리게 갱신되는 세그먼트 트리 (Segment Tree Lazy Propagation)

728x90

개념

최대한 늦게 세그먼트 트리를 갱신시키는 방법으로, Lazy Value 값을 노드에 저장시켜놓은뒤, 노드의 방문을 필요때마다, 업데이트를 시켜준다.

쿼리나, 다른 업데이트가 필요할때, 그때 노드를 방문함으로서, 업데이트를 최대한 미루는 방식이다.

Lazy Value 대상은 자식노드들이 모두 업데이트가 필요할 경우이다. 모든 자식노드가 업데이트가 필요할경우 해당 부모노드에서 자식노드까지의 방문을 멈추고, 해당 부모노드에 Lazy Value 값을 갱신시킨다. 추후 나중에 필요할때 Lazy Value 값을 적용 자식들에게 가중치를 전파한다.

가령, 원소는 (0, 9)가 존재한다고 하고, (3, 7)구간을 업데이트 한다고 가정한다.

빨간색과, 흰색을 제외한 부분은 다 업데이트를 해야하는 부분이며, 하늘색 부분은 기존의 세그먼트 업데이트 방식 그대로 업데이트한다. 3-4노드, 5-7노드는 자식노드가 모두 업데이트가 필요하기때문에 해당노드에 Lazy Value 값을 갱신시킨뒤, 자식노드들을 업데이트 하지 않는다.

느리게 갱신되는 세그먼트 트리는 기존세그먼트에서 아래의 기능이 추가된다.

구현

  1. 지연된 노드를 구하는방법

가령, left, right가 주어진다고 하자, 함수의 인자로는 다음과 같이 필요하다.

  • ( start, end, left, right, value )

start와 end는 배열의 크기를 나타내고, left와 right는 범위를 나타내고, value(는 가중치를 나타낸다.

지연된 값을 저장할 노드를 구하는 방법은 start ≤ left and right ≥ end 안에 속하는 값이다. 그 값을 제외한 나머지는 업데이트를 진행 해주어야한다.

  1. 지연된 노드를 방문할때 자식노드에게 전파하는 방법

쿼리나 업데이트를 진행할때, 노드를 방문하게 된다. 해당 노드에 Lazy Value가 있다면, 현재 노드에 값을 계산해준뒤, 자식노드에게 전파 해주어야한다.

현재 노드값을 계산할때는 아래와 같은 식으로 계산한다.

구간합을 구할때 → tree[node] = (end - start + 1 ) * lazy[node] → 구간 갯수만큼 업데이트

자기자신이 리프노드가 아닐경우 자식에게 전파를 해야한다.

코드구현

import sys
input = sys.stdin.readline

N, M, K = map(int, input().split())

tree_size = 4 * N

arr = [int(input()) for _ in range(N)]

tree = [0 for _ in range(tree_size)]
lazy = [0 for _ in range(tree_size)]

def init(start, end, node):
    if start == end:
        tree[node] += arr[start]
    else:
        mid = (start + end) // 2
        init(start, mid, node * 2)
        init(mid+1, end, node * 2 + 1)
        tree[node] = tree[node * 2] + tree[node * 2 + 1]

def update_lazy(start, end, node):
    if lazy[node] != 0:
        tree[node] += lazy[node] * (end - start + 1)
        if start != end:
            lazy[node * 2] += lazy[node]
            lazy[node * 2 + 1] += lazy[node]
        lazy[node] = 0
    
def update(start, end, left, right, node, diff):
    update_lazy(start, end, node)
    if left > end or right < start:
        return
    
    if start >= left and end <= right:
        tree[node] += (end - start + 1) * diff
        if start != end:
            lazy[node * 2] += diff
            lazy[node * 2 + 1] += diff
        return
    mid = ( start + end ) // 2
    update(start, mid, left, right, node * 2, diff)
    update(mid + 1, end, left, right, node * 2 + 1, diff)
    tree[node] = tree[node * 2] + tree[node * 2 + 1]
    
def query(start, end, left, right, node):
    update_lazy(start, end, node)
    if left > end or right < start:
        return 0
    if start >= left and end <= right:
        return tree[node]
    mid = ( start + end ) // 2
    lquery = query(start, mid, left, right, node * 2)
    rquery = query(mid + 1, end, left, right, node * 2 + 1)
    return lquery + rquery
    

init(0, N-1, 1)

for _ in range(M+K):
    temp = list(map(int, input().split()))
    if temp[0] == 1:
        b, c, d = temp[1:]
        update(0, N-1, b-1, c-1, 1, d)
    else:
        b, c = temp[1:]
        print(query(0, N-1, b-1, c-1, 1))
728x90

'자료구조' 카테고리의 다른 글

[자료구조] 트라이  (0) 2024.05.14
[Java] 자료구조  (0) 2024.02.26