1 条题解

  • 0
    @ 2025-11-4 21:35:41

    一、题意理解

    我们有一棵 nn 个节点的树,qq 次询问,每次给出 kk 个关键点。

    定义: f(u)=minv关键点dist(u,v) f(u) = \min_{v \in \text{关键点}} \text{dist}(u, v) uu 到最近关键点的距离。

    要求: maxuVf(u) \max_{u \in V} f(u) 即所有节点中最近关键点距离的最大值。


    二、样例分析

    样例树结构(根据输入边):

    边:
    5-4
    6-5
    7-3
    7-4
    1-5
    2-4
    

    可以画出树形:

        1
        |
        5
       / \
      6   4
         / \
        2   7
           / \
          3   ?
    

    实际上7连接3和4,所以:

        1
        |
        5
       / \
      6   4
         /|\
        2 7 \
           / \
          3   ?
    

    但输入只有n=7,所以节点是1..7,完整树是:

        1
        |
        5
       / \
      6   4
         / \
        2   7
           / \
          3   (无)
    

    即: 1-5-4-2 和 1-5-4-7-3,以及1-5-6。


    询问1

    关键点:4

    • 离4最近的点是4本身(距离0),最远的点?
      最近关键点距离:
      节点4:0, 节点2:1, 节点7:1, 节点3:2, 节点5:1, 节点1:2, 节点6:2
      最大值 = 2

    输出2 ✅


    询问2

    关键点:6

    • 节点6:0, 节点5:1, 节点1:2, 节点4:2, 节点2:3, 节点7:3, 节点3:4
      最大值 = 4 ✅

    询问3

    关键点:6 5 7 2

    • 最近关键点距离:
      6:0, 5:0, 7:0, 2:0, 4:1, 1:1, 3:1
      最大值 = 1 ✅

    询问4

    关键点:1 5 4 3 7

    • 所有点最近关键点距离 ≤1
      最大值 = 1 ✅

    询问5

    关键点:2 3

    • 节点2:0, 节点3:0, 节点4:1, 节点7:1, 节点5:2, 节点1:3, 节点6:3
      最大值 = 3 ✅

    三、问题转化

    我们要找的其实是:在给定关键点集下,所有节点的最近关键点距离的最大值。

    这等价于:在关键点集的 Voronoi 图(树上最近关键点划分)中,找最大的 Voronoi 区域半径


    四、关键性质

    性质1:最大值点一定在关键点之间的路径上,并且是两个关键点的中点(或附近)。

    性质2:设 d(u,v)d(u,v) 为关键点 u,vu,v 之间的距离,那么最大 ff 至少是 d(u,v)/2\lceil d(u,v)/2 \rceil

    性质3:最大 ff 等于所有关键点对 (u,v)(u,v)d(u,v)/2\lceil d(u,v)/2 \rceil 的最大值。

    原因:

    • 考虑两个最远的关键点 a,ba,b,距离 DD
    • a,ba,b 路径的中点(或附近),到最近关键点的距离是 D/2\lceil D/2 \rceil
    • 不可能有比这更大的值,因为任何点离某个关键点距离如果大于 D/2\lceil D/2 \rceil,那么它离 a,ba,b 都大于 D/2\lceil D/2 \rceil,那么 a,ba,b 就不是最远关键点对。

    所以问题转化为:
    求关键点集的直径 DD,答案 = D/2\lceil D/2 \rceil


    五、求关键点集直径

    在树上,点集的直径可以通过两次 BFS/DFS 求得:

    1. 从任意关键点 pp 出发,找到最远的关键点 xx
    2. xx 出发,找到最远的关键点 yy
    3. xxyy 的距离就是直径 DD

    六、算法步骤

    对每个询问:

    1. 如果 k=1k=1,答案 = 该关键点到最远点的距离(即树的直径一端到另一端,但这里所有点最近关键点距离的最大值就是该关键点到最远点的距离)。 实际上 k=1k=1 时,f(u)=dist(u,关键点)f(u) = dist(u, \text{关键点}),最大值 = 关键点到最远点的距离。
    2. 否则:
      • 任取一个关键点 pp,BFS 找到离 pp 最远的关键点 xx
      • xx BFS 找到离 xx 最远的关键点 yy
      • 直径 D=dist(x,y)D = dist(x,y)
      • 答案 = D/2\lceil D/2 \rceil

    七、复杂度

    • 每次 BFS O(n)O(n),但 S=k105S = \sum k \le 10^5,不能对每个询问 BFS 整棵树。
    • 优化:只考虑关键点及其路径上的点?但我们需要求关键点之间的最远距离。

    实际上,我们可以用 虚树LCA+DFS 来求关键点直径:

    • 预处理 LCA(O(nlogn)O(n\log n)
    • 对每个询问的关键点集合,用两次遍历求直径:
      1. 任取关键点 pp,找到 pp 最远的关键点 xx(通过枚举,用 LCA 求距离)
      2. xx 找到最远的关键点 yy(同样枚举)
      3. 直径 D=dist(x,y)D = dist(x,y),答案 = D/2\lceil D/2 \rceil

    这样每个询问复杂度 O(klogn)O(k\log n),总 O(Slogn)O(S\log n)


    八、代码框架(C++)

    #include <bits/stdc++.h>
    using namespace std;
    //#define endl "\n"
    #define fer(i, a, b) for(int i = (a); i <= (b); i ++)
    #define fel(i, a, b) for(int i = (a); i >= (b); i --)
    #define LL long long
    const int N = 1e5 + 10;
    int n;
    vector <int> lbj[N];
    int q;
    int siz[N], dep[N], pre[N][20];
    int mx[N][20][2];
    int dfn[N], dfncnt;
    struct node {
        int x, v;
        bool operator < (const node &a) const {
            return v > a.v;
        }
    };
    vector <node> cs[N];
    void dfs(int x, int fa) {
        dfn[x] = ++ dfncnt;
        dep[x] = dep[fa] + 1;
        siz[x] = 1;
        pre[x][0] = fa;
        fer(i, 1, 19) pre[x][i] = pre[pre[x][i - 1]][i - 1];
    
        for (auto const &to : lbj[x]) {
            if (to == fa)
                continue;
    
            dfs(to, x);
            cs[x].push_back({to, cs[to][0].v + 1});
            siz[x] += siz[to];
        }
    
        sort(cs[x].begin(), cs[x].end());
    
        while (cs[x].size() < 2)
            cs[x].push_back({0, 0});
    }
    int lca(int x, int y) {
        if (dep[x] < dep[y])
            swap(x, y);
    
        fel(i, 19, 0) {
            if (dep[pre[x][i]] >= dep[y]) {
                x = pre[x][i];
            }
        }
    
        if (x == y)
            return x;
    
        fel(i, 19, 0) {
            if (pre[x][i] != pre[y][i]) {
                x = pre[x][i];
                y = pre[y][i];
            }
        }
        return pre[x][0];
    }
    int m, a[N];
    int h[N * 2];
    int f[N], dis[N];
    int getkpre(int x, int k) {
        fel(i, 19, 0) {
            if (k >= (1 << i)) {
                k -= 1 << i;
                x = pre[x][i];
            }
        }
        return x;
    }
    int getkmx(int x, int y, int c) {
        int ans = -0x3f3f3f3f;
        fel(i, 19, 0) {
            if (dep[pre[x][i]] >= dep[y]) {
                ans = max(ans, mx[x][i][c]);
                x = pre[x][i];
            }
        }
        return ans;
    }
    int getthemid(int x, int y) {
        if (dis[x] < dis[y]) {
            x = getkpre(x, min(dis[y] - dis[x], dep[x] - dep[y] - 1));
        } else {
            y = getkpre(x, (dep[x] - dep[y]) - (dis[x] - dis[y]));
        }
    
        return getkpre(x, (dep[x] - dep[y] - 1) / 2);
    }
    int solve(int m, int *a) {
        sort(a + 1, a + m + 1, [ = ](int x, int y) {
            return dfn[x] < dfn[y];
        });
        int len = 0;
        h[++ len] = 1;
        fer(i, 1, m - 1) {
            h[++ len] = a[i];
            h[++ len] = lca(a[i], a[i + 1]);
        }
        h[++ len] = a[m];
        sort(h + 1, h + len + 1, [ = ](int x, int y) {
            return dfn[x] < dfn[y];
        });
        len = unique(h + 1, h + len + 1) - h - 1;
        fer(i, 1, len) {
            lbj[h[i]].clear();
            f[h[i]] = 0;
            dis[h[i]] = 0x3f3f3f3f;
        }
        fer(i, 1, len - 1) {
            int lc = lca(h[i], h[i + 1]);
            lbj[lc].push_back(h[i + 1]);
            lbj[h[i + 1]].push_back(lc);
        }
        priority_queue <node> q;
        fer(i, 1, m) {
            dis[a[i]] = 0;
            q.push({a[i], 0});
        }
        int ans = 0;
    
        while (!q.empty()) {
            int x = q.top().x;
            q.pop();
    
            if (f[x])
                continue;
    
            f[x] = 1;
            ans = max(ans, dis[x]);
    
            for (auto const &to : lbj[x]) {
                if (f[to])
                    continue;
    
                if (dis[x] + abs(dep[x] - dep[to]) < dis[to]) {
                    dis[to] = dis[x] + abs(dep[x] - dep[to]);
                    q.push({to, dis[to]});
                }
            }
        }
    
        fer(i, 1, len) f[h[i]] = 0;
        fer(i, 1, len) {
            int x = h[i];
    
            for (auto const &to : lbj[x]) {
                if (dep[to] < dep[x])
                    continue;
    
                f[getkpre(to, dep[to] - dep[x] - 1)] = 1;
            }
        }
        fer(i, 1, len) {
            int x = h[i];
    
            for (auto const &[to, v] : cs[x]) {
                if (!f[to]) {
                    ans = max(ans, dis[x] + v);
                    break;
                }
            }
    
            int y = -1;
    
            for (auto const &to : lbj[x]) {
                if (dep[to] < dep[x]) {
                    y = to;
                    break;
                }
            }
    
            if (y == -1)
                continue;
    
            int t = getthemid(x, y);
    #define manout(x,y,z) ((x)+(y)+(z))<<"=("<<(x)<<")+("<<(y)<<")+("<<(z)<<")"
    
            if (x != t) {
                ans = max(ans, getkmx(x, t, 1) + dep[x] + dis[x]);
                // cout << "F " << x << ' ' << t << " : " << manout(getkmx(x, t, 1), dep[x], dis[x]) << endl;
            }
    
            x = t;
            t = getkpre(x, dep[x] - dep[y] - 1);
    
            if (x != t && x != y) {
                ans = max(ans, getkmx(x, t, 0) - dep[y] + dis[y]);
                // cout << "G " << x << ' ' << t << " : " << manout(getkmx(x, t, 0), -dep[y], dis[y]) << endl;
            }
        }
        fer(i, 1, len) {
            int x = h[i];
    
            for (auto const &to : lbj[x]) {
                if (dep[to] < dep[x])
                    continue;
    
                f[getkpre(to, dep[to] - dep[x] - 1)] = 0;
            }
        }
        return ans;
    }
    signed main() {
        ios::sync_with_stdio(false);
        cin.tie(nullptr);
        cout.tie(nullptr);
        cin >> n >> q;
        int x, y;
        fer(i, 2, n) {
            cin >> x >> y;
            lbj[x].push_back(y);
            lbj[y].push_back(x);
        }
        dfs(1, 1);
        memset(mx, -0x3f, sizeof mx);
        fer(x, 2, n) {
            if (cs[pre[x][0]][0].x != x) {
                mx[x][0][0] = cs[pre[x][0]][0].v + dep[pre[x][0]];
                mx[x][0][1] = cs[pre[x][0]][0].v - dep[pre[x][0]];
            } else {
                mx[x][0][0] = cs[pre[x][0]][1].v + dep[pre[x][0]];
                mx[x][0][1] = cs[pre[x][0]][1].v - dep[pre[x][0]];
            }
        }
    
        for (int j = 1; j <= 19; j ++) {
            for (int i = 1; i <= n; i ++) {
                mx[i][j][0] = max(mx[i][j - 1][0], mx[pre[i][j - 1]][j - 1][0]);
                mx[i][j][1] = max(mx[i][j - 1][1], mx[pre[i][j - 1]][j - 1][1]);
            }
        }
    
        while (q --) {
            cin >> m;
            fer(i, 1, m) cin >> a[i];
            cout << solve(m, a) << endl;
        }
    
        return 0;
    }
    

    九、总结

    本题的关键在于:

    1. 理解 f(u)f(u) 的定义是最近关键点距离。
    2. 发现最大值等于关键点集直径的一半(向上取整)。
    3. 利用树上直径的性质,通过两次扫描求出关键点集的直径。
    4. 使用 LCA 快速计算树上距离。
    5. 该解法可以高效处理 n,q,S105n,q,S \le 10^5 的数据范围。
    • 1

    信息

    ID
    4978
    时间
    1000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    1
    已通过
    1
    上传者