본문 바로가기

알고리즘/백준

[백준] 5676번 : 음주코딩

728x90

5676번: 음주 코딩

 

5676번: 음주 코딩

각 테스트 케이스마다 곱셈 명령의 결과를 한 줄에 모두 출력하면 된다. 출력하는 i번째 문자는 i번째 곱셈 명령의 결과이다. 양수인 경우에는 +, 음수인 경우에는 -, 영인 경우에는 0을 출력한다.

www.acmicpc.net

📄 문제개요

구간곱을 구하는 문제로 쿼리의 개수가 10^5개로 N의 개수가 매우 크다면, 일반적인 쿼리로 문제를 해결 할 수 없습니다.

“팬윅트리” 혹은 **“세그먼트 트리”**를 이용하여 문제를 해결 할 수 있습니다.

수열이 주어졌을때, 해당 곱이 “양수”인지 “음수”인지 “0”인지 판별해야합니다.

🤔 문제분석

세그먼트 트리로 문제 풀기.

세그먼트 트리로 문제를 풀경우 메모리는 최소 2^(log(N)+1), 최대 4N의 배열이 필요합니다.

쿼리와 업데이트는 log(N)으로 해결 할 수 있습니다.

위의 문제는 세그먼트트리로 더 쉽게 문제를 해결 할 수 있습니다. 일반적으로 구간곱을 구하는 문제는 팬윅트리로는 구현하기가 조금 복잡합니다. 팬윅트리는 1~i 까지의 구간곱은 구할 수 있으나, i~j 까지의 구간곱은 구현이 까다롭습니다.

팬윅트리로 문제 풀기.

팬윅트리로 문제를 풀경우에는 메모리는 최소 N+1의 배열이 필요합니다.

쿼리와 업데이트는 세그먼트 트리와 마찬가지로 log(N)으로 해결 할 수 있습니다.

위의 문제는 팬윅트리를 조금 응용하여 문제를 해결 할 수 있습니다. 문제에서 “양수”인지 “음수”인지 “0인지” 것을 판별하는게 중요하기때문에, 사실상 곱의 결과는 필요하지 않습니다. 따라서 곱셈의 성질을 이용하여 문제를 해결 할 수 있습니다.

  • 0의 개수와 음수의 개수를 카운팅하여 0의 개수가 구간에 하나라도 존재한다면 0 이다.
  • 0의 개수가 없고, 음수의 개수가 짝수라면 양의 정수이다. 홀수라면 음의 정수이다.

📝 의사코드

세그먼트 트리로 문제 풀기.

  1. 세그먼트를 트리를 초기화 해준다.
  2. 쿼리와 업데이트를 분기하여 로직을 작성한다.

팬윅트리로 문제 풀기.

  1. 팬윅트리를 초기화 해준다.
  2. 쿼리와 업데이트를 분기하여 로직을 작성한다.

💻 코드

세그먼트 트리로 문제 풀기.

# <https://www.acmicpc.net/problem/5676>
import sys
import math

input = sys.stdin.readline

def parse(number):
    if number > 0:
        return 1
    elif number < 0:
        return -1
    else:
        return 0

def init(start, end, node):
    if start == end:
        tree[node] = parse(arr[start])
    else:
        mid = (start + end) // 2
        tree[node] = init(start, mid, 2*node) * init(mid+1, end, 2*node + 1)
    return tree[node]
 
def query(start, end, node, left, right):
    if left > end or right < start:
        return 1
    if left <= start and end <= right:
        return tree[node]
 
    mid = (start + end) // 2
    return query(start, mid, 2*node, left, right) * query(mid+1, end, 2*node + 1, left, right)
 
def update(start, end, node, where, diff):
    if where < start or end < where:
        return
 
    if start == end:
        tree[node] = parse(diff)
    else:
        mid = (start + end) // 2
        update(start, mid, node*2, where, diff)
        update(mid+1, end, node*2 + 1, where, diff)
        tree[node] = tree[2*node] * tree[2*node + 1]
    
while True:
    try:
        N, K = map(int, input().split())
        arr = list(map(int, input().split()))
        tree = [0] * (1 << (int(math.ceil(math.log2(N))) + 1))
        init(0, N-1, 1)
        ans = []
        for _ in range(K):
            temp = list(map(str, input().split()))
            if temp[0] == 'C':
                a, b = map(int, temp[1:])
                update(0, N-1, 1, a-1, b)
            else:
                a, b = map(int, temp[1:])
                temp = query(0, N-1, 1, a-1, b-1)
                if temp == 1:
                    ans.append('+')
                elif temp == -1:
                    ans.append('-')
                else:
                    ans.append('0')
                
        print(''.join(ans))
    except:
        break

