본문 바로가기
Algorithm/Union Find

Union-Find 알고리즘 (feat. Disjoint Set)

by wch_t 2024. 6. 27.

알고리즘을 풀다가 분리집합, union-find 관련 문제를 자주 접하여 관련 개념들을 정리하고자 글을 적게 되었다.

 


 

1. Union-Find와 Disjoint Set

`Union-Find, Disjoint Set` 두 단어 모두 알고리즘을 풀다 보면 한 번씩 들어보게 되는 용어이다.

 

처음에는 '2개의 문제 유형이 다른가?' 하고 생각을 했는데 같은 문제로 보면 될 것 같다.

정확히 말하면 그래프 탐색에 dfs, bfs 알고리즘을 사용하듯이, 분리집합에는 Union-Find 알고리즘을 사용한다.

하지만 분리집합을 관리하기 위한 알고리즘으로 Union-Find 알고리즘이 효율이 매우 좋아 다른 알고리즘이 필요하지 않는다고 보면 될 것 같다.

 

그럼 간략하게나마 각 용어의 개념을 정리해보자.

 

Disjoint Set

: 서로소 집합이라고 하며, 서로 겹치지 않는 여러 개의 집합들로 구성된 자료구조이다.

즉, 각 집합은 고유한 원소들로 구성되어 있고 원소들은 다른 집합의 원소와 중복되지 않는다.

 

Union-Find

: 분리집합을 구현하는 알고리즘 기법이다.

 

Union - 두 집합을 합친다.

Find - 특정 원소가 속한 집합의 root 원소를 찾는다.

 

 


 

2. 문제 예시

*BOJ - 1717. 집합의 표현

 

1) Union 단계

여기에서의 목적은 "x와 y를 하나의 집합으로 통일하는 것"이다.

따라서 x, y의 root 원소를 찾고 둘 중에 어느 한 root 원소로 통일시켜 주면 된다.

 

union은 별도의 기준에 맞춰 진행할 수 있는데 2)와 3) 예시를 보면 된다.

2)는 'root 원소의 크기'를 비교하여 union을 하도록 구현했다.

3)은 'rank의 크기'를 비교하여 union을 하도록 구현했다.

이 때 rank는 해당 집합의 크기로, 집합의 크기가 매우 커질 때 하위 원소들의 root가 바뀔 때 해당 연산을 최소화 하도록 할 수 있다.

def union(x, y):
    rootX = find(x)
    rootY = find(y)
	
    parent[rootX] = rootY
    


def union2(x, y):
    rootX = find(x)
    rootY = find(y)
	
    if rootX < rootY:
        parent[rootY] = rootX
    else:
        parent[rootX] = rootY



def union3(x, y):
    findX = find(x)
    findY = find(y)

    if rank[findX] >= rank[findY]:
        parents[findY] = findX
        rank[findX] += rank[findY]
    else:
        parents[findX] = findY
        rank[findY] += rank[findX]

 

 

 

2) Find 단계

일단 find를 하기 전에 기본 개념에 대해 인지할 필요가 있다.

1) 각 원소에 대한 초기 root 원소는 자기 자신으로 설정한다.

2) root 원소의 부모 원소는 그 자신 root 원소이다.

 

다시 돌아와서 find의 목적을 살펴보면 "x 원소에 대한 root 원소를 찾는 것"이다.

그럼 부모 노드를 타고 올라가 root 원소를 만났을 때, 즉 2번 조건에 해당했을 때 return 할 수 있게끔 한다.

 

여기서 주의해야 할 점이 있다.

`return find(parent[x])` 로 바로 root 원소를 return 하는 것이 아니라, '경로 압축'를 하여 union 했을 때 root 노드가 바뀌었던 것을 바로 적용하며 아래와 같이 트리 노드를 최적화해주도록 한다.

 

 

def find(x):
    if x == parent[x]:
        return x

    parent[x] = find(parent[x]) # 경로 압축 최적화
    return parent[x]
    
    # X
    # return find(parent[x])

 

3) 코드

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)

def find(x):
    if x == parent[x]:
        return x

    parent[x] = find(parent[x]) # 경로 압축 최적화
    return parent[x]

def union(x, y):
    rootX = find(x)
    rootY = find(y)
	
    parent[rootX] = rootY


# m : 입력으로 주어지는 연산의 개수
N, M = map(int, input().split())

# 초기 루트 노드는 자기 자신으로 초기화
parent = [0 for _ in range(N+1)]
for i in range(N+1):
    parent[i] = i

for _ in range(M):
    op, a, b, = map(int, input().split())

    # 합집합 연산
    if op == 0:
        # 루트 노드가 다르면
        if find(a) != find(b):
            union(a, b)

    # 두 원소가 같은 집합에 포함되어 있는지
    # 즉, 같은 루트 노드를 가지고 있는지
    else:
        if find(a) != find(b):
            print("NO")
        else:
            print("YES")

 

 


 

 

