알고리즘을 풀다가 분리집합, union-find 관련 문제를 자주 접하여 관련 개념들을 정리하고자 글을 적게 되었다.
1. Union-Find와 Disjoint Set
`Union-Find, Disjoint Set` 두 단어 모두 알고리즘을 풀다 보면 한 번씩 들어보게 되는 용어이다.
처음에는 '2개의 문제 유형이 다른가?' 하고 생각을 했는데 같은 문제로 보면 될 것 같다.
정확히 말하면 그래프 탐색에 dfs, bfs 알고리즘을 사용하듯이, 분리집합에는 Union-Find 알고리즘을 사용한다.
하지만 분리집합을 관리하기 위한 알고리즘으로 Union-Find 알고리즘이 효율이 매우 좋아 다른 알고리즘이 필요하지 않는다고 보면 될 것 같다.
그럼 간략하게나마 각 용어의 개념을 정리해보자.
Disjoint Set
: 서로소 집합이라고 하며, 서로 겹치지 않는 여러 개의 집합들로 구성된 자료구조이다.
즉, 각 집합은 고유한 원소들로 구성되어 있고 원소들은 다른 집합의 원소와 중복되지 않는다.
Union-Find
: 분리집합을 구현하는 알고리즘 기법이다.
Union - 두 집합을 합친다.
Find - 특정 원소가 속한 집합의 root 원소를 찾는다.
2. 문제 예시
*BOJ - 1717. 집합의 표현
1) Union 단계
여기에서의 목적은 "x와 y를 하나의 집합으로 통일하는 것"이다.
따라서 x, y의 root 원소를 찾고 둘 중에 어느 한 root 원소로 통일시켜 주면 된다.
union은 별도의 기준에 맞춰 진행할 수 있는데 2)와 3) 예시를 보면 된다.
2)는 'root 원소의 크기'를 비교하여 union을 하도록 구현했다.
3)은 'rank의 크기'를 비교하여 union을 하도록 구현했다.
이 때 rank는 해당 집합의 크기로, 집합의 크기가 매우 커질 때 하위 원소들의 root가 바뀔 때 해당 연산을 최소화 하도록 할 수 있다.
def union(x, y):
rootX = find(x)
rootY = find(y)
parent[rootX] = rootY
def union2(x, y):
rootX = find(x)
rootY = find(y)
if rootX < rootY:
parent[rootY] = rootX
else:
parent[rootX] = rootY
def union3(x, y):
findX = find(x)
findY = find(y)
if rank[findX] >= rank[findY]:
parents[findY] = findX
rank[findX] += rank[findY]
else:
parents[findX] = findY
rank[findY] += rank[findX]
2) Find 단계
일단 find를 하기 전에 기본 개념에 대해 인지할 필요가 있다.
1) 각 원소에 대한 초기 root 원소는 자기 자신으로 설정한다.
2) root 원소의 부모 원소는 그 자신 root 원소이다.
다시 돌아와서 find의 목적을 살펴보면 "x 원소에 대한 root 원소를 찾는 것"이다.
그럼 부모 노드를 타고 올라가 root 원소를 만났을 때, 즉 2번 조건에 해당했을 때 return 할 수 있게끔 한다.
여기서 주의해야 할 점이 있다.
`return find(parent[x])` 로 바로 root 원소를 return 하는 것이 아니라, '경로 압축'를 하여 union 했을 때 root 노드가 바뀌었던 것을 바로 적용하며 아래와 같이 트리 노드를 최적화해주도록 한다.
def find(x):
if x == parent[x]:
return x
parent[x] = find(parent[x]) # 경로 압축 최적화
return parent[x]
# X
# return find(parent[x])
3) 코드
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)
def find(x):
if x == parent[x]:
return x
parent[x] = find(parent[x]) # 경로 압축 최적화
return parent[x]
def union(x, y):
rootX = find(x)
rootY = find(y)
parent[rootX] = rootY
# m : 입력으로 주어지는 연산의 개수
N, M = map(int, input().split())
# 초기 루트 노드는 자기 자신으로 초기화
parent = [0 for _ in range(N+1)]
for i in range(N+1):
parent[i] = i
for _ in range(M):
op, a, b, = map(int, input().split())
# 합집합 연산
if op == 0:
# 루트 노드가 다르면
if find(a) != find(b):
union(a, b)
# 두 원소가 같은 집합에 포함되어 있는지
# 즉, 같은 루트 노드를 가지고 있는지
else:
if find(a) != find(b):
print("NO")
else:
print("YES")
3. 문제 적용
1) BOJ - 20040. 사이클 게임
[코드]
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)
def union(x, y):
findX = find(x)
findY = find(y)
if rank[findX] >= rank[findY]: # x 중심
parents[findY] = findX
rank[findX] += rank[findY]
else:
parents[findX] = findY
rank[findY] += rank[findX]
def find(x):
if parents[x] == x:
return x
parents[x] = find(parents[x])
return parents[x]
def judgeCycle():
count = 0
for x, y in number:
count += 1
# 두 점 모두 부모가 같다면 사이클이 형성된 것
if find(x) == find(y):
return count
else:
union(x, y)
return 0
# 점의 개수, 진행된 차례의 수
N, M = map(int, input().split())
number = [list(map(int, input().split())) for _ in range(M)]
# 부모는 자기 자신
parents = [i for i in range(N)]
rank = [1 for i in range(N)]
result = judgeCycle()
print(result)
[풀이]
2) BOJ - 4195. 친구 네트워크
[코드]
import sys
input = sys.stdin.readline
T = int(input())
def find(k):
if parent[k] == k:
return k
parent[k] = find(parent[k]) # 부모 갱신, 경로 압축 최적화
return parent[k]
def union(a, b):
rootA = find(a) # a 원소의 최상위 루트 값
rootB = find(b) # b 원소의 최상위 루트 값
if rootA != rootB:
parent[rootB] = rootA # B가 A의 하위로 들어감
number[rootA] += number[rootB] # A 원소 개수 늘어나기
print(number[rootA])
for _ in range(T):
parent = dict() # (name, parent node) 저장
number = dict() # (rootK, cnt) / 루트 원소의 집합의 개수 파악
F = int(input()) # 친구 관계의 수
# 친구 관계는 두 사용자의 아이디로 이루어져 있다.
for z in range(F):
A, B = map(str, input().split())
# 새로운 이름이면 parent 에 추과
# 초기 parent는 자기 자신
if A not in parent:
parent[A] = A
number[A] = 1
if B not in parent:
parent[B] = B
number[B] = 1
union(A, B)
[풀이]
3) BOJ - 20303. 할로윈의 양아치
[코드]
import sys
from collections import defaultdict
input = sys.stdin.readline
def find(x):
if x == parent[x]:
return x
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
rootX = find(x)
rootY = find(y)
parent[rootX] = rootY
# O(M * a(N))
def solve_unionFind():
global result
global parent
result = 0
parent = [i for i in range(N + 1)]
for _ in range(M):
a, b = map(int, input().split())
if find(a) != find(b):
union(a, b)
# union을 진행한 후, 하위 노드에 대한 부모도 같이 변경해주기 위함.
for i in range(1, N + 1):
find(i)
# value로 list를 사용하기 위해 defaultdict을 사용함
# 같은 친구 관계에 있는 것들끼리, 집합을 만들었다.
dicts = defaultdict(list)
for i in range(1, N + 1):
dicts[parent[i]].append(candy[i])
return dicts
# O(분리집합 갯수 * 최대로 훔칠 수 있는 아이들 수)
def solve_dp():
# i번째 집합의 j번째 사탕 갯수에서 가질 수 있는 최댓값 저장을 위한 1차원 dp 테이블을 정의한다
dp = [0] * K
for friends in dicts.values():
w = len(friends) # 친구 몇 명?
v = sum(friends) # 친구 사탕을 모두 뺏을 때 가질 수 있는 사탕은 몇 개?
# 1차원 dp로 해결하기 위해서, 뒤에서부터 dp 값을 갱신한다.
for j in range(K-1, w-1, -1): # 현재 뺏을 수 있는 최대 아이들
# dp[i][j] = max(dp[i - 1][j], dp[i - 1][j - w] + v)
dp[j] = max(dp[j], dp[j - w] + v)
return dp[K-1]
# 아이들 수, 아이들의 친구 관계 수, 울음소리가 공명하기 위한 아이 수
N, M, K = map(int, input().split())
candy = [0] + list(map(int, input().split()))
# 친구관계를 분리 집합으로 파악
dicts = solve_unionFind()
# 4명의 아이들 집합 사탕 vs 2명의 아이들 집합 사탕 + 2명의 아이들 집합 사탕
# 어느 것이 더 많이 훔칠 수 있는지 파악해야 하므로 "배낭문제" dp를 사용한다.
result = solve_dp()
print(result)
[풀이]
'Algorithm > Union Find' 카테고리의 다른 글
[백준] 1717번 집합의 표현 _ Python (1) | 2024.04.12 |
---|