팬윅트리로 문제 풀기.

# <https://www.acmicpc.net/problem/5676>
import sys

input = sys.stdin.readline

def update(node, value, dif):
    def execute(node, tree, dif):
        while node <= N:
            tree[node] += dif
            node += (node & -node)
        
    if value < 0:
        execute(node, minus_fwt, dif)
    elif value == 0:
        execute(node, zero_fwt, dif)
        
def query_cnt(node, tree):
    cnt = 0
    while node >= 1:
        cnt += tree[node]
        node -= (node & -node)

    return cnt

while True:
    try:
        N, K = map(int, input().split())
        arr = list(map(int, input().split()))
        minus_fwt = [0] * (N+1)
        zero_fwt = [0] * (N+1)
        
        for i in range(1, N+1):
            update(i, arr[i-1], 1)
                
        ans = []
        for _ in range(K):
            temp = list(map(str, input().split()))
            if temp[0] == 'C':
                a, b = map(int, temp[1:])
                update(a, arr[a-1], -1)
                update(a, b, 1)
                arr[a-1] = b
            else:
                a, b = map(int, temp[1:])
                a = a-1
                if query_cnt(b, zero_fwt) - query_cnt(a, zero_fwt) > 0:
                    ans.append('0')
                    continue
                
                minus_cnt = query_cnt(b, minus_fwt) - query_cnt(a, minus_fwt)
                
                if minus_cnt % 2 == 0:
                    ans.append('+')
                else:
                    ans.append('-')
                
        print(''.join(ans))
    except Exception as e:
        #print(f"An exception occurred: {e}")
        break

🎯 피드백 및 개선사항

기존에 세그먼트 트리와 팬윅트리를 학습을 했었으나, 구현하는 방법을 까먹어 버렸습니다. 이번기회에 세그먼트 트리와 팬윅트리를 좀더 깊게 알 수 있게 되었습니다.

❓다른사람은 어떻게 풀었을까?

저는 처음에 세그먼트 트리로 문제를 접근하였고, 기존에 팬윅트리로 문제를 푸신분이 존재하여, 그 풀이를 참고하여 제가 따로 팬윅트리 방법으로 문제를 구현해보았습니다. 참고한 코드는 아래와 같습니다.

import sys
input = sys.stdin.readline

def update_query(idx, v):
    idx += 1
    while idx < N + 1:
        fwt[idx] += v
        idx += idx & -idx

def update_zero_query(idx, v):
    idx += 1
    while idx < N + 1:
        zero_fwt[idx] += v
        idx += idx & -idx

def get_query(idx):
    idx += 1
    temp = 0
    while 0 < idx:
        temp += fwt[idx]
        idx &= idx - 1
    return temp

def get_zero_query(idx):
    idx += 1
    temp = 0
    while 0 < idx:
        temp += zero_fwt[idx]
        idx &= idx - 1
    return temp

while True:
    temp = input()
    if len(temp) <= 1:
        break
    N, K = map(int, temp.split())
    arr = list(map(int, input().split()))
    fwt = [0] * (N + 1)
    zero_fwt = [0] * (N + 1)

    for i in range(N):
        if arr[i] < 0:
            update_query(i, 1)
        elif arr[i] == 0:
            update_zero_query(i, 1)

    ans = []
    for i in range(K):
        order, a, b = input().split()
        if order == 'C':
            idx, v = int(a) - 1, int(b)
            if arr[idx] <= 0 and 0 < v:
                if arr[idx] == 0:
                    update_zero_query(idx, -1)
                else:
                    update_query(idx, -1)
                arr[idx] = 1
            elif 0 <= arr[idx] and v < 0:
                if arr[idx] == 0:
                    update_zero_query(idx, -1)
                else:
                    update_query(idx, 1)
                arr[idx] = -1
            elif v == 0:
                if arr[idx] < 0:
                    update_query(idx, -1)
                if arr[idx] != 0:
                    update_zero_query(idx, 1)
                arr[idx] = 0
        else:
            a, b = int(a) - 2, int(b) - 1
            t1, t2 = get_zero_query(a), get_zero_query(b)
            if t2 - t1 > 0:
                ans.append("0")
                continue
            t1, t2 = get_query(a), get_query(b)
            if (t2 - t1) % 2 == 1:
                ans.append("-")
            else:
                ans.append("+")

    print("".join(ans))
728x90