https://www.acmicpc.net/problem/1753
위 문제는 기본적인 다익스트라 문제이다.
그런데 다음과 같이 풀면 틀린다.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Main {
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static StringTokenizer st;
public static void main(String[] args) throws IOException {
st = new StringTokenizer(br.readLine());
int V = Integer.parseInt(st.nextToken());
int E = Integer.parseInt(st.nextToken());
int K = Integer.parseInt(br.readLine());
List<List<Node>> list = new ArrayList<>();
for (int i = 0; i <= V; i++) {
list.add(new ArrayList<>());
}
for (int i = 0; i < E; i++) {
st = new StringTokenizer(br.readLine());
int start = Integer.parseInt(st.nextToken());
int end = Integer.parseInt(st.nextToken());
int cost = Integer.parseInt(st.nextToken());
list.get(start).add(new Node(end, cost));
}
int[] dist = new int[V + 1];
Arrays.fill(dist, Integer.MAX_VALUE);
boolean[] visited = new boolean[V + 1];
Queue<Node> que = new PriorityQueue<>((o1, o2) -> o1.cost - o2.cost);
que.add(new Node(K, 0));
dist[K] = 0;
visited[K] = true; // (1) 시작점 방문처리
while (!que.isEmpty()) {
Node now = que.poll();
for (Node next : list.get(now.end)) {
if (visited[next.end]) continue; // (2) 방문했으면 패스
if (dist[next.end] > dist[now.end] + next.cost) {
dist[next.end] = dist[now.end] + next.cost;
que.add(new Node(next.end, dist[next.end]));
visited[next.end] = true; // (3) 방문처리
}
}
}
for (int i = 1; i <= V; i++) {
System.out.println((dist[i] == Integer.MAX_VALUE) ? "INF" : dist[i]);
}
}
private static class Node {
int end;
int cost;
public Node(int end, int cost) {
this.end = end;
this.cost = cost;
}
}
}
문제는 방문처리 시점이다.
while (!que.isEmpty()) {
Node now = que.poll();
for (Node next : list.get(now.end)) {
if (visited[next.end]) continue; // (2) 방문했으면 패스
if (dist[next.end] > dist[now.end] + next.cost) {
dist[next.end] = dist[now.end] + next.cost;
que.add(new Node(next.end, dist[next.end]));
visited[next.end] = true; // (3) 방문처리
}
}
}
예를 들어, 노드와 간선이 다음과 같다고 가정할 때, 1번 노드에서 3번 노드까지의 최단 거리는 3이다.
그런데 위 코드와 같이 방문처리를 하면 1번 노드에서 3번 노드까지의 최단 거리가 4로 계산된다.
- (1)에서 시작 노드인 1을 방문처리한다.
- Node(1, 0)를 poll하고
- next가 Node(2, 2)일 때, dist[2]를 2로 갱신한다. 그리고 2번 노드를 방문처리한다(3).
- 이어서 next가 Node(3, 4)일 때, dist[3]을 4로 갱신하고, 3번 노드를 방문처리한다(3).
- 이어서 Node(2, 2)를 poll하고
- next가 Node(3, 1)일 때, 3번 노드에 대해 방문처리가 되어 있기 때문에 (2)번 부분에서 패스된다.
- 따라서 1번 노드에서 3번 노드까지의 거리가 3이 최단 거리임에도 이미 4의 거리로 방문했기 때문에 패스된다.
즉, a에서 k로 가는 경로를 반영하는 시점에(que.add()할 때) k에 방문처리를 하면 안된다. 왜냐하면 그 다음 노드인 b에서 k로 가는 비용이 a에서 k로 가는 비용보다 적을 수 있기 때문이다. 만약 a에서 k로 가는 경로를 반영하는 시점에 k에 방문처리를 하면, b에서 k로 가는 비용이 더 적은데 이때 visited[k]가 true이기 때문에 b에서 k로 가는 경로는 반영되지 않는다.
while (!que.isEmpty()) {
Node now = que.poll();
for (Node next : list.get(now.end)) {
// (2) visited[k]가 true이기 때문에 패스되어 b에서 k로 가는 경로는 반영되지 않음
if (visited[next.end]) continue;
if (dist[next.end] > dist[now.end] + next.cost) {
dist[next.end] = dist[now.end] + next.cost;
que.add(new Node(next.end, dist[next.end]));
// (1) a에서 k로 가는 경로를 반영할 때 k에 방문처리
visited[next.end] = true;
}
}
}
따라서 k에 대한 방문처리는 특정 노드에서 k로 가는 경로를 반영하는 시점이 아닌, k에서 출발하는 시점에 해줘야 한다.
while (!que.isEmpty()) {
Node now = que.poll();
if (visited[now.end]) continue; // 방문했으면 패스
visited[now.end] = true; // 방문처리
for (Node next : list.get(now.end)) {
if (dist[next.end] > dist[now.end] + next.cost) {
dist[next.end] = dist[now.end] + next.cost;
que.add(new Node(next.end, dist[next.end]));
}
}
}
그리고 이때 que.poll()한 시점에 방문처리를 해도 되는 이유는 큐가 우선순위 큐이기 때문이다. 큐에는 특정 노드(end)로 가는 비용(cost)를 담고 있다. 그리고 cost가 적은 순으로 poll된다. 따라서 시작 노드인 1번에서 3번 노드로 가는 경로가 다음과 같이 두개가 있을 때
- 1 → 2 → 3 (총 비용: 3)
- 1 → 3 (총 비용: 4)
큐에는 Node(3, 3)과 Node(3, 4)가 들어있다. 그리고 우선순위 큐이기 때문에 Node(3, 3)이 먼저 poll된다. 이때 3번 노드에 대해 방문처리를 하면, 이후에 Node(3, 4)를 poll할 때는 이미 3번 노드를 방문했기 때문에 poll한 후 continue된다.
그런데 우선순위 큐로 3번 노드로 가는데 더 적은 비용이 드는 경로인 Node(3, 3)을 먼저 poll하고 처리해줬기 때문에 그 이후에 3번 노드로 가는데 더 큰 비용이 드는 경로인 Node(3, 4)는 continue되어 처리되지 않아도 된다.
전체 코드는 다음과 같다.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Main {
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static StringTokenizer st;
public static void main(String[] args) throws IOException {
st = new StringTokenizer(br.readLine());
int V = Integer.parseInt(st.nextToken());
int E = Integer.parseInt(st.nextToken());
int K = Integer.parseInt(br.readLine());
List<List<Node>> list = new ArrayList<>();
for (int i = 0; i <= V; i++) {
list.add(new ArrayList<>());
}
for (int i = 0; i < E; i++) {
st = new StringTokenizer(br.readLine());
int start = Integer.parseInt(st.nextToken());
int end = Integer.parseInt(st.nextToken());
int cost = Integer.parseInt(st.nextToken());
list.get(start).add(new Node(end, cost));
}
int[] dist = new int[V + 1];
Arrays.fill(dist, Integer.MAX_VALUE);
boolean[] visited = new boolean[V + 1];
Queue<Node> que = new PriorityQueue<>((o1, o2) -> o1.cost - o2.cost);
que.add(new Node(K, 0));
dist[K] = 0;
while (!que.isEmpty()) {
Node now = que.poll();
if (visited[now.end]) continue;
visited[now.end] = true;
for (Node next : list.get(now.end)) {
if (visited[next.end]) continue;
if (dist[next.end] > dist[now.end] + next.cost) {
dist[next.end] = dist[now.end] + next.cost;
que.add(new Node(next.end, dist[next.end]));
}
}
}
for (int i = 1; i <= V; i++) {
System.out.println((dist[i] == Integer.MAX_VALUE) ? "INF" : dist[i]);
}
}
private static class Node {
int end;
int cost;
public Node(int end, int cost) {
this.end = end;
this.cost = cost;
}
}
}
'CS > Algorism' 카테고리의 다른 글
[백준] 5567. 결혼식 (0) | 2025.02.11 |
---|---|
[백준] 1058. 친구 (0) | 2025.01.10 |
[프로그래머스] 등대 (1) | 2024.10.18 |
다익스트라 알고리즘 - 우선순위 큐와 방문 처리에 대해 (0) | 2024.09.21 |
[백준] 10713. 기차 여행 (0) | 2024.07.05 |