1 条题解
-
0
针对这个问题,有几种或简单或复杂的组合解法,但我会描述一种动态规划的方法,我认为它更容易理解和实现。
假设我们已经固定了组成子序列
"abc"的三个字符的位置,分别记为 、 和 。那么,有多少个字符串在这些位置上包含所需的子序列呢?显然,如果其中某个位置上的字符已经确定(不是问号),并且与该位置期望的字符不匹配,那么包含该子序列的字符串数量就是 。否则,由于我们已经固定了三个字符,其他位置上的所有问号都可以任意取值——因此这样的字符串数量为 ,其中 是除了 、、 之外的其他位置上问号的数量。
这启发我们可以写出一个 的解法:枚举 、、,对于每一组三元组,计算在这些位置上包含所需子序列的字符串数量。
但是这样太慢了。我们注意到,对于每个这样的子序列,包含它的字符串数量为 ,其中 是字符串中问号的总数, 是 、、 这三个位置中包含问号的个数。因此,对于每个整数 (),我们可以计算恰好包含 个问号的、匹配
"abc"的子序列的数量,这样就能更快地解决问题。如何计算每个 对应的子序列数量呢?在我看来,最简单的方法是动态规划:
设 表示:在所有以位置 结尾的子序列中,匹配
"abc"的前 个字符,且包含恰好 个问号的子序列数量。这个 DP 的转移如果直接写是平方复杂度(因为需要枚举子序列中的下一个或上一个位置),但如果我们重新定义 为:在所有结尾位置不超过 的子序列中,匹配
"abc"的前 个字符,且包含恰好 个问号的子序列数量,就可以将转移加速到线性。因为每次转移要么取当前字符,要么跳过当前字符,因此可以在 时间内完成。最终,这个 DP 解法的时间复杂度为 。针对这个问题,有几种或简单或复杂的组合解法,但我会描述一种动态规划的方法,我认为它更容易理解和实现。假设我们已经固定了组成子序列
"abc"的三个字符的位置,分别记为 、 和 。那么,有多少个字符串在这些位置上包含所需的子序列呢?显然,如果其中某个位置上的字符已经确定(不是问号),并且与该位置期望的字符不匹配,那么包含该子序列的字符串数量就是 。否则,由于我们已经固定了三个字符,其他位置上的所有问号都可以任意取值——因此这样的字符串数量为 ,其中 是除了 、、 之外的其他位置上问号的数量。
这启发我们可以写出一个 的解法:枚举 、、,对于每一组三元组,计算在这些位置上包含所需子序列的字符串数量。
但是这样太慢了。我们注意到,对于每个这样的子序列,包含它的字符串数量为 ,其中 是字符串中问号的总数, 是 、、 这三个位置中包含问号的个数。因此,对于每个整数 (),我们可以计算恰好包含 个问号的、匹配
"abc"的子序列的数量,这样就能更快地解决问题。如何计算每个 对应的子序列数量呢?在我看来,最简单的方法是动态规划:
设 表示:在所有以位置 结尾的子序列中,匹配
"abc"的前 个字符,且包含恰好 个问号的子序列数量。这个 DP 的转移如果直接写是平方复杂度(因为需要枚举子序列中的下一个或上一个位置),但如果我们重新定义 为:在所有结尾位置不超过 的子序列中,匹配
"abc"的前 个字符,且包含恰好 个问号的子序列数量,就可以将转移加速到线性。因为每次转移要么取当前字符,要么跳过当前字符,因此可以在 时间内完成。最终,这个 DP 解法的时间复杂度为 。#include <bits/stdc++.h> using namespace std; const int MOD = int(1e9) + 7; const int N = 200043; const int K = 4; int add(int x, int y) { x += y; while(x >= MOD) x -= MOD; while(x < 0) x += MOD; return x; } int mul(int x, int y) { return (x * 1ll * y) % MOD; } int n; string s; int dp[N][K][K]; char buf[N]; int pow3[N]; int main() { scanf("%d", &n); scanf("%s", buf); s = buf; int cntQ = 0; for(auto c : s) if(c == '?') cntQ++; pow3[0] = 1; for(int i = 1; i < N; i++) pow3[i] = mul(pow3[i - 1], 3); dp[0][0][0] = 1; for(int i = 0; i < n; i++) for(int j = 0; j <= 3; j++) for(int k = 0; k <= 3; k++) { if(!dp[i][j][k]) continue; dp[i + 1][j][k] = add(dp[i + 1][j][k], dp[i][j][k]); if(j < 3 && (s[i] == '?' || s[i] - 'a' == j)) { int nk = (s[i] == '?' ? k + 1 : k); dp[i + 1][j + 1][nk] = add(dp[i + 1][j + 1][nk], dp[i][j][k]); } } int ans = 0; for(int i = 0; i <= 3; i++) if(cntQ >= i) ans = add(ans, mul(dp[n][3][i], pow3[cntQ - i])); printf("%d\n", ans); }
- 1
信息
- ID
- 6813
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 10
- 标签
- 递交数
- 1
- 已通过
- 1
- 上传者