1 条题解

  • 0
    @ 2025-10-15 15:41:53

    算法标签

    动态规划、FWT(快速沃尔什变换变换)、线性递推、矩阵快速幂、概率与期望

    题解

    问题分析

    题目要求计算经过 K 次操作后,最终数字为 i 的所有方案中权值乘概率之和。每次操作有两种可能:以概率 p 执行 OR 操作,以概率 1-p 执行 AND 操作。权值为每次操作后数字对应的 c 值之和。

    核心挑战在于:

    1. K 可能高达 1e9,无法直接模拟
    2. 状态空间为 2^n,n 最大为 17,状态数达 131072

    需要利用 FWT 加速位运算相关的转移,并结合线性递推优化高次幂计算。

    方法思路

    1. 状态表示

      • 定义两个数组:
        • f[t][0][x]:经过 t 次操作后结果为 x 的概率
        • f[t][1][x]:经过 t 次操作后结果为 x 的权值乘概率之和
    2. 转移方程

      • 每次操作有两种选择(OR 或 AND),需要对所有可能的 x 计算转移
      • 使用 FWT 加速 OR 和 AND 卷积操作,将 O(4^n) 的转移优化为 O(n·2^n)
    3. 线性递推优化

      • 当 K 很大时,通过 Berlekamp-Massey 算法寻找线性递推关系
      • 利用快速幂计算递推结果,将时间复杂度从 O(K) 降至 O(log K)

    代码解析

    #include <stdio.h>
    #include <algorithm>
    #include <vector>
    #include <chrono>
    #include <random>
    #include <tuple>
    typedef unsigned int uint;
    typedef unsigned long long ull;
    constexpr uint mod{998244353};
    
    // 模运算辅助函数
    constexpr uint plus(const uint &x, const uint &y) {
        return x + y >= mod ? x + y - mod : x + y;
    }
    constexpr uint minus(const uint &x, const uint &y) {
        return x < y ? x - y + mod : x - y;
    }
    constexpr uint power(uint x, uint y) {
        uint s(1);
        while (y > 0) {
            if (y & 1) s = ull(s) * x % mod;
            x = ull(x) * x % mod;
            y >>= 1;
        }
        return s;
    }
    
    // FWT 相关函数 - 用于加速 OR 和 AND 卷积
    constexpr void FWT_or(const uint &n, uint *const f) {
        for (uint l = 1; l != n; l <<= 1) {
            for (uint *i = f; i != f + n; i += l << 1) {
                for (uint *j = i; j != i + l; j++) {
                    *(j + l) = plus(*(j + l), *j);
                }
            }
        }
    }
    constexpr void IFWT_or(const uint &n, uint *const f) {
        for (uint l = 1; l != n; l <<= 1) {
            for (uint *i = f; i != f + n; i += l << 1) {
                for (uint *j = i; j != i + l; j++) {
                    *(j + l) = minus(*(j + l), *j);
                }
            }
        }
    }
    constexpr void FWT_and(const uint &n, uint *const f) {
        for (uint l = 1; l != n; l <<= 1) {
            for (uint *i = f; i != f + n; i += l << 1) {
                for (uint *j = i; j != i + l; j++) {
                    *j = plus(*j, *(j + l));
                }
            }
        }
    }
    constexpr void IFWT_and(const uint &n, uint *const f) {
        for (uint l = 1; l != n; l <<= 1) {
            for (uint *i = f; i != f + n; i += l << 1) {
                for (uint *j = i; j != i + l; j++) {
                    *j = minus(*j, *(j + l));
                }
            }
        }
    }
    
    // Berlekamp-Massey 算法 - 用于寻找线性递推关系
    constexpr uint BM(const uint &n, const uint *const f, uint *const g) {
        uint lst[n + 1] {}, tmp[n + 1] {};
        uint len(0);
        uint lstpos(-1), lstlen(0);
        uint lstdel(0);
    
        for (uint i = 0; i != n; i++) {
            uint val(0);
            for (uint j = 1; j <= len; j++) {
                val = (val + ull(f[i - j]) * g[j]) % mod;
            }
            if (val == f[i]) continue;
            
            if (lstpos == -1) {
                len = i + 1;
                std::fill(g + 1, g + i + 2, 0);
                lstpos = i, lstlen = 0;
                lstdel = f[i];
                continue;
            }
    
            const uint del(minus(f[i], val));
            const uint coef(ull(del)*power(lstdel, mod - 2) % mod);
            std::copy(g + 1, g + len + 1, tmp + 1);
            
            for (uint j = 1; j <= lstlen; j++) {
                g[j + (i - lstpos)] = (g[j + (i - lstpos)] + ull(mod - coef) * lst[j]) % mod;
            }
            
            g[i - lstpos] = plus(g[i - lstpos], coef);
            std::tie(lstlen, len) = std::make_tuple(len, std::max(len, lstlen + (i - lstpos)));
            lstpos = i, lstdel = del;
            std::copy(tmp + 1, tmp + lstlen + 1, lst + 1);
        }
        return len;
    }
    
    // NTT 相关函数 - 用于多项式运算
    constexpr uint N{8};
    uint w[1 << N | 1];
    constexpr uint getn(const uint &n) {
        return 1 << std::__lg(n - 1) + 1;
    }
    inline void init() {
        w[1 << N] = power(3, mod - 1 >> N + 2);
        for (int i = N; i != 0; i--) {
            w[1 << i - 1] = ull(w[1 << i]) * w[1 << i] % mod;
        }
        w[0] = 1;
        for (int i = 1; i != 1 << N; i++) {
            w[i] = ull(w[i & i - 1]) * w[i & -i] % mod;
        }
    }
    inline void DIF(const uint &n, uint *const f) {
        for (uint l = n >> 1; l != 0; l >>= 1) {
            for (uint *i = f, *o = w; i != f + n; i += l << 1, o++) {
                for (uint *j = i; j != i + l; j++) {
                    const uint t(ull(*(j + l)) **o % mod);
                    *(j + l) = minus(*j, t), *j = plus(*j, t);
                }
            }
        }
    }
    inline void DIT(const uint &n, uint *const f) {
        for (uint l = 1; l != n; l <<= 1) {
            for (uint *i = f, *o = w; i != f + n; i += l << 1, o++) {
                for (uint *j = i; j != i + l; j++) {
                    const uint t(*(j + l));
                    *(j + l) = ull(*j + mod - t) **o % mod, *j = plus(*j, t);
                }
            }
        }
        std::reverse(f + 1, f + n);
        const uint t(mod - (mod - 1 >> std::__lg(n)));
        for (uint *i = f; i != f + n; i++) {
            *i = ull(*i) * t % mod;
        }
    }
    
    // 多项式运算函数
    inline void multiply(const uint &n, uint *const f, uint *const g) {
        DIF(n, f), DIF(n, g);
        for (uint i = 0; i != n; i++) {
            f[i] = ull(f[i]) * g[i] % mod;
        }
        DIT(n, f);
    }
    inline void inverse(const uint &n, const uint *const f, uint *const g) {
        uint *t0(new uint[n << 1] {}), *t1(new uint[n << 1] {});
        t1[0] = power(f[0], mod - 2);
        for (uint l = 1; l != n; l <<= 1) {
            std::copy(f, f + (l << 1), t0);
            std::fill(t0 + (l << 1), t0 + (l << 2), 0);
            DIF(l << 2, t0), DIF(l << 2, t1);
            for (uint i = 0; i != l << 2; i++) {
                t1[i] = (2 + ull(mod - t0[i]) * t1[i]) % mod * t1[i] % mod;
            }
            DIT(l << 2, t1);
            std::fill(t1 + (l << 1), t1 + (l << 2), 0);
        }
        std::copy(t1, t1 + n, g);
        delete[] t0, delete[] t1;
    }
    
    // 线性递推计算
    inline void linear_recurrence(const uint &n, const uint *const f, uint *const g, const int &k) {
        std::vector<uint> t0, t1, t2, t3;
        t0.assign({0, 1});
        t1.assign({1});
        t2.assign(f, f + n + 1);
        for (int i = k; i; i >>= 1) {
            if (i & 1) {
                multiply(t1, t0, t3);
                modulo(t3, t2, t1);
            }
            multiply(t0, t0, t3);
            modulo(t3, t2, t0);
        }
        std::fill(g, g + n, 0);
        std::copy(t1.begin(), t1.end(), g);
    }
    
    // 主函数
    uint wt[1 << 17 | 1], coef[1 << 17 | 1];
    uint val[2][1 << 17 | 1];
    uint f[85][2][1 << 17 | 1], g[2][1 << 17 | 1], h[2][1 << 17 | 1];
    uint arr[85], arr1[85], arr2[85], arr3[85];
    uint ans[1 << 17 | 1];
    
    int main() {
        init();
        int n, k, x;
        uint p;
        scanf("%d%u%d%d", &n, &p, &k, &x);
        
        // 读取权值数组
        for (int i = 0; i != 1 << n; i++) {
            scanf("%u", wt + i);
        }
        
        // 预处理概率值
        const uint p2(power(2, mod - 1 - n));
        for (int i = 0; i != 1 << n; i++) {
            val[0][i] = ull(p2) * p % mod;
        }
        FWT_or(1 << n, val[0]);
        
        for (int i = 0; i != 1 << n; i++) {
            val[1][i] = ull(p2) * (mod + 1 - p) % mod;
        }
        FWT_and(1 << n, val[1]);
        
        // 随机系数用于线性递推
        for (int i = 0; i != 1 << n; i++) {
            coef[i] = rand<uint>(1, mod - 1);
        }
        
        // 初始化状态
        std::fill(f[0][0], f[0][0] + (1 << n), 0);
        std::fill(f[0][1], f[0][1] + (1 << n), 0);
        f[0][0][x] = 1;
        
        // 模拟前 80 步
        for (int i = 1; i <= 80; i++) {
            // OR 操作转移
            std::copy(f[i - 1][0], f[i - 1][0] + (1 << n), g[0]), FWT_or(1 << n, g[0]);
            std::copy(f[i - 1][1], f[i - 1][1] + (1 << n), g[1]), FWT_or(1 << n, g[1]);
            for (int j = 0; j != 1 << n; j++) {
                g[0][j] = ull(g[0][j]) * val[0][j] % mod;
                g[1][j] = ull(g[1][j]) * val[0][j] % mod;
            }
            IFWT_or(1 << n, g[0]);
            IFWT_or(1 << n, g[1]);
            std::copy(g[0], g[0] + (1 << n), h[0]);
            std::copy(g[1], g[1] + (1 << n), h[1]);
            
            // AND 操作转移
            std::copy(f[i - 1][0], f[i - 1][0] + (1 << n), g[0]), FWT_and(1 << n, g[0]);
            std::copy(f[i - 1][1], f[i - 1][1] + (1 << n), g[1]), FWT_and(1 << n, g[1]);
            for (int j = 0; j != 1 << n; j++) {
                g[0][j] = ull(g[0][j]) * val[1][j] % mod;
                g[1][j] = ull(g[1][j]) * val[1][j] % mod;
            }
            IFWT_and(1 << n, g[0]);
            IFWT_and(1 << n, g[1]);
            
            // 合并结果并更新权值
            for (int j = 0; j != 1 << n; j++) {
                h[0][j] = plus(h[0][j], g[0][j]);
                h[1][j] = plus(h[1][j], g[1][j]);
                h[1][j] = (h[1][j] + ull(h[0][j]) * wt[j]) % mod;
            }
            
            std::copy(h[0], h[0] + (1 << n), f[i][0]);
            std::copy(h[1], h[1] + (1 << n), f[i][1]);
            
            // 计算线性递推的系数
            for (int j = 0; j != 1 << n; j++) {
                arr[i] = (arr[i] + ull(f[i][1][j]) * coef[j]) % mod;
            }
            
            // 如果 K 较小,直接输出
            if (i == k) {
                for (int j = 0; j != 1 << n; j++) {
                    printf("%u%c", f[i][1][j], j + 1 == 1 << n ? '\n' : ' ');
                }
                return 0;
            }
        }
        
        // 对于大 K,使用线性递推
        arr[0] = 0;
        const int len(BM(81, arr, arr1));
        for (int i = 1; i <= len; i++) {
            arr2[len - i] = minus(0, arr1[i]);
        }
        arr2[len] = 1;
        linear_recurrence(len, arr2, arr3, k);
        
        // 计算最终结果
        std::fill(ans, ans + (1 << n), 0);
        for (int i = 0; i < len; i++) {
            for (int j = 0; j != 1 << n; j++) {
                ans[j] = (ans[j] + ull(f[i][1][j]) * arr3[i]) % mod;
            }
        }
        
        // 输出结果
        for (int j = 0; j != 1 << n; j++) {
            printf("%u%c", ans[j], j + 1 == 1 << n ? '\n' : ' ');
        }
        
        return 0;
    }
    

    复杂度分析

    • 时间复杂度:O(n·2^n + 80·n·2^n + L^2·log K),其中 L 是线性递推的阶数(约 80)
    • 空间复杂度:O(80·2^n),用于存储前 80 步的状态

    该算法通过 FWT 加速位运算转移,结合线性递推处理大 K 值,高效解决了题目约束下的问题。

    • 1

    信息

    ID
    3152
    时间
    1000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    5
    已通过
    1
    上传者