728x90

http://www.acmicpc.net/problem/1197

 

1197번: 최소 스패닝 트리

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이

www.acmicpc.net

 

# 조건

  • 그래프가 주어졌을 때, 그래프의 최소 스패닝 트리를 구하시오
  • 최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중 가중치의 합이 최소인 트리를 말한다.
 

# 접근 방법

  • 최소 스패닝 트리 = 최소 신장 트리이다.
  • prim 또는 kruskal 알고리즘을 통해 풀어준다.

  • prim 알고리즘인 경우 하나의 정점에서 연결된 간선들 중 하나씩 선택하면서 만들어 가는 방식이다.
  • 최소 비용의 간선이 존재하는 정점을 선택하며
  • 모든 정점이 선택될 때까지 위의 과정을 반복한다.

  • kruskal의 경우
  • 최초, 모든 간선을 가중치에 따라 오름차순으로 정렬한 후
  • 가중치가 가장 낮은 간선부터 선택하면서 트리를 증가시킨다.
  • 이 후 사이클이 존재하면 다음으로 가중치가 낮은 간선을 선택한다
  • N-1개의 간선이 선택될 때까지 위를 반복
 

kruskal 풀이

 
1. -> 시간초과
import sys  
sys.stdin = open('input.txt')  
input = sys.stdin.readline  
  
  
def find_set(x):  
    while x!=rep[x]:  
        x = rep[x]  
    return x  
  
def union(x, y):  
    rep[find_set(y)] = find_set(x)  
  
  
V, E = map(int, input().split())  
  
graph = []  
  
for _ in range(E):  
    a, b, c = map(int, input().split())  
    graph.append([c,a,b])  
  
# 가중치 기준 오름차순 정렬  
graph.sort()  
  
# 대표원소 배열  
rep = [i for i in range(V+1)]  
  
# MST의 간선수 N = 정점수 -1N = V+1  
cnt = 0  
total = 0  
for w, v, u in graph:  
    if find_set(v) != find_set(u):  
        cnt += 1  
        union(u, v)  
        total += w  
        if cnt == N-1:  
            break  
print(total)
 

2. pass 코드

  • find 함수에서 경로를 압축시켜 주니 통과가 잘 되었다.
  • Kruskal 알고리즘만 알고 있더라도 잘 풀리는 문제!
import sys  
input = sys.stdin.readline  
  
  
def find_set(x):  
    if x == rep[x]:  
        return x  
    rep[x] = find_set(rep[x])  
    return rep[x]  
  
  
def union(x, y):  
    rep[find_set(y)] = find_set(x)  
  
  
V, E = map(int, input().split())  
  
graph = []  
  
for _ in range(E):  
    a, b, c = map(int, input().split())  
    graph.append([c,a,b])  
  
# 가중치 기준 오름차순 정렬  
graph.sort()  
  
# 대표원소 배열  
rep = [i for i in range(V+1)]  
  
# MST의 간선수 N = 정점수 -1N = V+1  
cnt = 0  
total = 0  
for w, v, u in graph:  
    if find_set(v) != find_set(u):  
        cnt += 1  
        union(u, v)  
        total += w  
print(total)
 

다른 사람 풀이

-> bfs + que

import sys, heapq
input=sys.stdin.readline
V, E=map(int, input().split())  #정점, 간선 개수
visited=[0]*(V+1)
lst=[[] for _ in range(V+1)]
ans=0

for e in range(E) :
    A, B, C=map(int, input().split())   # 노드, 노드, 가중치
    lst[A].append((C, B))
    lst[B].append((C, A))

que=[(0,1)]     
while que :
    node=heapq.heappop(que)
    if not visited[node[1]] :
        visited[node[1]]=1
        ans+=node[0]
        for n in lst[node[1]] :
            heapq.heappush(que, n)

print(ans)

-> prim 알고리즘

  • 인덱스 = 노드, 값 = (연결된 노드, 가중치) 인 리스트 G 생성
  • heap을 사용해서 가중치가 작은 순서대로 pop 한다.
  • 합에 해당 노드와 연결된 노드를 push 한다.
# https://velog.io/@jajubal/%ED%8C%8C%EC%9D%B4%EC%8D%AC%EB%B0%B1%EC%A4%80-1197-%EC%B5%9C%EC%86%8C-%EC%8A%A4%ED%8C%A8%EB%8B%9D-%ED%8A%B8%EB%A6%AC

from heapq import heappush, heappop
import sys
input=sys.stdin.readline

def prim(start, weight) :
    visit=[0]*(V+1)     # 정점 방문 처리
    q=[[weight, start]]     # 힙 구조를 사용하기 위해 가중치를 앞에 둠
    ans=0   # 가중치 합
    cnt=0   # 간선의 개수
    while cnt < V :     # 간선의 개수 최대는 V-1
        k, v=heappop(q)
        if visit[v] : continue # 이미 방문한 정점이면 지나감
        visit[v]=1      # 방문안했으면 방문처리
        ans+=k      # 해당 정점까지의 가중치를 더해줌
        cnt+=1      # 간선의 갯수 더해줌
        for u, w in G[v] :  # 해당 정점의 간선정보를 불러옴
            heappush(q, [w, u])     # 힙에 넣어줌
    return ans

V, E=map(int, input().split())
G=[[] for _ in range(V+1)]
for _ in range(E) :
    u, v, w=map(int, input().split())
    G[u].append([v, w])
    G[v].append([u, w])

print(prim(1, 0))
728x90