1 条题解
-
0
解题思路
题目要求计算数组 的逆序对数,其中 。
将 分成 个块,每个块长度为 。块内逆序对:每个块内元素为 ,比较时 为正常数,故大小关系完全由 决定。由于 是一个排列,块内逆序对数等于排列 的逆序对数 。总共有 个块,因此块内逆序对总数为 。
块间逆序对:考虑两个不同块 ,需要统计满足 的所有 对。由于 是 的排列,指数对 取遍所有有序对,因此块间逆序对总数与 的具体顺序无关。我们可以假设 ,此时每个块内元素为 且严格递增()。这样块内逆序对为 ,原问题的答案变为
[ \text{ans}=n\cdot\operatorname{inv}(q)+\text{块间逆序对数量}. ]现在问题转化为:给定序列 (奇数, 的排列),对每个 构造递增序列 ,求所有 时 与 之间形成的逆序对总数(即所有 且 的数对)。
两个块之间的逆序对计数
固定两个不同的值 ( 来自左边块, 来自右边块),设 。我们需要计算
[ F(x,y)=#{(a,b)\mid 0\le a,b\le m,\ x\cdot2^a > y\cdot2^b}. ]因为 是奇数, 和 的大小关系可以通过比较 与 以及 确定。
情况 1:
令 ,则 。
可以证明(通过归并排序的视角): [ F(x,y)=\begin{cases} \frac{(m-z)(m-z+1)}{2}, & 0\le z\le m,\ 0, & z>m. \end{cases} ] 当 时,,从而所有 ,逆序对数为 。
当 时,公式简化后为 。情况 2:
令 ,则 。
此时 [ F(x,y)=\begin{cases} (z+1)(m+1)+\sum_{i=z+1}^{m} i, & 0\le z\le m,\ (m+1)^2, & z>m. \end{cases} ] 当 时,求和为 ,因此 [ F(x,y)=(z+1)(m+1)+\frac{(m+z+1)(m-z)}{2}. ] 当 时,所有 恒成立,逆序对数为 。
快速计算所有块间逆序对
遍历 数组(按原顺序),维护一个值域树状数组(BIT),记录左侧已经出现过的值。对于当前值 ,需要累加所有左侧 与 的 。
由于 和 的取值范围是 (奇数),而 ,我们可以直接以数值为下标。
注意到 只需要枚举到 (约 ),因为更大的 对应的区间会超出值域或 值为常数。对于 :
枚举 ,区间 内的 均满足 。
对应的贡献为 (若 ,否则为 )。
区间端点取整: [ L_z = \left\lfloor\frac{y}{2^{z+1}}\right\rfloor+1,\qquad R_z = \left\lfloor\frac{y-1}{2^z}\right\rfloor. ] 若 且 ,则查询 BIT 中 内 的个数,乘以贡献累加。对于 :
枚举 ,区间 内的 满足 。
贡献 按上述公式计算。区间端点: [ L_z = 2^z y+1,\qquad R_z = 2^{z+1}y-1. ] 若 ,则查询 内的 个数,乘以 累加。对于每个 , 的枚举次数为 ,每次 BIT 查询 ,总复杂度 ,可以通过( 总和 )。
最后答案加上块内逆序对 并取模 。
实现细节
- 使用
long long计算中间结果,取模时注意正数。 - 树状数组大小设为 。
- 预处理 的幂(
1<<z)。 - 计算 可以用 BIT 或归并排序。
- 注意 可能为 ,公式仍然适用。
代码框架
#include <bits/stdc++.h> using namespace std; using ll = long long; const int MOD = 998244353; struct BIT { int n; vector<int> t; BIT(int _n) : n(_n), t(_n + 2, 0) {} void add(int i, int v) { for (; i <= n; i += i & -i) t[i] += v; } int sum(int i) { int s = 0; for (; i > 0; i -= i & -i) s += t[i]; return s; } int range(int l, int r) { if (l > r) return 0; return sum(r) - sum(l - 1); } }; ll inv_perm(vector<int>& q) { int k = q.size(); BIT bit(k); ll inv = 0; for (int i = k - 1; i >= 0; --i) { inv += bit.sum(q[i] + 1); bit.add(q[i] + 1, 1); } return inv % MOD; } void solve() { int n, k; cin >> n >> k; vector<int> p(n); for (int i = 0; i < n; ++i) cin >> p[i]; vector<int> q(k); for (int i = 0; i < k; ++i) cin >> q[i]; // 块内逆序对 ll inv_q = inv_perm(q); ll ans = (ll)n * inv_q % MOD; int m = k - 1; int maxVal = 2 * n - 1; BIT bit(maxVal); // 预处理 C(z) 和 D(z) 对于 z=0..20 (足够) const int L = 20; vector<ll> C(L + 2, 0), D(L + 2, 0); for (int z = 0; z <= L; ++z) { if (z <= m) { ll t = m - z; C[z] = t * (t + 1) / 2 % MOD; D[z] = ((ll)(z + 1) * (m + 1) + (ll)(m + z + 1) * (m - z) / 2) % MOD; } else { C[z] = 0; D[z] = (ll)(m + 1) * (m + 1) % MOD; } } // 遍历 p 数组 for (int i = 0; i < n; ++i) { int y = p[i]; // x < y for (int z = 0; ; ++z) { ll divisor = 1LL << (z + 1); ll Lz = y / divisor + 1; ll Rz = (y - 1) / (1LL << z); if (Lz > Rz || Rz < 1) break; if (Lz > maxVal) break; if (Rz > maxVal) Rz = maxVal; int cnt = bit.range(Lz, Rz); if (cnt) { ans = (ans + (ll)cnt * C[z]) % MOD; } } // x > y for (int z = 0; ; ++z) { ll Lz = (1LL << z) * y + 1; if (Lz > maxVal) break; ll Rz = (1LL << (z + 1)) * y - 1; if (Rz > maxVal) Rz = maxVal; int cnt = bit.range(Lz, Rz); if (cnt) { ans = (ans + (ll)cnt * D[z]) % MOD; } } bit.add(y, 1); } cout << ans << '\n'; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int t; cin >> t; while (t--) solve(); return 0; } - 使用
- 1
信息
- ID
- 7164
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 2
- 标签
- 递交数
- 1
- 已通过
- 1
- 上传者