728x90
11505번: 구간 곱 구하기
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 곱을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄
www.acmicpc.net
📄 문제개요
수열이 주어지고, 업데이트, 쿼리가 주어졌을때 쿼리를 했을때에 업데이트된 구간 곱을 출력하는 문제이다.
🤔 문제분석
세그먼트 트리를 사용하는 구간곱 문제이고, 기존 구간 합과는 0을 곱해버리면 트리가 다음 업데이트일때 어떠한 수를 곱해도 0이 되기때문에 diff를 사용하지 않는다.
update함수는 leaf 노드까지 방문했다가, leaf노드를 해당 값으로 초기화 시킨뒤, 트리를 모두 다 갱신해준다.
query함수는 구간합을 구하는 문제와 마찬가지로 3가지 상태로 분기시킨다.
- 구간의 범위를 넘어서는경우 → return 1
- 구간의 범위안에 포함되는 경우 → return tree[node]
- 걸치는 경우 → node << 1 과 node << 1 | 1 연산을 통하여 자식노드들을 호출한다.
📝 의사코드
- 세그먼트 트리 초기화
- 쿼리와, 업데이트 분기
- 결과 출력
💻 코드
# <https://www.acmicpc.net/problem/11505>
import sys
DIV = 1000000007
input = sys.stdin.readline
N, M, K = map(int, input().split())
arr = [int(input()) for _ in range(N)]
tree = [0 for _ in range(4*N)]
zero_tree = [False for _ in range(4*N)]
def init(start, end, node):
if start == end:
tree[node] = arr[start]
return tree[node]
else:
mid = ( start + end ) // 2
tree[node] = init(start, mid, node * 2) * init(mid + 1, end, node * 2 + 1) % DIV
return tree[node]
def query(start, end, node, left, right):
if left > end or right < start:
return 1
if left <= start and right >= end:
return tree[node]
mid = ( start + end ) // 2
lmul = query(start, mid, node * 2, left, right)
rmul = query(mid + 1, end, node * 2 + 1, left, right)
return lmul * rmul % DIV
def update(start, end, node, idx, diff):
if idx < start or idx > end:
return
if start == end:
tree[node] = diff
return tree[node]
mid = ( start + end ) // 2
update(start, mid, node * 2, idx, diff)
update(mid + 1, end, node * 2 + 1, idx, diff)
tree[node] = tree[node*2] * tree[node*2+1] % DIV
return tree[node]
init(0, N-1, 1)
for _ in range(M+K):
a, b, c = map(int, input().split())
if a == 1:
update(0, N-1, 1, b-1, c)
arr[b-1] = c
else:
print(query(0, N-1, 1, b-1, c-1))
🎯 피드백 및 개선사항
초기에는 트리를 하나 더 만들어서 0일경우의 예외사항을 처리하려고 하다보니, 문제가 너무 복잡해졌다.
구간 합처럼 diff 값으로 문제를 해결 하는것이 아닌, leaf노드까지 방문하여 업데이트하고, 각각의 부모노드들을 다시 재 업데이트 한다.
❓다른사람은 어떻게 풀었을까?
아래와 같이 문제를 해결 하신 분이 있었습니다.
from sys import stdin
input = stdin.readline
SIZE = 10**9 + 7
def solve():
n, nUpdate, nQuery = map(int, input().split())
tree = [0] * (2 * n)
for i in range(n):
tree[n + i] = int(input())
for i in range(n - 1, 0, -1):
tree[i] = tree[2*i] * tree[2*i + 1] % SIZE
for _ in range(nUpdate + nQuery):
query, *temp = map(int, input().split())
if query == 1:
node = n + (temp[0] - 1)
tree[node] = temp[1]
node //= 2
while node:
tree[node] = tree[2*node] * tree[2*node + 1] % SIZE
node //= 2
else:
left = n + (temp[0] - 1)
right = n + (temp[1] - 1)
value = 1
while left <= right:
if left % 2:
value = value * tree[left] % SIZE
left += 1
left //= 2
if not right % 2:
value = value * tree[right] % SIZE
right -= 1
right //= 2
print(value)
solve()
728x90
'알고리즘 > 백준' 카테고리의 다른 글
[백준] 6549번 : 히스토그램에서 가장 큰 직사각형 (1) | 2024.01.07 |
---|---|
[백준] 7578번 : 공장 (1) | 2024.01.06 |
[백준] 10999번 : 구간 합 구하기2 (1) | 2024.01.03 |
[백준] 5676번 : 음주코딩 (1) | 2024.01.01 |
[백준] 18500번 : 미네랄 2 (1) | 2023.12.30 |