ULISIA's Developer Life

[Backjoon] 1197 최소 스패닝 트리 Python 문제 풀이 본문

Algorithm/Backjoon

[Backjoon] 1197 최소 스패닝 트리 Python 문제 풀이

ULISIA 2023. 7. 18. 16:45

출처:1197번 최소 스패닝 트리

문제해석

최소 신장 트리 (minimum spanning tree)

전체 요소들을 연결할때 사용한다. Kruskal 알고리즘, Prim 알고리즘이 있다.

Kruskal

  1. 간선들을 정렬
  2. 간선이 잇는 두 정점의 root를 찾는다.
  3. 다르다면 하나의 root를 바꾸어 연결 시켜준다.

Prim

  1. 임의의 정점을 선택
  2. 해당 정점에서 갈 수 있는 간선을 minheap(최소힙)에 넣는다.
  3. 최소값을 뽑아 해당 정점을 방문하지 않았다면 선택한다.

Kruskal 알고리즘은 간선들을 정렬해야하기 때문에 간선이 적다면 Kruskal, 많다면 Prim을 선택한다.

비슷한 알고리즘으로 다익스트라 알고리즘이 있다.
다익스트라는 전체 요소를 연결하는 것이 아닌 한 정점에서 다른 정점으로 가는 가장 짧은 방법을 구할 때 사용한다.

다익스트라

  1. 최단 거리 배열을 무한대로 초기화한다. 방문여부 배열은 False로 초기화 한다.
  2. 출발 노드는 방문했다고 체크 후 heap에 넣는다.
  3. 아직 방문하지 않은 노드들 중, 최단 거리 테이블 값이 가장 작은 노드를 선택한다.
  4. 저장된 최단거리 값과 현재노드에 가중치를 더한 거리 값 중 더 작은 값으로 update한다.

해답코드

1. Kruskal

  1. root를 저장하는 vRoot 배열 생성. (여기서 root는 연결요소 중 가장 작은 값, 처음에는 자기 자신을 저장)
  2. 간선들(eList)를 가중치(gain) 기준으로 정렬.
  3. 간선들이 이은 두 정점을 find 함수를 통해 두 root(sRoot, eRoot)를 찾는다.
  4. 두 root가 다르다면 큰 root값을 작은 root값으로 만들어 연결되게 해준다.
  5. 가중치를 더한다.
# 최소 스패닝 트리

#1. Kruskal
import sys
input = sys.stdin.readline

v, e = map(int, input().split())
vRoot = [i for i in range(v+1)]
eList = []
for _ in range(e):
    eList.append(list(map(int,input().split())))

eList.sort(key=lambda x: x[2])

def find(x):
    if x != vRoot[x]:
        vRoot[x] = find(vRoot[x])
    return vRoot[x]

ans=0

for start, end, gain in eList:
    startRoot = find(start)
    endRoot = find(end)
    if startRoot != endRoot:
        if startRoot>endRoot:
            vRoot[startRoot] = endRoot
        else:
            vRoot[endRoot] = startRoot
        ans += gain

print(ans)

2. Prim

  1. isVisited : 방문여부 확인
  2. eList : 간선 저장
  3. heap : 현재 그래프에서 짧은 경로를 선택
  • 현재 그래프에서 가장 짧은 간선을 골라 방문하지 않은 정점이라면 선택한다.
#2. Prim
import sys
import heapq
input = sys.stdin.readline

v, e = map(int, input().split())
isVisited = [False]*(v+1)
eList = [[] for _ in range(v+1)]
heap = [[0, 1]]

for _ in range(e):
    start, end, gain = map(int, input().split())
    eList[start].append([gain,end])
    eList[end].append([gain,start])

ans = 0
cnt = 0
while heap:
    if cnt == v:
        break
    gain, start = heapq.heappop(heap)
    if not isVisited[start]:
        isVisited[start] = True
        ans += gain
        cnt += 1
        for i in eList[start]:
            heapq.heappush(heap, i)

print(ans)