1 条题解

  • 0
    @ 2025-10-19 16:33:13

    题解

    思路概述

    • 树上查询要求统计路径上每个结点深度的 k 次幂和,模 998244353。设根为 1 且深度从 0 开始,路径 x ↔ y 的贡献可以拆成两段前缀:
      ans = Σ_{u∈path(x,y)} depth(u)^k = prefix(x,k) + prefix(y,k) - 2*prefix(lca,k) + depth(lca)^k
    • 因此只需一次 DFS 预处理 prefix[u][k] = Σ depth(v)^k(根到 u 的路径)。k ≤ 50,直接对每个结点存 50 个值即可。
    • 为了求 LCA,本题使用倍增法:fa[u][j] 表示从 u 向上跳 2^j 层的祖先,配合深度即可在 O(log n) 时间求 LCA。

    实现要点

    • DFS 时传入父结点与当前深度,先写好 fast_pow(depth, k) 对每个 k 计算深度的 k 次幂,再累加到父亲的前缀和即可。
    • 倍增预处理时从 fa[][0] 往上推,其它层没有祖先则保持 0。查询阶段先把较深的点跳到同一深度,再一起上跳直到相遇。
    • 需要注意所有加减都取模,尤其在做 ans - 2*prefix[lca][k] 时要加模再取模防止出现负值。

    复杂度

    • 预处理 DFS 为 O(n * K),倍增表是 O(n log n),单次查询 O(log n),总复杂度满足题目范围。
    #include <bits/stdc++.h>
    using namespace std;
    using ll = long long;
    
    const int N = 3e5 + 9;
    const int LOGN = 20;
    const int MAXK = 51;
    const int MOD = 998244353;
    
    int n, q;
    int dep[N], fa[N][LOGN];
    vector<int> adj[N];
    ll prefix_sum[N][MAXK];
    
    ll mod_pow(ll base, ll exp) {
        ll res = 1;
        base %= MOD;
        while (exp) {
            if (exp & 1) res = res * base % MOD;
            base = base * base % MOD;
            exp >>= 1;
        }
        return res;
    }
    
    void dfs(int u, int parent, int depth) {
        dep[u] = depth;
        fa[u][0] = parent;
        for (int k = 1; k <= 50; ++k) {
            ll val = mod_pow(depth, k);
            prefix_sum[u][k] = (prefix_sum[parent][k] + val) % MOD;
        }
        for (int v : adj[u]) {
            if (v != parent) dfs(v, u, depth + 1);
        }
    }
    
    void build_lca() {
        for (int j = 1; j < LOGN; ++j) {
            for (int i = 1; i <= n; ++i) {
                if (fa[i][j - 1]) fa[i][j] = fa[fa[i][j - 1]][j - 1];
            }
        }
    }
    
    int lca(int x, int y) {
        if (dep[x] < dep[y]) swap(x, y);
        for (int j = LOGN - 1; j >= 0; --j) {
            if (dep[x] - (1 << j) >= dep[y]) x = fa[x][j];
        }
        if (x == y) return x;
        for (int j = LOGN - 1; j >= 0; --j) {
            if (fa[x][j] && fa[x][j] != fa[y][j]) {
                x = fa[x][j];
                y = fa[y][j];
            }
        }
        return fa[x][0];
    }
    
    int main() {
        ios::sync_with_stdio(false);
        cin.tie(nullptr);
    
        cin >> n;
        for (int i = 1; i < n; ++i) {
            int u, v;
            cin >> u >> v;
            adj[u].push_back(v);
            adj[v].push_back(u);
        }
    
        dfs(1, 0, 0);
        build_lca();
    
        cin >> q;
        while (q--) {
            int x, y, k;
            cin >> x >> y >> k;
            int g = lca(x, y);
            ll ans = (prefix_sum[x][k] + prefix_sum[y][k]) % MOD;
            ans = (ans - 2 * prefix_sum[g][k] % MOD + MOD) % MOD;
            ans = (ans + mod_pow(dep[g], k)) % MOD;
            cout << ans << '\n';
        }
        return 0;
    }
    
    • 1

    信息

    ID
    3385
    时间
    1000ms
    内存
    256MiB
    难度
    6
    标签
    递交数
    1
    已通过
    1
    上传者