1 条题解

  • 0
    @ 2025-10-22 15:53:36

    题目分析

    给定一棵 nn 个节点的树,边权为 1。有 qq 个查询,每个查询给出 kk 个点,需要计算:

    1. 这些点两两之间路径长度的总和
    2. 这些点两两之间路径长度的最小值
    3. 这些点两两之间路径长度的最大值

    直接计算所有点对的时间复杂度为 O(k2)O(k^2),在 kk 较大时不可行。

    解题思路

    使用虚树技术优化:

    1. 树链剖分预处理O(n)O(n) 预处理,支持 O(logn)O(\log n) 求 LCA
    2. 虚树构建:对每个查询的点集构建虚树,只保留关键点和它们的 LCA
    3. 树形DP:在虚树上一次性计算所有需要的统计量

    关键算法

    虚树构建步骤:

    1. 将查询点按 DFS 序排序
    2. 依次加入相邻点的 LCA
    3. 按 DFS 序连接形成虚树

    树形DP状态:

    • siz[u]:子树中关键点数量
    • minn[u]:子树中到最近关键点的距离
    • maxx[u]:子树中到最远关键点的距离

    完整代码

    #include <bits/stdc++.h>
    #define pii pair<int,int>
    #define pb emplace_back
    #define int long long
    #define mk make_pair
    #define se second
    #define fi first
    #ifdef int
    #define inf (int)1e18+10
    #else
    #define inf (int)1e9+10
    #endif
    using namespace std;
    
    const int Max = 1e6 + 10;
    const int mod = 998244353;
    
    inline int read() {
        int res = 0, v = 1;
        char c = getchar();
        while (c < '0' || c > '9') {
            v = (c == '-' ? -1 : 1);
            c = getchar();
        }
        while (c >= '0' && c <= '9') {
            res = (res << 3) + (res << 1) + (c ^ 48);
            c = getchar();
        }
        return res * v;
    }
    
    int dis[Max];
    vector<pii> v[Max];
    
    struct tree {
        int siz, son, dep, top, id, fa, rk;
    } bbb[Max];
    
    #define siz(x) bbb[x].siz
    #define son(x) bbb[x].son
    #define dep(x) bbb[x].dep
    #define top(x) bbb[x].top
    #define id(x) bbb[x].id
    #define fa(x) bbb[x].fa
    #define rk(x) bbb[x].rk
    
    int Res;
    
    void dfs1(int now, int fa) {
        fa(now) = fa;
        dep(now) = dep(fa) + 1;
        siz(now) = 1;
        for (auto [to, val] : v[now]) {
            if (to == fa) continue;
            dis[to] = dis[now] + val;
            dfs1(to, now);
            siz(now) += siz(to);
            if (siz(to) > siz(son(now))) {
                son(now) = to;
            }
        }
    }
    
    void dfs2(int now, int top) {
        top(now) = top;
        id(now) = ++Res;
        rk(Res) = now;
        if (son(now)) {
            dfs2(son(now), top);
        }
        for (auto [to, val] : v[now]) {
            if (id(to) == 0) {
                dfs2(to, to);
            }
        }
    }
    
    int LCA(int x, int y) {
        while (top(x) != top(y)) {
            if (dep(top(x)) < dep(top(y))) {
                swap(x, y);
            }
            x = fa(top(x));
        }
        return dep(x) < dep(y) ? x : y;
    }
    
    vector<int> p, tmp;
    vector<pii> a[Max];
    
    bool cmp(int a, int b) {
        return id(a) < id(b);
    }
    
    int GetDis(int x, int y) {
        return dis[x] + dis[y] - 2 * dis[LCA(x, y)];
    }
    
    int siz[Max], minn[Max], maxx[Max], vis[Max];
    int sum, mnn, mxx, num;
    
    void dfs(int now, int fa) {
        if (vis[now]) {
            minn[now] = maxx[now] = 0;
            siz[now] = 1;
        } else {
            minn[now] = inf;
            maxx[now] = -inf;
            siz[now] = 0;
        }
    
        for (auto [to, val] : a[now]) {
            if (to == fa) continue;
            dfs(to, now);
            siz[now] += siz[to];
            sum += siz[to] * (num - siz[to]) * val;
            mnn = min(mnn, minn[now] + minn[to] + val);
            mxx = max(mxx, maxx[now] + maxx[to] + val);
            minn[now] = min(minn[now], minn[to] + val);
            maxx[now] = max(maxx[now], maxx[to] + val);
        }
    }
    
    void build() {
        mxx = -inf;
        mnn = inf;
        sum = 0;
        num = p.size();
    
        for (auto x : p) {
            vis[x] = 1;
        }
    
        p.pb(1);
        sort(p.begin(), p.end(), cmp);
        p.erase(unique(p.begin(), p.end()), p.end());
        
        tmp = p;
        for (int i = 0; i < (int)p.size() - 1; i++) {
            int lca = LCA(p[i], p[i + 1]);
            tmp.pb(lca);
        }
        
        sort(tmp.begin(), tmp.end(), cmp);
        tmp.erase(unique(tmp.begin(), tmp.end()), tmp.end());
        
        for (auto x : tmp) {
            a[x].clear();
        }
        
        for (int i = 0; i < (int)tmp.size() - 1; i++) {
            int lca = LCA(tmp[i], tmp[i + 1]);
            int val = GetDis(lca, tmp[i + 1]);
            a[lca].pb(mk(tmp[i + 1], val));
            a[tmp[i + 1]].pb(mk(lca, val));
        }
    
        dfs(1, 0);
    
        for (auto x : p) {
            vis[x] = 0;
        }
        tmp.clear();
        p.clear();
    }
    
    signed main() {
        int n = read();
        for (int i = 1; i < n; i++) {
            int x = read(), y = read();
            v[x].pb(mk(y, 1));
            v[y].pb(mk(x, 1));
        }
    
        dfs1(1, 0);
        dfs2(1, 1);
        
        int q = read();
        for (int i = 1; i <= q; i++) {
            int k = read();
            for (int j = 1; j <= k; j++) {
                p.pb(read());
            }
            build();
            cout << sum << ' ' << mnn << ' ' << mxx << "\n";
        }
    
        return 0;
    }
    

    代码说明

    主要函数:

    • dfs1, dfs2:树链剖分预处理
    • LCA:求最近公共祖先
    • GetDis:计算两点间距离
    • build:构建虚树
    • dfs:在虚树上进行DP统计

    关键变量:

    • sum:所有路径长度总和
    • mnn:最小路径长度
    • mxx:最大路径长度
    • siz[u]:子树关键点数量
    • minn[u]:到最近关键点距离
    • maxx[u]:到最远关键点距离

    复杂度分析

    • 预处理O(n)O(n)
    • 每个查询O(klogk)O(k \log k)
    • 总复杂度O(n+klogk)O(n + \sum k \log k)
    • 1

    信息

    ID
    3687
    时间
    3000ms
    内存
    512MiB
    难度
    10
    标签
    递交数
    10
    已通过
    1
    上传者