1 条题解
-
0
算法标签
动态规划、FWT(快速沃尔什变换变换)、线性递推、矩阵快速幂、概率与期望
题解
问题分析
题目要求计算经过 K 次操作后,最终数字为 i 的所有方案中权值乘概率之和。每次操作有两种可能:以概率 p 执行 OR 操作,以概率 1-p 执行 AND 操作。权值为每次操作后数字对应的 c 值之和。
核心挑战在于:
- K 可能高达 1e9,无法直接模拟
- 状态空间为 2^n,n 最大为 17,状态数达 131072
需要利用 FWT 加速位运算相关的转移,并结合线性递推优化高次幂计算。
方法思路
-
状态表示:
- 定义两个数组:
- f[t][0][x]:经过 t 次操作后结果为 x 的概率
- f[t][1][x]:经过 t 次操作后结果为 x 的权值乘概率之和
- 定义两个数组:
-
转移方程:
- 每次操作有两种选择(OR 或 AND),需要对所有可能的 x 计算转移
- 使用 FWT 加速 OR 和 AND 卷积操作,将 O(4^n) 的转移优化为 O(n·2^n)
-
线性递推优化:
- 当 K 很大时,通过 Berlekamp-Massey 算法寻找线性递推关系
- 利用快速幂计算递推结果,将时间复杂度从 O(K) 降至 O(log K)
代码解析
#include <stdio.h> #include <algorithm> #include <vector> #include <chrono> #include <random> #include <tuple> typedef unsigned int uint; typedef unsigned long long ull; constexpr uint mod{998244353}; // 模运算辅助函数 constexpr uint plus(const uint &x, const uint &y) { return x + y >= mod ? x + y - mod : x + y; } constexpr uint minus(const uint &x, const uint &y) { return x < y ? x - y + mod : x - y; } constexpr uint power(uint x, uint y) { uint s(1); while (y > 0) { if (y & 1) s = ull(s) * x % mod; x = ull(x) * x % mod; y >>= 1; } return s; } // FWT 相关函数 - 用于加速 OR 和 AND 卷积 constexpr void FWT_or(const uint &n, uint *const f) { for (uint l = 1; l != n; l <<= 1) { for (uint *i = f; i != f + n; i += l << 1) { for (uint *j = i; j != i + l; j++) { *(j + l) = plus(*(j + l), *j); } } } } constexpr void IFWT_or(const uint &n, uint *const f) { for (uint l = 1; l != n; l <<= 1) { for (uint *i = f; i != f + n; i += l << 1) { for (uint *j = i; j != i + l; j++) { *(j + l) = minus(*(j + l), *j); } } } } constexpr void FWT_and(const uint &n, uint *const f) { for (uint l = 1; l != n; l <<= 1) { for (uint *i = f; i != f + n; i += l << 1) { for (uint *j = i; j != i + l; j++) { *j = plus(*j, *(j + l)); } } } } constexpr void IFWT_and(const uint &n, uint *const f) { for (uint l = 1; l != n; l <<= 1) { for (uint *i = f; i != f + n; i += l << 1) { for (uint *j = i; j != i + l; j++) { *j = minus(*j, *(j + l)); } } } } // Berlekamp-Massey 算法 - 用于寻找线性递推关系 constexpr uint BM(const uint &n, const uint *const f, uint *const g) { uint lst[n + 1] {}, tmp[n + 1] {}; uint len(0); uint lstpos(-1), lstlen(0); uint lstdel(0); for (uint i = 0; i != n; i++) { uint val(0); for (uint j = 1; j <= len; j++) { val = (val + ull(f[i - j]) * g[j]) % mod; } if (val == f[i]) continue; if (lstpos == -1) { len = i + 1; std::fill(g + 1, g + i + 2, 0); lstpos = i, lstlen = 0; lstdel = f[i]; continue; } const uint del(minus(f[i], val)); const uint coef(ull(del)*power(lstdel, mod - 2) % mod); std::copy(g + 1, g + len + 1, tmp + 1); for (uint j = 1; j <= lstlen; j++) { g[j + (i - lstpos)] = (g[j + (i - lstpos)] + ull(mod - coef) * lst[j]) % mod; } g[i - lstpos] = plus(g[i - lstpos], coef); std::tie(lstlen, len) = std::make_tuple(len, std::max(len, lstlen + (i - lstpos))); lstpos = i, lstdel = del; std::copy(tmp + 1, tmp + lstlen + 1, lst + 1); } return len; } // NTT 相关函数 - 用于多项式运算 constexpr uint N{8}; uint w[1 << N | 1]; constexpr uint getn(const uint &n) { return 1 << std::__lg(n - 1) + 1; } inline void init() { w[1 << N] = power(3, mod - 1 >> N + 2); for (int i = N; i != 0; i--) { w[1 << i - 1] = ull(w[1 << i]) * w[1 << i] % mod; } w[0] = 1; for (int i = 1; i != 1 << N; i++) { w[i] = ull(w[i & i - 1]) * w[i & -i] % mod; } } inline void DIF(const uint &n, uint *const f) { for (uint l = n >> 1; l != 0; l >>= 1) { for (uint *i = f, *o = w; i != f + n; i += l << 1, o++) { for (uint *j = i; j != i + l; j++) { const uint t(ull(*(j + l)) **o % mod); *(j + l) = minus(*j, t), *j = plus(*j, t); } } } } inline void DIT(const uint &n, uint *const f) { for (uint l = 1; l != n; l <<= 1) { for (uint *i = f, *o = w; i != f + n; i += l << 1, o++) { for (uint *j = i; j != i + l; j++) { const uint t(*(j + l)); *(j + l) = ull(*j + mod - t) **o % mod, *j = plus(*j, t); } } } std::reverse(f + 1, f + n); const uint t(mod - (mod - 1 >> std::__lg(n))); for (uint *i = f; i != f + n; i++) { *i = ull(*i) * t % mod; } } // 多项式运算函数 inline void multiply(const uint &n, uint *const f, uint *const g) { DIF(n, f), DIF(n, g); for (uint i = 0; i != n; i++) { f[i] = ull(f[i]) * g[i] % mod; } DIT(n, f); } inline void inverse(const uint &n, const uint *const f, uint *const g) { uint *t0(new uint[n << 1] {}), *t1(new uint[n << 1] {}); t1[0] = power(f[0], mod - 2); for (uint l = 1; l != n; l <<= 1) { std::copy(f, f + (l << 1), t0); std::fill(t0 + (l << 1), t0 + (l << 2), 0); DIF(l << 2, t0), DIF(l << 2, t1); for (uint i = 0; i != l << 2; i++) { t1[i] = (2 + ull(mod - t0[i]) * t1[i]) % mod * t1[i] % mod; } DIT(l << 2, t1); std::fill(t1 + (l << 1), t1 + (l << 2), 0); } std::copy(t1, t1 + n, g); delete[] t0, delete[] t1; } // 线性递推计算 inline void linear_recurrence(const uint &n, const uint *const f, uint *const g, const int &k) { std::vector<uint> t0, t1, t2, t3; t0.assign({0, 1}); t1.assign({1}); t2.assign(f, f + n + 1); for (int i = k; i; i >>= 1) { if (i & 1) { multiply(t1, t0, t3); modulo(t3, t2, t1); } multiply(t0, t0, t3); modulo(t3, t2, t0); } std::fill(g, g + n, 0); std::copy(t1.begin(), t1.end(), g); } // 主函数 uint wt[1 << 17 | 1], coef[1 << 17 | 1]; uint val[2][1 << 17 | 1]; uint f[85][2][1 << 17 | 1], g[2][1 << 17 | 1], h[2][1 << 17 | 1]; uint arr[85], arr1[85], arr2[85], arr3[85]; uint ans[1 << 17 | 1]; int main() { init(); int n, k, x; uint p; scanf("%d%u%d%d", &n, &p, &k, &x); // 读取权值数组 for (int i = 0; i != 1 << n; i++) { scanf("%u", wt + i); } // 预处理概率值 const uint p2(power(2, mod - 1 - n)); for (int i = 0; i != 1 << n; i++) { val[0][i] = ull(p2) * p % mod; } FWT_or(1 << n, val[0]); for (int i = 0; i != 1 << n; i++) { val[1][i] = ull(p2) * (mod + 1 - p) % mod; } FWT_and(1 << n, val[1]); // 随机系数用于线性递推 for (int i = 0; i != 1 << n; i++) { coef[i] = rand<uint>(1, mod - 1); } // 初始化状态 std::fill(f[0][0], f[0][0] + (1 << n), 0); std::fill(f[0][1], f[0][1] + (1 << n), 0); f[0][0][x] = 1; // 模拟前 80 步 for (int i = 1; i <= 80; i++) { // OR 操作转移 std::copy(f[i - 1][0], f[i - 1][0] + (1 << n), g[0]), FWT_or(1 << n, g[0]); std::copy(f[i - 1][1], f[i - 1][1] + (1 << n), g[1]), FWT_or(1 << n, g[1]); for (int j = 0; j != 1 << n; j++) { g[0][j] = ull(g[0][j]) * val[0][j] % mod; g[1][j] = ull(g[1][j]) * val[0][j] % mod; } IFWT_or(1 << n, g[0]); IFWT_or(1 << n, g[1]); std::copy(g[0], g[0] + (1 << n), h[0]); std::copy(g[1], g[1] + (1 << n), h[1]); // AND 操作转移 std::copy(f[i - 1][0], f[i - 1][0] + (1 << n), g[0]), FWT_and(1 << n, g[0]); std::copy(f[i - 1][1], f[i - 1][1] + (1 << n), g[1]), FWT_and(1 << n, g[1]); for (int j = 0; j != 1 << n; j++) { g[0][j] = ull(g[0][j]) * val[1][j] % mod; g[1][j] = ull(g[1][j]) * val[1][j] % mod; } IFWT_and(1 << n, g[0]); IFWT_and(1 << n, g[1]); // 合并结果并更新权值 for (int j = 0; j != 1 << n; j++) { h[0][j] = plus(h[0][j], g[0][j]); h[1][j] = plus(h[1][j], g[1][j]); h[1][j] = (h[1][j] + ull(h[0][j]) * wt[j]) % mod; } std::copy(h[0], h[0] + (1 << n), f[i][0]); std::copy(h[1], h[1] + (1 << n), f[i][1]); // 计算线性递推的系数 for (int j = 0; j != 1 << n; j++) { arr[i] = (arr[i] + ull(f[i][1][j]) * coef[j]) % mod; } // 如果 K 较小,直接输出 if (i == k) { for (int j = 0; j != 1 << n; j++) { printf("%u%c", f[i][1][j], j + 1 == 1 << n ? '\n' : ' '); } return 0; } } // 对于大 K,使用线性递推 arr[0] = 0; const int len(BM(81, arr, arr1)); for (int i = 1; i <= len; i++) { arr2[len - i] = minus(0, arr1[i]); } arr2[len] = 1; linear_recurrence(len, arr2, arr3, k); // 计算最终结果 std::fill(ans, ans + (1 << n), 0); for (int i = 0; i < len; i++) { for (int j = 0; j != 1 << n; j++) { ans[j] = (ans[j] + ull(f[i][1][j]) * arr3[i]) % mod; } } // 输出结果 for (int j = 0; j != 1 << n; j++) { printf("%u%c", ans[j], j + 1 == 1 << n ? '\n' : ' '); } return 0; }
复杂度分析
- 时间复杂度:O(n·2^n + 80·n·2^n + L^2·log K),其中 L 是线性递推的阶数(约 80)
- 空间复杂度:O(80·2^n),用于存储前 80 步的状态
该算法通过 FWT 加速位运算转移,结合线性递推处理大 K 值,高效解决了题目约束下的问题。
- 1
信息
- ID
- 3152
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 10
- 标签
- 递交数
- 5
- 已通过
- 1
- 上传者