문제
1부터 N까지의 번호를 가진 N명의 사람이 있다. 각 사람들은 1부터 N 사이의 임의의 수 Ci가 쓰여있는 카드를 한 장씩 가지고 있다.
사람들 간에는 총 M쌍의 친구 관계가 있다. 모든 친구 관계는 양방향이라서, a번 사람과 b번 사람이 친구라면 b번 사람과 a번 사람도 친구이다. 서로 친구 관계에 있는 두 사람끼리는 서로 들고 있는 카드를 원하는 만큼 교환할 수 있다.
모든 사람들은 각자가 들고 있는 카드에 적힌 수가 자신의 번호와 최대한 비슷하기를 원한다. 어떤 한 사람의 불만족도를 그 사람이 들고 있는 카드에 적힌 수와 그 사람의 번호와의 차이로 정의하고, 전체 불만족도는 모든 사람의 불만족도의 합으로 정의한다.
카드 교환이 적절하게 이루어졌을 때, 가능한 전체 불만족도의 최솟값을 구하라.
입력 조건 등은 문제 링크에서 확인할 수 있다.
분석
문제에서 눈에 띄는 키워드는 친구 관계와 불만족도이다.
우선 친구 관계에서 가장 중요한 부분은 서로 친구인 두 사람끼리는 카드를 원하는 만큼 교환할 수 있다는 말이다. 이는 즉 a와 b가 친구, b와 c가 친구, c와 d가 친구라면 a, b, c, d끼리는 마음대로 카드를 교환할 수 있다는 뜻이다. 이 문장을 보고 바로 든 아이디어는 사람들을 친구 그룹으로 묶는 것이다. 이렇게 되면 그룹 내의 사람들과는 카드를 마음대로 교환할 수 있게 된다.
그 다음 단계는 불만족도를 계산하는 단계이다. 문제에서 불만족도는 |(본인 번호) - (카드 번호)|
로 정의하고 있다. 위 단계에서 이미 그룹화를 마쳤으므로, 각 그룹에서 불만족도의 합을 최소화하는 방법을 찾으면 된다.
풀이
필자의 개인적인 풀이이며, 더 좋은 풀이가 있을 수 있음을 참고 바란다.
그룹화
이 부분은 다양한 방법으로 구현할 수 있지만, 같은 친구 그룹에 속하는 지를 빠르게 확인하기 위해 Union-Find를 사용하였다.
친구 관계를 입력받을 때마다, union을 진행해 바로 바로 집합을 업데이트하면 빠르게 그룹화를 할 수 있다.
필자는 Union-Find를 마친 후 반복문을 돌며 같은 root를 가진 사람들을 리스트로 저장해두었다.
불만족도 계산
이전에 백준 2012번과 같은 문제를 풀어보았다면 감이 오겠지만, 그리디로 쉽게 해결할 수 있다.
작은 번호를 가진 사람이 작은 숫자의 카드를 가질 수 있도록, 카드를 정렬한 후 할당하면 된다.
증명은 위 백준 문제의 질문 게시판에 정리해놓으신 분이 계셔 링크를 남겨놓는다.
시간 복잡도
위 방식대로 구현하면 최악의 경우(모두가 친구 관계로 이어져있을 때)의 시간 복잡도는 정렬에 의해 O(NlogN)이 된다.
소스 코드
import sys
from collections import defaultdict
from functools import reduce
input = sys.stdin.readline
sys.setrecursionlimit(200001)
def find(target):
if tree[target] == target:
return target
tree[target] = find(tree[target])
return tree[target]
def union(u, v):
u = find(u)
v = find(v)
tree[min(u, v)] = max(u, v)
n, m = map(int, input().split())
arr = list(map(int, input().split()))
tree = [i for i in range(n+1)]
roots = defaultdict(list)
cards = defaultdict(list)
# 그룹 만들기
for _ in range(m):
a, b = map(int, input().split())
union(a, b)
for i in range(1,n+1):
root = find(i)
roots[root].append(i)
cards[root].append(arr[i-1])
# 불만족도 계산
result = 0
for root in roots.keys():
group = sorted(roots[root])
card = sorted(cards[root])
for i in range(len(group)):
result += abs(card[i] - group[i])
print(result)