본문 바로가기

PS/Problems

[Java] 백준 1761번 정점들의 거리 - LCA 응용

https://www.acmicpc.net/problem/1761

 

1761번: 정점들의 거리

첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 M이 주어지고, 다음 M개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩

www.acmicpc.net

기존의 LCA 문제에서 나아가, 간선들의 길이의 합에 대한 정보도 필요하다.

처음에는 세그먼트 트리를 생각해보았지만 간선이 연결되는 각 경우마다 트리를 만들 수는 없으므로

각 노드의 2^i 위의 부모에 대한 정보를 담은 parents 배열처럼

2^i 위의 부모까지의 길이정보를 담는 distances 배열을 새로 고안했다.

예술이네

 

import java.util.*;
import java.io.*;

public class p5_1761 {
    static FastScanner fs = new FastScanner();
    static PrintWriter pw = new PrintWriter(System.out);
    static int n, dep = 0;
    static List<Map<Integer, Integer>> list = new ArrayList<>(); 
    static int[] deps;
    static int[][] parents, distances;

    public static void main(String[] args) {
        n = fs.nextInt();
        for (int i=0;i<=n;i++) list.add(new TreeMap<>());
        for (int i=0;i<n-1;i++){
            int a = fs.nextInt(), b = fs.nextInt(), d = fs.nextInt();
            list.get(a).put(b, d);
            list.get(b).put(a, d);
        }

        deps = new int[n + 1];
        int temp = 1;
        while (temp < n){
            temp *= 2;
            dep++;
        }
        
        parents = new int[n + 1][dep];
        distances = new int[n + 1][dep];
        dfs(1, 1);
        fill();

        int t = fs.nextInt();
        while (t-- > 0){
            int a = fs.nextInt(), b = fs.nextInt();
            int as = a, bs = b;

            if (deps[a] < deps[b]){
                temp = a;
                a = b;
                b = temp;
            }

            for (int i=dep-1;i>=0;i--){
                if (Math.pow(2, i) <= deps[a] - deps[b]){
                    a = parents[a][i];
                }
            }

            for (int i=dep-1;i>=0;i--){
                if (parents[a][i] != parents[b][i]){
                    a = parents[a][i];
                    b = parents[b][i];
                }
            }

            int lca = (a == b) ? a : parents[a][0];
            int ans = 0;
            // pw.println(as + " " + bs + " " + lca);
            for (int i=dep-1;i>=0;i--){
                if (Math.pow(2, i) <= deps[as] - deps[lca]){
                    ans += distances[as][i];
                    as = parents[as][i];
                }
                if (Math.pow(2, i) <= deps[bs] - deps[lca]){
                    ans += distances[bs][i];
                    bs = parents[bs][i];
                }
            }
            pw.println(ans);
        }

        pw.close();
    }

    static void fill(){
        for (int i=1;i<dep;i++){
            for (int j=1;j<=n;j++){
                parents[j][i] = parents[parents[j][i - 1]][i - 1];
                distances[j][i] = distances[j][i - 1] + distances[parents[j][i - 1]][i - 1];
            }
        }
    }

    static void dfs(int node, int depth){
        deps[node] = depth;
        
        for (Integer i : list.get(node).keySet()){
            if (deps[i] == 0){
                dfs(i, depth + 1);
                parents[i][0] = node;
                distances[i][0] = list.get(i).get(node);
            }
        }
    }

    // ----------input function----------

    static void sort(int[] a) {
        ArrayList<Integer> L = new ArrayList<>();
        for (int i : a)
            L.add(i);
        Collections.sort(L);
        for (int i = 0; i < a.length; i++)
            a[i] = L.get(i);
    }

    static class FastScanner {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer("");

        String next() {
            while (!st.hasMoreTokens()) {
                try {
                    st = new StringTokenizer(br.readLine());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return st.nextToken();
        }

        int nextInt() {
            return Integer.parseInt(next());
        }

        double nextDouble() {
            return Double.parseDouble(next());
        }

        int[] readArray(int n) {
            int[] a = new int[n];
            for (int i = 0; i < n; i++)
                a[i] = nextInt();
            return a;
        }

        long nextLong() {
            return Long.parseLong(next());
        }
    }
}