본문 바로가기
백준

[JAVA] 백준 1504 - 특정한 최단경로

by 맴썰 2025. 8. 12.

https://www.acmicpc.net/problem/1504

특정 노드 2개를 반드시 지나는 최단거리를 구하는 문제이다.

처음에 다익스트라를 생각했지만 필수노드 2개를 반드시 지나는 부분을 생각하지 못했는데, 

질문게시판의 한 은인의 도움으로 목표 경로는 시작점 -> 필수노드1 -> 필수노드2 -> 끝점

또는 시작점 -> 필수노드2 -> 필수노드1 -> 끝점의 2가지 경우의 수의 최솟값이라는 사실을 깨닫고

시작점 -> 필수노드1, 필수노드1->필수노드2, 필수노드2 -> 끝점의

3번의 다익스트라를 거쳐서 풀었다. 

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;

public class Solved1504 {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int[] a = Arrays.stream(br.readLine().split(" ")).mapToInt(Integer::parseInt).toArray();
        int nodeCount = a[0];
        List<Nodes> nodeList = new ArrayList<>();
        for (int i = 0; i <= nodeCount; i++) {
            nodeList.add(new Nodes(i));
        }
        int vertexCount = a[1];
        if (vertexCount == 0) {
            System.out.println(-1);
            return;
        }

        for (int i = 0; i < vertexCount; i++) {
            a = Arrays.stream(br.readLine().split(" ")).mapToInt(Integer::parseInt).toArray();
                nodeList.get(a[0]).add(new Vertex(a[0], a[1], a[2]));
                nodeList.get(a[1]).add(new Vertex(a[1], a[0], a[2]));

        }
        a = Arrays.stream(br.readLine().split(" ")).mapToInt(Integer::parseInt).toArray();
        int e1 = a[0];
        int e2 = a[1];
        PriorityQueue<Vertex> pq = new PriorityQueue<>();
        pq.addAll(nodeList.get(1).nodeList);
        nodeList.get(1).weight = 0;
        while (!pq.isEmpty()) {
            Vertex target = pq.poll();
            int from = target.from;
            int to = target.to;
            int cost = target.value;
            if (nodeList.get(to).weight < nodeList.get(from).weight + cost) continue;
            nodeList.get(to).weight = nodeList.get(from).weight + cost;
            for (int i = 0; i < nodeList.get(to).nodeList.size(); i++) {
                Vertex v = nodeList.get(to).nodeList.get(i);
                if(v.value+ nodeList.get(v.from).weight<nodeList.get(v.to).weight)pq.add(v);
            }
        }
        int a1 = nodeList.get(e1).weight; //S->e1
        int a2 = nodeList.get(e2).weight; //S->e2

        initialize(nodeList);
        pq.addAll(nodeList.get(e1).nodeList);
        nodeList.get(e1).weight = 0;
        while (!pq.isEmpty()) {
            Vertex target = pq.poll();
            int from = target.from;
            int to = target.to;
            int cost = target.value;
            if (nodeList.get(to).weight < nodeList.get(from).weight + cost) continue;
            nodeList.get(to).weight = nodeList.get(from).weight + cost;
            for (int i = 0; i < nodeList.get(to).nodeList.size(); i++) {
                Vertex v = nodeList.get(to).nodeList.get(i);
                if(v.value+ nodeList.get(v.from).weight<nodeList.get(v.to).weight)pq.add(v);
            }
        }
        int b = nodeList.get(e2).weight;  //e1->e2==e2->e1
        int c2 = nodeList.get(nodeCount).weight; //e1->N

        initialize(nodeList);
        pq.addAll(nodeList.get(e2).nodeList);
        nodeList.get(e2).weight = 0;
        while(!pq.isEmpty()){
            Vertex target = pq.poll();
            int from = target.from;
            int to = target.to;
            int cost = target.value;
            if(nodeList.get(to).weight<nodeList.get(from).weight + cost) continue;
            nodeList.get(to).weight = nodeList.get(from).weight + cost;
            for (int i = 0; i < nodeList.get(to).nodeList.size(); i++) {
                Vertex v = nodeList.get(to).nodeList.get(i);
                if(v.value+ nodeList.get(v.from).weight<nodeList.get(v.to).weight)pq.add(v);
            }
        }
        int c1 = nodeList.get(nodeCount).weight; //e2->N
        int ans = Math.min(a1+b+c1, a2+b+c2);
        if(!check(new int[]{a1,a2,b,c2,c1})){
            System.out.println(-1);
        }else System.out.println(ans);
    }
    static boolean check(int[] a){
        for (int i = 0; i < a.length; i++) {
            if(a[i]==Integer.MAX_VALUE) return false;
        }
        return true;
    }

    static void initialize(List<Nodes> list){
        list.forEach(Nodes::init);
    }
}

class Nodes {
    int value;
    List<Vertex> nodeList = new ArrayList<>();
    int weight = Integer.MAX_VALUE;

    Nodes(int value) {
        this.value = value;
    }

    void add(Vertex v) {
        this.nodeList.add(v);
    }

    void init(){
        this.weight=Integer.MAX_VALUE;
    }

}

class Vertex implements Comparable<Vertex> {
    int from;
    int to;
    int value;

    Vertex(int from, int to, int value) {
        this.from = from;
        this.to = to;
        this.value = value;
    }

    @Override
    public int compareTo(Vertex o) {
        return this.value - o.value;
    }
}

'백준' 카테고리의 다른 글

[JAVA] 백준 1967 - 트리의 지름  (1) 2025.08.18
[JAVA] 백준 1753 - 최단경로  (4) 2025.08.13
[JAVA] 백준 1043 - 거짓말  (2) 2025.08.12
[JAVA] 백준 17070 - 파이프 옮기기 1  (8) 2025.08.11
백준 수학 - 1059번 : 좋은 구간  (0) 2022.03.11