1 条题解
-
0
题解
结构化观察
- 左到右依次插入
∧ / ∨
运算,且初值为 0。对任意一个比特位上的数列x = (a_{1j}, a_{2j}, …, a_{nj})
,最终结果只取决于x
和运算符的选择。不同比特之间互不干扰,因此我们可以把每一列看成一个长度为n
的 0-1 向量。 - 把所有列按
(a_{1j}, a_{2j}, …, a_{nj})
的字典序从小到大排序(代码中通过逐行稳定划分实现)。排序后有一个重要性质:对任意一组运算符,输出的二进制串必然形如若干个 0 后跟若干个 1,即“阈值串”。换句话说,存在一个分界点,使得排序后编号不超过分界的列输出 0,之后的列输出 1。 - 决定这条分界线只与运算符的位置有关。把一列视作对应的二进制数
v_j = Σ a_{ij} · 2^{n-i}
,再把某组运算符视作二进制数w
(∨
为 1,∧
为 0,同样从上到下拼接),可以证明排序后“第一列为 1”恰好是w
的上界:若v_j < w
则第j
列输出 1,否则输出 0。于是每条阈值串都对应一个半开区间(v_L, v_R]
。 - 因此对于查询串
r
,若它不是阈值串(0 和 1 混杂顺序错误),答案为 0。否则设Mx
为需要输出 0 的列中排名最大的编号,Mn
为需要输出 1 的列中排名最小的编号,则所有满足条件的运算符正好对应整数集合(v_{Mx}, v_{Mn}]
,答案就是区间长度v_{Mn} - v_{Mx}
。
实现细节
- 排序:维护列的编号数组
ord
。对每一行执行一次稳定划分,把当前位置的 0 放前面、1 放后面,即可在O(nm)
时间内得到字典序排列。 - 列值编码:对原始列计算对应的二进制值
val[j]
(从下到上拼接),并在末尾补上一个全 1 列的取值val[m+1]=2^n
,方便表示“全部为 1”的情况。 - 查询:读取目标串
r
,通过rk[i]
(列i
在排序后的位置)找出Mx
与Mn
:- 若存在某个 1 的排名在某个 0 之前,即
rk[Mn] < rk[Mx]
,说明r
不是阈值串,答案为 0; - 否则输出
(val[Mn] - val[Mx]) mod 1e9+7
。
- 若存在某个 1 的排名在某个 0 之前,即
所有运算都按列进行,时间复杂度
O(nm + q m)
,在题目的1000 × 5000 × 5000
范围内可行。参考实现
#include <bits/stdc++.h> #define For(i, j, k) for(int i = (j); i <= (k); ++i) #define ForDown(i, j, k) for(int i = (j); i >= (k); --i) #define debug(fmt, args...) fprintf(stderr, fmt, ##args), fflush(stderr) using namespace std; typedef long long ll; const int N = 1005, M = 5005, mod = 1e9 + 7; int n, m, q, a[M][N], ord[M], rk[M]; ll val[M]; char s[M]; int L[M], R[M]; inline int read() { char ch = getchar(); int x = 0,f = 1; while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();} while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar(); return x * f; } int main() { n = read(); m = read(); q = read(); For(i, 1, m) ord[i] = i; For(i, 1, n) { scanf("%s", s + 1); int c1 = 0, c2 = 0; For(j, 1, m) { if(s[ord[j]] == '0') L[++c1] = ord[j]; else R[++c2] = ord[j]; } For(j, 1, c1) ord[j] = L[j]; For(j, 1, c2) ord[j + c1] = R[j]; For(j, 1, m) a[j][i] = s[j] - '0'; } For(j, 1, m) ForDown(i, n, 1) val[j] = (val[j] << 1 | a[j][i]) % mod; For(i, 1, m) rk[ord[i]] = i; rk[m + 1] = ord[m + 1] = m + 1; For(i, 1, n) val[m + 1] = (val[m + 1] << 1 | 1) % mod; val[m + 1]++; For(i, 1, q) { scanf("%s", s + 1); int Mx = 0, Mn = m + 1; For(i, 1, m) if(s[i] == '0') Mx = (rk[Mx] < rk[i] ? i : Mx); else Mn = (rk[Mn] > rk[i] ? i : Mn); if(rk[Mn] < rk[Mx]) puts("0"); else printf("%lld\n", (val[Mn] - val[Mx] + mod) % mod); } return 0; }
- 左到右依次插入
- 1
信息
- ID
- 3390
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 7
- 标签
- 递交数
- 1
- 已通过
- 1
- 上传者