05-算法模板/03-最近公共祖先LCA:修订间差异

来自三三百科
跳转到导航 跳转到搜索
->Importer
批量导入三三文档
 
33DAI留言 | 贡献
导入1个版本
(没有差异)

2026年5月20日 (三) 16:25的版本

给定一棵 [math]\displaystyle{ n }[/math] 个点的树,[math]\displaystyle{ s }[/math] 是根节点,[math]\displaystyle{ m }[/math] 个询问。每次询问两个点的 LCA。

倍增法(推荐)

[math]\displaystyle{ O(n\log n) }[/math] 预处理,[math]\displaystyle{ O(\log n) }[/math] 查询。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 500'000;
int n, m, s;
vector<int> e[MAXN + 5];
int f[MAXN + 5][25];  // f[i][j] 记录 i 的 2^j 级祖先
int dep[MAXN + 5];    // 节点深度

void dfs(int u, int fa)
{
    f[u][0] = fa;
    for (int v : e[u])
    {
        if (v == fa)
            continue;
        dep[v] = dep[u] + 1;
        dfs(v, u);
    }
}

int lca(int u, int v)
{
    if (dep[v] < dep[u])
        swap(u, v);
    // 拉到同样深度
    for (int j = 20; j >= 0; j--)
        if (dep[v] - dep[u] >= (1 << j))
            v = f[v][j];
    if (u == v)
        return u;
    // 同步往上跳
    for (int j = 20; j >= 0; j--)
        if (f[u][j] != f[v][j])
            u = f[u][j], v = f[v][j];
    return f[u][0];
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> m >> s;
    for (int i = 1; i <= n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dep[s] = 0;
    dfs(s, 0);
    for (int j = 1; (1LL << j) <= n; j++)
        for (int i = 1; i <= n; i++)
            f[i][j] = f[f[i][j - 1]][j - 1];
    while (m--)
    {
        int u, v;
        cin >> u >> v;
        cout << lca(u, v) << "\n";
    }
    return 0;
}

树链剖分

[math]\displaystyle{ O(n) }[/math] 预处理,[math]\displaystyle{ O(\log n) }[/math] 查询。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 500000 + 5;
int n, m, s;
vector<int> e[MAXN];
int fa[MAXN], dep[MAXN], siz[MAXN], hson[MAXN];

void dfs_build(int u, int fat)
{
    hson[u] = 0;
    siz[hson[u]] = 0;
    siz[u] = 1;
    for (int v : e[u])
    {
        if (v == fat)
            continue;
        dep[v] = dep[u] + 1;
        fa[v] = u;
        dfs_build(v, u);
        siz[u] += siz[v];
        if (siz[v] > siz[hson[u]])
            hson[u] = v;
    }
}

int tot, top[MAXN], dfn[MAXN], rnk[MAXN];
void dfs_div(int u, int fat)
{
    dfn[u] = ++tot;
    rnk[tot] = u;
    if (hson[u])
    {
        top[hson[u]] = top[u];
        dfs_div(hson[u], u);
        for (int v : e[u])
        {
            if (v == fat || v == hson[u])
                continue;
            top[v] = v;
            dfs_div(v, u);
        }
    }
}

int lca(int u, int v)
{
    while (top[u] != top[v])
    {
        if (dep[top[u]] > dep[top[v]])
            u = fa[top[u]];
        else
            v = fa[top[v]];
    }
    return dep[u] > dep[v] ? v : u;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> m >> s;
    for (int i = 1; i <= n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dep[s] = 1;
    fa[s] = 0;
    dfs_build(s, 0);
    tot = 0;
    top[s] = s;
    dfs_div(s, 0);
    while (m--)
    {
        int u, v;
        cin >> u >> v;
        cout << lca(u, v) << "\n";
    }
    return 0;
}

Tarjan(离线 + 并查集)

[math]\displaystyle{ O(n + m) }[/math],需要提前读入所有询问。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 500000, MAXM = 500000;
int n, m, s;
vector<int> e[MAXN + 5];
vector<pair<int, int>> ask[MAXN + 5];
int ans[MAXM + 5];
bool vis[MAXN + 5];
int fa[MAXN + 5];

int findFa(int x)
{
    return fa[x] == x ? x : fa[x] = findFa(fa[x]);
}

void dfs(int u, int from)
{
    for (auto [v, id] : ask[u])
        if (vis[v])
            ans[id] = findFa(v);
    vis[u] = true;
    for (int v : e[u])
    {
        if (v == from)
            continue;
        dfs(v, u);
        fa[v] = u;
    }
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> m >> s;
    for (int i = 1; i <= n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    for (int i = 1; i <= m; i++)
    {
        int u, v;
        cin >> u >> v;
        if (u == v)
            ans[i] = u;
        else
        {
            ask[u].push_back({v, i});
            ask[v].push_back({u, i});
        }
    }
    for (int i = 1; i <= n; i++)
        fa[i] = i, vis[i] = false;
    dfs(s, 0);
    for (int i = 1; i <= m; i++)
        cout << ans[i] << "\n";
    return 0;
}