1 条题解

  • 0
    @ 2025-11-27 10:41:33

    题目分析

    题目要求计算猴群团结度的期望乘以 ( n! ) 后对 ( 998244353 ) 取模的结果。核心在于将问题转化为组合数学与多项式运算,通过分析树的高度排列对应的连通块结构,推导出递推关系并利用多项式求逆求解。

    解题思路

    1. 问题转化:每棵树的高度是随机排列,绳索连接的条件等价于形成以最大值为根的连通块。团结度的期望可转化为对所有排列的贡献求和,再除以 ( n! )。
    2. 多项式建模:设 ( f(n) ) 为答案,则递推关系可表示为多项式形式。通过构造生成函数,利用多项式乘法与求逆运算求解。
    3. NTT优化:使用快速数论变换(NTT)加速多项式乘法,结合多项式求逆算法高效计算生成函数的逆,最终得到结果。

    代码实现(带注释)

    #include <bits/stdc++.h>
    using namespace std;
    
    constexpr int mod = 998244353;
    typedef long long LL;
    
    // 快速幂:计算 x^y mod mod
    int Pow(int x, LL y) {
        int res = 1;
        for (; y; y >>= 1, x = (LL)x * x % mod)
            if (y & 1)
                res = (LL)res * x % mod;
        return res;
    }
    
    // 多项式结构体
    struct Poly {
        vector<int> p;  // 系数数组,p[i] 表示 x^i 的系数
        int deg;        // 多项式次数(最高次项的次数)
    
        // 空多项式
        Poly() : deg(-1), p() {};
        // 初始化次数为n的多项式,系数全0
        explicit Poly(int n) : deg(n), p(n + 1, 0) {};
        // 用系数数组初始化多项式
        explicit Poly(const vector<int> &coefficients) {
            if (coefficients.empty()) {
                deg = -1;
                p = vector<int>();
            } else {
                p = coefficients;
                deg = static_cast<int>(p.size()) - 1;
                normalize();  // 去除高位零
            }
        }
    
        // 标准化多项式:去除高位零,更新次数
        void normalize() {
            if (p.empty()) {
                deg = -1;
                return;
            }
            int new_deg = static_cast<int>(p.size()) - 1;
            while (new_deg >= 0 && p[new_deg] == 0)
                new_deg--;
            if (new_deg < 0)
                deg = -1, p.clear();
            else
                deg = new_deg, p.resize(deg + 1);
        }
    
        // 重载[]运算符,修改系数
        int &operator[](int i) {
            if (i < 0 || i > deg)
                throw out_of_range("Poly index out of range");
            return p[i];
        }
    
        // 重载[]运算符,访问系数(只读)
        int operator[](int i) const {
            if (i < 0 || i > deg)
                return 0;
            return p[i];
        }
    
        // NTT变换:type=1为正变换,type=0为逆变换
        Poly NTT(int len, int type) const {
            if (deg == -1)
                return Poly();
            vector<int> rev(len, 0);
            // 预处理位逆序
            for (int i = 0; i < len; i++)
                rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (len >> 1) : 0);
            Poly res(len - 1);
            // 复制系数并按位逆序排列
            for (int i = 0; i < len; i++)
                res.p[rev[i]] = (*this)[i];
            // 分治进行NTT
            for (int mid = 1; mid < len; mid <<= 1) {
                int wn = Pow(3, (mod - 1) / (mid << 1));  // 单位根
                if (type == 0)
                    wn = Pow(wn, mod - 2);  // 逆变换用逆元
                for (int i = 0; i < len; i += (mid << 1)) {
                    int w = 1;
                    for (int j = 0; j < mid; j++, w = (LL)w * wn % mod) {
                        int x = res.p[i + j], y = (LL)w * res.p[i + j + mid] % mod;
                        res.p[i + j] = (x + y) % mod;
                        res.p[i + j + mid] = (x - y + mod) % mod;
                    }
                }
            }
            // 逆变换需乘以逆元
            if (type == 0) {
                int inv = Pow(len, mod - 2);
                for (int i = 0; i < len; i++)
                    res.p[i] = (LL)res.p[i] * inv % mod;
            }
            return res;
        }
    
        // 多项式加法
        Poly operator+(const Poly &P) const {
            if (deg == -1)
                return P;
            if (P.deg == -1)
                return *this;
            int new_deg = max(deg, P.deg);
            Poly res(new_deg);
            for (int i = 0; i <= new_deg; i++)
                res[i] = ((*this)[i] + P[i]) % mod;
            res.normalize();
            return res;
        }
    
        // 多项式减法
        Poly operator-(const Poly &P) const {
            if (deg == -1)
                return P;
            if (P.deg == -1)
                return *this;
            int new_deg = max(deg, P.deg);
            Poly res(new_deg);
            for (int i = 0; i <= new_deg; i++)
                res[i] = ((*this)[i] - P[i] + mod) % mod;
            res.normalize();
            return res;
        }
    
        // 多项式乘法(NTT优化)
        Poly operator*(const Poly &P) const {
            if (deg == -1 || P.deg == -1)
                return Poly();
            int new_deg = deg + P.deg;
            int len = 1;
            while (len <= new_deg)
                len <<= 1;  // 取不小于new_deg的最小2的幂
            Poly A = this->NTT(len, 1);
            Poly B = P.NTT(len, 1);
            Poly C(len - 1);
            // 点值相乘
            for (int i = 0; i < len; i++)
                C.p[i] = (LL)A.p[i] * B.p[i] % mod;
            // 逆变换得到系数
            Poly res = C.NTT(len, 0);
            res.deg = len - 1;
            res.normalize();
            return res;
        }
    
        // 多项式求逆(声明)
        Poly Inv(int d);
    };
    
    // 多项式求逆:计算模x^(d+1)的逆元
    Poly Poly::Inv(int d) {
        assert(deg != -1 && this->p[0]);  // 常数项非零
        Poly res(vector<int>({Pow(this->p[0], mod - 2)}));  // 初始逆元为常数项的逆
        int len = 1;
        while (len < d + 1) {
            len <<= 1;
            // 取前len项
            Poly This(vector<int>(this->p.begin(), this->p.begin() + min(this->deg + 1, len)));
            // NTT变换
            Poly this_t = This.NTT(len << 1, 1), res_t = res.NTT(len << 1, 1);
            // Newton迭代:res = res * (2 - This * res) mod x^len
            for (int i = 0; i < (len << 1); i++)
                res_t[i] = (LL)res_t[i] * (2 + mod - (LL)res_t[i] * this_t[i] % mod) % mod;
            // 逆变换
            res = res_t.NTT(len << 1, 0);
            fill(res.p.begin() + len, res.p.end(), 0);
            res.normalize();
        }
        res.p.resize(d + 1);
        res.normalize();
        return res;
    }
    
    int main() {
        int n;
        cin >> n;
        // 构造阶乘多项式 Fac(x) = sum_{i=0}^n i! x^i
        Poly Fac(n);
        Fac[0] = 1;
        for (int i = 1; i <= n; i++)
            Fac[i] = (LL)Fac[i - 1] * i % mod;
        // FFac为Fac的副本,用于后续求逆
        Poly FFac = Fac;
        Fac[0] = 0;  // Fac(x) 变为 sum_{i=1}^n i! x^i
        // 计算 coef = Fac * FFac^{-1} mod x^{n+1}
        Poly coef = Fac * FFac.Inv(n);
        coef.p.resize(n + 1);
        coef.normalize();
        // 调整系数:coef[i] = -i * coef[i]
        for (int i = 0; i <= coef.deg; i++)
            coef[i] = (mod - (LL)coef[i] * i % mod) % mod;
        coef[0] = (coef[0] + 1) % mod;  // 常数项加1
        // 求逆后取x^n的系数即为答案
        cout << coef.Inv(n)[n] << endl;
        return 0;
    }
    

    代码解释

    1. 多项式结构体:封装多项式的存储、标准化、运算(加、减、乘)及NTT变换。
    2. NTT变换:通过分治实现快速数论变换,加速多项式乘法。
    3. 多项式求逆:使用牛顿迭代法,在模 ( x^{d+1} ) 意义下求多项式的逆元。
    4. 主逻辑:构造阶乘多项式,通过多项式乘法与求逆得到目标多项式,最终取 ( x^n ) 的系数作为答案。
    • 1