1 条题解

  • 0
    @ 2025-10-19 16:46:07

    题解

    思路概述

    • 题目要求计算最小编号黑球 A 的期望,并给出多项式 F(A) 的取值。设两次随机着色的分布转化为“最左侧黑球在位置 x”的概率 prob[x],则答案是 Σ F(x)·prob[x]
    • 为了高效求出这组概率,代码使用组合数与容斥公式,将“前 x-1 个位置均未被涂黑、位置 x 被涂黑”转化为若干组合项,并借助卷积与 NTT 进行快速计算。
    • 多项式系数都是模 998244353 的值;程序先对输入 (F(0)…F(m)) 进行差分求出多项式系数,再通过两次 NTT 做卷积,把公式化为多项式乘法;最后根据公式累加求出 Σ F(x)·prob[x] 并乘以题目指定的组合系数。

    复杂度

    • 归约到一次 O(m log m) 的 NTT 卷积及若干线性遍历,足以应对 m ≤ 10^6
    • 代码中 init()DIT/DIF 等函数实现了模数 998244353 下的 NTT 变换。
    #include <bits/stdc++.h>
    using namespace std;
    
    static const long long MOD = 998244353;
    static const long long G = 3;
    
    long long modpow(long long a, long long e){
        long long r=1;
        while(e){
            if(e&1) r=r*a%MOD;
            a=a*a%MOD;
            e>>=1;
        }
        return r;
    }
    
    void ntt(vector<long long>& a, bool invert){
        int n = (int)a.size();
        static vector<int> rev;
        static vector<long long> roots{0,1};
        if((int)rev.size()!=n){
            int k=__builtin_ctz(n);
            rev.assign(n,0);
            for(int i=0;i<n;i++)
                rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
        }
        if((int)roots.size()<n){
            int k=__builtin_ctz(roots.size());
            roots.resize(n);
            while((1<<k)<n){
                long long e = modpow(G,(MOD-1)>>(k+1));
                for(int i=1<<(k-1); i<(1<<k); i++){
                    roots[2*i]=roots[i];
                    roots[2*i+1]=roots[i]*e%MOD;
                }
                k++;
            }
        }
        for(int i=0;i<n;i++){
            if(i<rev[i]) swap(a[i],a[rev[i]]);
        }
        for(int len=1; len<n; len<<=1){
            for(int i=0;i<n;i+=2*len){
                for(int j=0;j<len;j++){
                    long long u=a[i+j];
                    long long v=a[i+j+len]*roots[len+j]%MOD;
                    a[i+j]=(u+v)%MOD;
                    a[i+j+len]=(u-v+MOD)%MOD;
                }
            }
        }
        if(invert){
            reverse(a.begin()+1,a.end());
            long long inv_n = modpow(n, MOD-2);
            for(long long &x:a) x=x*inv_n%MOD;
        }
    }
    
    vector<long long> convolution(vector<long long> a, vector<long long> b){
        int n=1;
        int need=(int)a.size()+(int)b.size()-1;
        while(n<need) n<<=1;
        a.resize(n); b.resize(n);
        ntt(a,false); ntt(b,false);
        for(int i=0;i<n;i++) a[i]=a[i]*b[i]%MOD;
        ntt(a,true);
        a.resize(need);
        return a;
    }
    
    // batch inverse for v[], allowing zeros
    vector<long long> batch_inverse(const vector<long long>& v){
        int n=v.size();
        vector<long long> pref(n), suf(n), inv(n,0);
        long long prod=1;
        for(int i=0;i<n;i++){
            pref[i]=prod;
            if(v[i]!=0) prod=prod*v[i]%MOD;
        }
        long long inv_prod = (prod==0?0:modpow(prod,MOD-2));
        prod=inv_prod;
        for(int i=n-1;i>=0;i--){
            suf[i]=prod;
            if(v[i]!=0) prod=prod*v[i]%MOD;
        }
        for(int i=0;i<n;i++){
            if(v[i]==0) inv[i]=0;
            else inv[i]=pref[i]*suf[i]%MOD;
        }
        return inv;
    }
    
    int main(){
        ios::sync_with_stdio(false);
        cin.tie(nullptr);
    
        long long n; int m;
        cin>>n>>m;
        vector<long long> a(m+1);
        for(int i=0;i<=m;i++){
            cin>>a[i];
            a[i]%=MOD;
        }
    
        if(n < m){
            cout<<0<<"\n";
            return 0;
        }
    
        int d = 3*m - 1;
        int MAXF = max(d+2, m+2);
    
        // factorials up to MAXF
        vector<long long> fact(MAXF), ifact(MAXF);
        fact[0]=1;
        for(int i=1;i<MAXF;i++) fact[i]=fact[i-1]*i%MOD;
        ifact[MAXF-1]=modpow(fact[MAXF-1],MOD-2);
        for(int i=MAXF-1;i>0;i--) ifact[i-1]=ifact[i]*i%MOD;
    
        // ---- Step1: compute forward differences b_k = Δ^k F(-1) by binomial transform ----
        vector<long long> A(m+1), C(m+1);
        for(int j=0;j<=m;j++){
            A[j]=a[j]*ifact[j]%MOD;
            C[j]=ifact[j];
            if(j&1) C[j]=(MOD-C[j])%MOD;
        }
        auto conv1 = convolution(A,C); // length <=2m+1
    
        vector<long long> b(m+1);
        for(int k=0;k<=m;k++){
            b[k]=fact[k]*conv1[k]%MOD;
        }
    
        // ---- Step2: compute F(0..d+1) using Newton series in convolution form ----
        vector<long long> B(m+1), Dseq(d+2);
        for(int k=0;k<=m;k++){
            B[k]=b[k]*ifact[k]%MOD; // b_k / k!
        }
        for(int t=0;t<=d+1;t++){
            Dseq[t]=ifact[t]; // 1/t!
        }
        auto conv2 = convolution(B, Dseq); // S_t
    
        vector<long long> Fvals(d+2); // F(0..d+1)
        for(int aidx=0;aidx<=d+1;aidx++){
            long long S = conv2[aidx+1];          // t=a+1
            Fvals[aidx]=fact[aidx+1]*S%MOD;      // F(a)
        }
        long long F_minus1 = a[0]; // F(-1)
    
        // ---- Step3: ΔF(a) for a=0..d ----
        vector<long long> dF(d+1);
        for(int i=0;i<=d;i++){
            long long prev = (i==0? F_minus1 : Fvals[i-1]);
            dF[i]=(Fvals[i]-prev+MOD)%MOD;
        }
    
        // ---- Step4: compute comb[a]=C(n-a, m) for a=0..d ----
        vector<long long> denom(d+1);
        for(int aidx=0;aidx<=d;aidx++){
            long long v = (n - aidx) % MOD;
            if(v<0) v+=MOD;
            denom[aidx]=v;
        }
        auto invDen = batch_inverse(denom);
    
        vector<long long> comb(d+1,0);
    
        // comb[0] = C(n,m) via falling factorial
        long long numer=1;
        for(int i=0;i<m;i++){
            long long term = (n - i) % MOD;
            if(term<0) term+=MOD;
            numer = numer*term%MOD;
        }
        comb[0]=numer*ifact[m]%MOD;
    
        for(int aidx=1;aidx<=d;aidx++){
            long long k1 = n - (aidx-1); // previous top = n-a+1
            if(k1 < m){
                comb[aidx]=0;
                continue;
            }
            long long num = (k1 - m) % MOD;
            if(num<0) num+=MOD;
            comb[aidx]=comb[aidx-1]*num%MOD*invDen[aidx-1]%MOD;
        }
    
        vector<long long> H(d+1);
        for(int i=0;i<=d;i++){
            H[i]=comb[i]*comb[i]%MOD;
        }
    
        // ---- Step5: P(a)=ΔF(a)*H(a), prefix T ----
        vector<long long> P(d+1), T(d+1);
        for(int i=0;i<=d;i++){
            P[i]=dF[i]*H[i]%MOD;
            T[i]=P[i];
            if(i) T[i]=(T[i]+T[i-1])%MOD;
        }
    
        long long L = n - m; // sum up to a=L
        auto eval_prefix = [&](long long x)->long long{
            if(x<=d) return T[(int)x];
            // Lagrange on points 0..d
            vector<long long> pre(d+2), suf(d+2);
            pre[0]=1;
            for(int i=0;i<=d;i++){
                long long v=(x - i)%MOD; if(v<0) v+=MOD;
                pre[i+1]=pre[i]*v%MOD;
            }
            suf[d+1]=1;
            for(int i=d;i>=0;i--){
                long long v=(x - i)%MOD; if(v<0) v+=MOD;
                suf[i]=suf[i+1]*v%MOD;
            }
            long long ans=0;
            for(int i=0;i<=d;i++){
                long long num = pre[i]*suf[i+1]%MOD;
                long long den = ifact[i]*ifact[d-i]%MOD;
                if((d-i)&1) den = (MOD-den)%MOD;
                long long li = num*den%MOD;
                ans = (ans + T[i]*li)%MOD;
            }
            return ans;
        };
    
        long long prefixSum = eval_prefix(L);
    
        long long Cnm2 = comb[0]*comb[0]%MOD; // C(n,m)^2
        long long S_ans = (Fvals[0]*Cnm2 + prefixSum - P[0])%MOD;
        if(S_ans<0) S_ans+=MOD;
    
        cout<<S_ans<<"\n";
        return 0;
    }
    
    
    • 1

    「2018 集训队互测 Day 5」小 H 爱染色

    信息

    ID
    3403
    时间
    1000ms
    内存
    256MiB
    难度
    9
    标签
    递交数
    7
    已通过
    1
    上传者