tarjan离线最小公共祖先LCA DFS T3 并查集

1171. 距离 - AcWing题库 给出 个点的一棵树,多次询问两点之间的最短距离。

注意:

  • 边是无向的。
  • 所有节点的编号是

输入格式

第一行为两个整数 表示点数, 表示询问次数;

下来 行,每行三个整数 ,表示点 和点 之间存在一条边长度为

再接下来 行,每行两个整数 ,表示询问点 到点 的最短距离。

树中结点编号从

输出格式

行,对于每次询问,输出一行询问结果。

数据范围

,
,
,

输入样例1:

2 2 
1 2 100 
1 2 
2 1

输出样例1:

100
100

输入样例2:

3 2
1 2 10
3 1 15
1 2
3 2

输出样例2:

10
25

思路

求两点之间的最短路, 时间复杂度来说不能用图论最短路来解决。

预处理每个点到根节点的距离 , 对于 两点间的距离, 就是 , 其中 的最小公共祖先。

若采用倍增LCA来求, 时间复杂度为 , 可以通过。不过这里介绍下更快的做法:tarjan离线求LCA。离线是指把所有查询存下来后一并处理且输出, 在线则是读入一个查询处理一个查询。

原理为:

  1. DFS搜索整个图, 此时所有点有三种类型:
    • 已经搜过且回溯的点, 即已经搜过的点
    • 已经搜过但还没回溯的点, 即当前正在搜索的点及其路径上的点
    • 还未搜过的点
  2. 所有已经搜过的点, 且正在搜索的点中存在其父节点 时, 设这种点为 , 那么对于所有比 深的正在搜索的点 , 的最小公共祖先总会是

具体步骤是

  1. 先用 二维pair数组 存入所有询问, 就是内存的就是和 有关的点和询问操作的
  2. 然后DFS求所有点相距根节点的距离。
  3. 再进行tarjan算法:
    1. 数组 标记当前点是否被枚举过或正在枚举, 0为未搜索, 1为正在搜索, 2为已经搜过
    2. 先递归搜索到最低部, 然后在返回时将当前点的父节点设为上级节点
    3. 一个点搜索完成之后, 判断是否有当前点 是否存在询问 , 若有, 则判断对应的 点是否已经被搜索完成 , 如果成立, 则说明其父节点 就是 的最小公共祖先。
    4. 处理完后回溯时将该点设置为 已被搜索

代码

#include <iostream>
#include <cstring>
#include <algorithm>
#include <string>
#include <cmath>
#include <vector>
using namespace std;
typedef pair<int,int> PII;
const int N = 10010, M = N * 2;
int h[N], w[M], e[M], ne[M], idx;
int dist[N];
int res[M];
int f[N];
vector<PII> query[N];
int n,m;
int st[N];
 
void add(int a, int b, int c) {e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;}
 
void dfs(int u, int fa)
{
    for(int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if(j == fa) continue;
        dist[j] = dist[u] + w[i];
        dfs(j,u);
    }
}
 
int find(int x) {
    if(f[x] != x) f[x] = find(f[x]);
    return f[x];
}
 
void tarjan(int u)
{
    st[u] = 1;
    for(int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if(!st[j])
        {
            tarjan(j);
            f[j] = u;
        }
    }
 
    for(auto item : query[u])
    {
        int a = item.first, id = item.second;
        if(st[a] == 2)
        {
            int anc = find(a);
            res[id] = dist[u] + dist[a] - 2 * dist[anc];
            //cout << res[id] << endl;
        }
    }
    
    st[u] = 2;
}
 
int main()
{
    memset(h, -1, sizeof h);
    cin >> n >> m;
    for(int i = 1; i < n; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        add(a,b,c), add(b,a,c);
    }
    for(int i = 0; i < m; i++)
    {
        int a, b;
        cin >> a >> b;
        if(a != b)
        {
            query[a].push_back({b,i});
            query[b].push_back({a,i});
        }
    }
    for(int i = 1; i <= n; i++) f[i] = i;
    dfs(1,-1);
    tarjan(1);
    
    for(int i = 0; i < m; i++) cout << res[i] << "\n";
    
    return 0;
}