3. 문제 적용

1) BOJ - 20040. 사이클 게임

 

[코드]

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)

def union(x, y):
    findX = find(x)
    findY = find(y)

    if rank[findX] >= rank[findY]: # x 중심
        parents[findY] = findX
        rank[findX] += rank[findY]
    else:
        parents[findX] = findY
        rank[findY] += rank[findX]


def find(x):
    if parents[x] == x:
        return x

    parents[x] = find(parents[x])
    return parents[x]


def judgeCycle():
    count = 0
    for x, y in number:
        count += 1

        # 두 점 모두 부모가 같다면 사이클이 형성된 것
        if find(x) == find(y):
            return count

        else:
            union(x, y)

    return 0

# 점의 개수, 진행된 차례의 수
N, M = map(int, input().split())
number = [list(map(int, input().split())) for _ in range(M)]

# 부모는 자기 자신
parents = [i for i in range(N)]
rank = [1 for i in range(N)]

result = judgeCycle()
print(result)

 

[풀이]

 

 

 


 

 

2) BOJ - 4195. 친구 네트워크

 

[코드]

import sys
input = sys.stdin.readline

T = int(input())
def find(k):
    if parent[k] == k:
        return k

    parent[k] = find(parent[k]) # 부모 갱신, 경로 압축 최적화
    return parent[k]


def union(a, b):
    rootA = find(a) # a 원소의 최상위 루트 값
    rootB = find(b) # b 원소의 최상위 루트 값

    if rootA != rootB:
        parent[rootB] = rootA # B가 A의 하위로 들어감
        number[rootA] += number[rootB] # A 원소 개수 늘어나기

    print(number[rootA])


for _ in range(T):
    parent = dict() # (name, parent node) 저장
    number = dict() # (rootK, cnt) / 루트 원소의 집합의 개수 파악

    F = int(input()) # 친구 관계의 수

    # 친구 관계는 두 사용자의 아이디로 이루어져 있다.
    for z in range(F):
        A, B = map(str, input().split())

        # 새로운 이름이면 parent 에 추과
        # 초기 parent는 자기 자신
        if A not in parent:
            parent[A] = A
            number[A] = 1
            
        if B not in parent:
            parent[B] = B
            number[B] = 1

        union(A, B)

 

[풀이]

 

 


 

 

3) BOJ - 20303. 할로윈의 양아치

 

[코드]

import sys
from collections import defaultdict
input = sys.stdin.readline

def find(x):
    if x == parent[x]:
        return x

    parent[x] = find(parent[x])
    return parent[x]


def union(x, y):
    rootX = find(x)
    rootY = find(y)

    parent[rootX] = rootY


# O(M * a(N))
def solve_unionFind():
    global result
    global parent

    result = 0
    parent = [i for i in range(N + 1)]
    for _ in range(M):
        a, b = map(int, input().split())

        if find(a) != find(b):
            union(a, b)

    # union을 진행한 후, 하위 노드에 대한 부모도 같이 변경해주기 위함.
    for i in range(1, N + 1):
        find(i)

    # value로 list를 사용하기 위해 defaultdict을 사용함
    # 같은 친구 관계에 있는 것들끼리, 집합을 만들었다.
    dicts = defaultdict(list)
    for i in range(1, N + 1):
        dicts[parent[i]].append(candy[i])

    return dicts


# O(분리집합 갯수 * 최대로 훔칠 수 있는 아이들 수)
def solve_dp():

    # i번째 집합의 j번째 사탕 갯수에서 가질 수 있는 최댓값 저장을 위한 1차원 dp 테이블을 정의한다
    dp = [0] * K

    for friends in dicts.values():
        w = len(friends)  # 친구 몇 명?
        v = sum(friends)  # 친구 사탕을 모두 뺏을 때 가질 수 있는 사탕은 몇 개?

        # 1차원 dp로 해결하기 위해서, 뒤에서부터 dp 값을 갱신한다.
        for j in range(K-1, w-1, -1):  # 현재 뺏을 수 있는 최대 아이들
            # dp[i][j] = max(dp[i - 1][j], dp[i - 1][j - w] + v)
            dp[j] = max(dp[j], dp[j - w] + v)

    return dp[K-1]


# 아이들 수, 아이들의 친구 관계 수, 울음소리가 공명하기 위한 아이 수
N, M, K = map(int, input().split())
candy = [0] + list(map(int, input().split()))

# 친구관계를 분리 집합으로 파악
dicts = solve_unionFind()

# 4명의 아이들 집합 사탕 vs 2명의 아이들 집합 사탕 + 2명의 아이들 집합 사탕
# 어느 것이 더 많이 훔칠 수 있는지 파악해야 하므로 "배낭문제" dp를 사용한다.
result = solve_dp()

print(result)

 

[풀이]

'Algorithm > Union Find' 카테고리의 다른 글

[백준] 1717번 집합의 표현 _ Python  (1) 2024.04.12