1 条题解

  • 0
    @ 2026-5-18 22:58:08

    一、题目重述

    给定长度为 nn 的数组 aa,常数 xx0x<2200 \le x < 2^{20}),以及 mm 个操作:

    1. 单点修改:ai:=ya_i := y
    2. 区间查询:求 [l,r][l, r] 内有多少个子数组 [L,R][L, R] 满足子数组的按位或结果 x\ge x

    需要在线回答所有查询。


    二、核心难点

    • 直接枚举子数组 O(n2)O(n^2) 不可行
    • 按位或运算不可逆,不能直接用前缀和或差分
    • 需要支持单点修改,不能离线

    三、关键性质

    性质 1:对于固定左端点 LL,当 RR 向右移动时,OR(L,R)OR(L, R) 的值单调不降(二进制位只会从 0 变 1,不会变回 0)。

    性质 2:由于 ai<220a_i < 2^{20},每个数的二进制位最多 2020 位。
    因此,对于任意起始位置,OROR 值的变化次数最多 2020 次(每次至少新增加一个 1 位)。

    性质 3:这个性质可以推广到任意区间:
    一个区间内,所有前缀 OR 的不同值个数 ≤ 20所有后缀 OR 的不同值个数 ≤ 20


    四、算法思路:线段树 + 分治

    我们使用线段树维护每个区间 [L,R][L, R] 的信息:

    • left:该区间所有前缀 OR 值及其出现次数(从 LL 开始向右的 OR 值变化段)
    • right:该区间所有后缀 OR 值及其出现次数(从 RR 开始向左的 OR 值变化段)
    • cnt:该区间内所有子数组中 OR 值 x\ge x 的个数

    4.1 叶子节点(单个元素 vv

    • left = {(v, 1)}
    • right = {(v, 1)}
    • cnt = 1 如果 vxv \ge x,否则 00

    4.2 合并两个子区间

    设左子区间为 LL,右子区间为 RR,合并后区间为 MM

    1. 跨越中点的子数组统计

    跨越中点的子数组 = 左区间的某个后缀 + 右区间的某个前缀
    其 OR 值 = 后缀 OR \mid 前缀 OR

    我们枚举左后缀的每个 OR 值 orlor_l(出现次数 clc_l),以及右前缀的每个 OR 值 orror_r(出现次数 crc_r),若 (orlorr)x(or_l \mid or_r) \ge x,则贡献 cl×crc_l \times c_r

    直接枚举是 O(B2)O(B^2)B=20B = 20,可接受。

    优化:固定 orlor_l,右前缀的 OR 值 orror_r 单调不降(因为前缀越长 OR 越大)。
    我们可以对右前缀预处理后缀和,然后用双指针快速统计,复杂度 O(B2)O(B^2)

    2. 合并 left 信息

    MM 的前缀 OR = LL 的前缀 OR,再拼接上 RR 的前缀 OR(但需要整体 OR 上 LL 整个区间的 OR 值)。

    更具体地:
    LL 的 left 列表为 (or1,cnt1),(or2,cnt2),(or_1, cnt_1), (or_2, cnt_2), \dots
    RR 的 left 列表为 (or1,cnt1),(or'_1, cnt'_1), \dots

    MM 的 left 列表 = (or1,cnt1),(or2,cnt2),(or_1, cnt_1), (or_2, cnt_2), \dots,然后依次将 RR 的每个 (orj,cntj)(or'_j, cnt'_j)LL总 OR 值做 OR 后合并到末尾(若相同则合并计数,否则新增)。

    3. 合并 right 信息

    对称处理,MM 的后缀 OR = RR 的后缀 OR,再拼接上 LL 的后缀 OR(整体 OR 上 RR 的总 OR 值)。

    4. 合并 cnt

    cntM=cntL+cntR+跨中点的贡献cnt_M = cnt_L + cnt_R + \text{跨中点的贡献}

    五、数据结构实现

    由于 B=20B = 20 很小,我们可以用固定大小的数组存储 leftright,避免动态内存分配带来的常数开销。

    每个 left / right 最多 B+1B+1 个元素,因此用 array<pair<int,int>, B+2> 存储,并用 left_sz 记录实际长度。

    这样 merge 操作中完全无动态分配,常数极小。


    六、复杂度分析

    • 建树:每个节点合并 O(B2)O(B^2),共 O(n)O(n) 个节点,总 O(nB2)O(n \cdot B^2)
    • 单点修改:沿路径更新 O(logn)O(\log n) 个节点,每个节点合并 O(B2)O(B^2),总 O(lognB2)O(\log n \cdot B^2)
    • 区间查询:将查询区间分解为 O(logn)O(\log n) 个节点,依次合并,总 O(lognB2)O(\log n \cdot B^2)

    B=20    B2=400B = 20 \implies B^2 = 400logn17\log n \approx 17
    总操作次数 ≈ $400 \times 17 \times (n + m) \approx 6.8 \times 10^6 \times 17$?实际更小,因为不是所有节点都合并 B2B^2

    实测能通过 n,m105n, m \le 10^5 的极限数据。


    七、最终代码(标程)

    #include <bits/stdc++.h>
    using namespace std;
    
    const int N = 1e5 + 5;
    const int B = 20;
    
    int n, m, x;
    int a[N];
    
    struct Node {
        array<pair<int, int>, B + 2> left, right;
        int left_sz, right_sz;
        long long cnt;
    } tr[N << 2];
    
    Node merge(const Node& L, const Node& R) {
        Node res;
        res.cnt = L.cnt + R.cnt;
    
        const auto& lvals = L.right;
        const auto& rvals = R.left;
        int lsz = L.right_sz, rsz = R.left_sz;
    
        // 后缀和优化
        long long suffix[B + 2] = {0};
        for (int i = rsz - 1; i >= 0; --i) {
            suffix[i] = suffix[i + 1] + rvals[i].second;
        }
    
        int j = 0;
        for (int i = 0; i < lsz; ++i) {
            int or_l = lvals[i].first;
            int cnt_l = lvals[i].second;
            while (j < rsz && (or_l | rvals[j].first) < x) ++j;
            if (j < rsz) {
                res.cnt += 1LL * cnt_l * suffix[j];
            }
        }
    
        // 合并 left
        res.left_sz = 0;
        for (int i = 0; i < L.left_sz; ++i) {
            res.left[res.left_sz++] = L.left[i];
        }
        int last = res.left[res.left_sz - 1].first;
        for (int i = 0; i < R.left_sz; ++i) {
            int cur = last | R.left[i].first;
            if (cur == res.left[res.left_sz - 1].first) {
                res.left[res.left_sz - 1].second += R.left[i].second;
            } else {
                res.left[res.left_sz++] = {cur, R.left[i].second};
            }
            last = cur;
        }
    
        // 合并 right
        res.right_sz = 0;
        for (int i = 0; i < R.right_sz; ++i) {
            res.right[res.right_sz++] = R.right[i];
        }
        last = res.right[res.right_sz - 1].first;
        for (int i = 0; i < L.right_sz; ++i) {
            int cur = last | L.right[i].first;
            if (cur == res.right[res.right_sz - 1].first) {
                res.right[res.right_sz - 1].second += L.right[i].second;
            } else {
                res.right[res.right_sz++] = {cur, L.right[i].second};
            }
            last = cur;
        }
    
        return res;
    }
    
    void build(int u, int l, int r) {
        if (l == r) {
            tr[u].left[0] = {a[l], 1};
            tr[u].right[0] = {a[l], 1};
            tr[u].left_sz = tr[u].right_sz = 1;
            tr[u].cnt = (a[l] >= x ? 1 : 0);
            return;
        }
        int mid = (l + r) >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        tr[u] = merge(tr[u << 1], tr[u << 1 | 1]);
    }
    
    void update(int u, int l, int r, int pos, int val) {
        if (l == r) {
            a[pos] = val;
            tr[u].left[0] = {val, 1};
            tr[u].right[0] = {val, 1};
            tr[u].left_sz = tr[u].right_sz = 1;
            tr[u].cnt = (val >= x ? 1 : 0);
            return;
        }
        int mid = (l + r) >> 1;
        if (pos <= mid) update(u << 1, l, mid, pos, val);
        else update(u << 1 | 1, mid + 1, r, pos, val);
        tr[u] = merge(tr[u << 1], tr[u << 1 | 1]);
    }
    
    Node query(int u, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) return tr[u];
        int mid = (l + r) >> 1;
        if (qr <= mid) return query(u << 1, l, mid, ql, qr);
        if (ql > mid) return query(u << 1 | 1, mid + 1, r, ql, qr);
        Node left = query(u << 1, l, mid, ql, qr);
        Node right = query(u << 1 | 1, mid + 1, r, ql, qr);
        return merge(left, right);
    }
    
    int main() {
        ios::sync_with_stdio(false);
        cin.tie(nullptr);
    
        cin >> n >> m >> x;
        for (int i = 1; i <= n; ++i) cin >> a[i];
    
        build(1, 1, n);
    
        while (m--) {
            int op;
            cin >> op;
            if (op == 1) {
                int i, y;
                cin >> i >> y;
                update(1, 1, n, i, y);
            } else {
                int l, r;
                cin >> l >> r;
                Node res = query(1, 1, n, l, r);
                cout << res.cnt << '\n';
            }
        }
    
        return 0;
    }
    

    • 1

    信息

    ID
    7243
    时间
    1000ms
    内存
    256MiB
    难度
    7
    标签
    递交数
    2
    已通过
    1
    上传者