1 条题解

  • 0
    @ 2025-10-19 20:30:37

    「PKUWC2018」猎人杀 题解

    问题描述

    nn 个猎人,第 ii 个猎人的仇恨度为 wiw_i。游戏规则如下:

    1. 初始时你开枪,随机选择一个猎人(概率与仇恨度成正比)。
    2. 每个死亡的猎人会立即开枪,目标是当前活着的猎人(概率仍与仇恨度成正比)。
    3. 所有猎人最终都会死亡,求 1 号猎人是最后一个死亡的概率,结果对 998244353998244353 取模。

    核心思路

    1. 问题转化:1 号猎人最后死亡等价于所有其他猎人都在 1 号之前死亡。
    2. 容斥原理:利用容斥原理推导概率公式,将问题转化为多项式系数的求和。
    3. 多项式乘法:通过生成函数(多项式)高效计算容斥项,结合快速数论变换(NTT)优化多项式乘法。

    详细分析

    1. 概率公式推导

    设其他猎人的集合为 U={2,3,,n}U = \{2, 3, \ldots, n\},其总仇恨度为 S=iUwiS = \sum_{i \in U} w_i,1 号猎人的仇恨度为 w1w_1

    关键结论:1 号猎人最后死亡的概率可表示为:

    $$\text{答案} = w_1 \cdot \sum_{S' \subseteq U} \frac{(-1)^{|S'|}}{w_1 + \sum_{i \in S'} w_i} $$

    其中 SS'UU 的子集,iSwi\sum_{i \in S'} w_i 是子集中猎人的仇恨度之和,S|S'| 是子集大小。

    推导依据

    • 利用容斥原理,事件“1 号最后死亡”等价于“所有其他猎人都在 1 号之前死亡”。
    • 对每个子集 SUS' \subseteq U,通过容斥项 (1)S(-1)^{|S'|} 修正概率,最终求和得到总概率。
    2. 生成函数与多项式乘法

    公式中的求和项可通过生成函数计算:

    • 每个猎人 iUi \in U 对应多项式 (1xwi)(1 - x^{w_i}),其展开式中 xkx^k 的系数为 1-1(若 k=wik = w_i)或 11(若 k=0k = 0),其余为 00
    • 所有多项式的乘积为 f(x)=iU(1xwi)f(x) = \prod_{i \in U} (1 - x^{w_i}),其展开式中 xsx^s 的系数 f[s]f[s] 恰好是 SU,wi=s(1)S\sum_{S' \subseteq U, \sum w_i = s} (-1)^{|S'|},即公式中分子的总和。

    因此,求和项 $\sum_{S' \subseteq U} \frac{(-1)^{|S'|}}{w_1 + \sum w_i}$ 等价于 s=0Sf[s]w1+s\sum_{s=0}^S \frac{f[s]}{w_1 + s}

    3. 高效计算与模运算
    • 多项式乘法:使用 NTT 优化多项式乘法,时间复杂度为 O(MlogM)O(M \log M)MM 为多项式最高次数,即 SS)。
    • 逆元计算:预处理 11w1+Sw_1 + S 的模逆元,快速计算 1w1+smod998244353\frac{1}{w_1 + s} \mod 998244353

    代码解析

    1. 输入处理与特殊情况
      读入 nnww,若 n=1n = 1,直接输出 11(唯一猎人必最后死亡)。

    2. 多项式构造
      对每个 i2i \geq 2,构造多项式 (1xwi)(1 - x^{w_i}),即系数数组中 00 处为 11wiw_i 处为 1-1(模 998244353998244353 处理为 998244352998244352)。

    3. 多项式合并
      使用优先队列(最小堆)合并所有多项式,每次合并两个多项式(通过 NTT 实现高效乘法),最终得到乘积多项式 f(x)f(x)

    4. 求和与结果计算
      预处理逆元,计算 s=0Sf[s]w1+s\sum_{s=0}^S \frac{f[s]}{w_1 + s},乘以 w1w_1 后取模,得到答案。

    复杂度分析

    • 时间复杂度:多项式最高次数为 S=i=2nwiS = \sum_{i=2}^n w_i,合并 n1n-1 个多项式的总时间为 O(Slog2S)O(S \log^2 S)(NTT 每次乘法为 O(MlogM)O(M \log M),合并次数为 O(nlogn)O(n \log n))。
    • 空间复杂度O(S)O(S),用于存储多项式系数和逆元数组。

    代码实现

    #include <bits/stdc++.h>
    using namespace std;
    using ll = long long;
    const int MOD = 998244353;
    const int G = 3; // 原根
    
    ll modpow(ll a, ll e = MOD - 2) {
        ll r = 1;
        while (e) {
            if (e & 1) r = r * a % MOD;
            a = a * a % MOD;
            e >>= 1;
        }
        return r;
    }
    
    // NTT 实现
    void ntt(vector<int>& a, bool invert) {
        int n = a.size();
        static vector<int> rev, roots{0, 1};
        if ((int)rev.size() != n) {
            rev.assign(n, 0);
            for (int i = 0; i < n; i++)
                rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? n >> 1 : 0);
        }
        for (int i = 0; i < n; i++)
            if (i < rev[i]) swap(a[i], a[rev[i]]);
        if ((int)roots.size() < n) {
            int k = __builtin_ctz(roots.size());
            roots.resize(n);
            while ((1 << k) < n) {
                ll z = modpow(G, (MOD - 1) >> (k + 1));
                for (int i = 1 << (k - 1); i < (1 << k); ++i) {
                    roots[2 * i] = roots[i];
                    roots[2 * i + 1] = roots[i] * z % MOD;
                }
                ++k;
            }
        }
        for (int len = 1; len < n; len <<= 1) {
            for (int i = 0; i < n; i += 2 * len) {
                for (int j = 0; j < len; j++) {
                    int u = a[i + j], v = 1LL * a[i + j + len] * roots[len + j] % MOD;
                    a[i + j] = (u + v) % MOD;
                    a[i + j + len] = (u - v + MOD) % MOD;
                }
            }
        }
        if (invert) {
            reverse(a.begin() + 1, a.end());
            ll inv_n = modpow(n);
            for (int& x : a) x = x * inv_n % MOD;
        }
    }
    
    // 多项式乘法(NTT 优化)
    vector<int> multiply_ntt(const vector<int>& a, const vector<int>& b, int need) {
        if (a.empty() || b.empty()) return {};
        int as = a.size(), bs = b.size();
        if (min(as, bs) < 64) { // 小规模直接乘法
            vector<int> c(min(need, as + bs - 1));
            for (int i = 0; i < as; i++)
                for (int j = 0; j < bs && i + j < need; j++)
                    c[i + j] = (c[i + j] + 1LL * a[i] * b[j]) % MOD;
            return c;
        }
        int sz = 1;
        while (sz < as + bs - 1) sz <<= 1;
        vector<int> fa(a.begin(), a.end()), fb(b.begin(), b.end());
        fa.resize(sz), fb.resize(sz);
        ntt(fa, false), ntt(fb, false);
        for (int i = 0; i < sz; i++) fa[i] = 1LL * fa[i] * fb[i] % MOD;
        ntt(fa, true);
        fa.resize(min(need, as + bs - 1));
        return fa;
    }
    
    int main() {
        ios::sync_with_stdio(false);
        cin.tie(nullptr);
        int n;
        cin >> n;
        vector<int> w(n + 1);
        for (int i = 1; i <= n; i++) cin >> w[i];
        if (n == 1) {
            cout << 1 << "\n";
            return 0;
        }
        int w1 = w[1];
        int S = 0;
        for (int i = 2; i <= n; i++) S += w[i];
        int need = S + 1;
    
        // 构造多项式 (1 - x^w_i) 并加入优先队列
        struct Poly { vector<int> v; };
        auto cmp = [](const Poly& A, const Poly& B) { return A.v.size() > B.v.size(); };
        priority_queue<Poly, vector<Poly>, decltype(cmp)> pq(cmp);
        for (int i = 2; i <= n; i++) {
            vector<int> p(min(need, w[i] + 1), 0);
            p[0] = 1;
            if (w[i] < need) p[w[i]] = MOD - 1; // -1 mod MOD
            pq.push({p});
        }
    
        // 合并所有多项式
        while (pq.size() > 1) {
            auto A = pq.top(); pq.pop();
            auto B = pq.top(); pq.pop();
            auto C = multiply_ntt(A.v, B.v, need);
            for (int& x : C) if (x < 0) x += MOD;
            pq.push({C});
        }
    
        vector<int> f = pq.top().v;
        f.resize(need, 0);
    
        // 预处理逆元
        int M = w1 + S;
        vector<int> inv(M + 1);
        inv[1] = 1;
        for (int i = 2; i <= M; i++)
            inv[i] = MOD - 1LL * (MOD / i) * inv[MOD % i] % MOD;
    
        // 计算求和项
        ll sum = 0;
        for (int s = 0; s <= S; s++) {
            if (f[s] == 0) continue;
            int denom = w1 + s;
            sum = (sum + 1LL * f[s] * inv[denom]) % MOD;
        }
    
        ll ans = 1LL * w1 % MOD * sum % MOD;
        cout << ans << "\n";
        return 0;
    }
    

    总结

    本题通过容斥原理将概率计算转化为多项式系数求和,结合 NTT 高效处理多项式乘法,最终在模意义下求解。核心在于将抽象的概率问题转化为具体的代数运算,利用生成函数和数论变换实现高效计算。

    • 1

    信息

    ID
    3503
    时间
    1000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    2
    已通过
    1
    上传者