1 条题解

  • 0
    @ 2025-11-4 9:11:25

    题解

    问题分析

    题目要求处理树上两种异或最大值查询:

    1. 子树查询:节点 ( x ) 的子树中所有节点的权值与 ( y ) 的异或最大值。
    2. 路径查询:节点 ( x ) 到 ( y ) 的路径上所有节点的权值与 ( z ) 的异或最大值。

    核心挑战是高效处理大规模(( n, Q \leq 10^5 ))树上范围的异或最大值查询,需结合字典树(Trie)、可持久化技术和树结构算法。

    关键思路

    1. 异或最大值的经典解法
      利用字典树(Trie)存储数字的二进制位(从高到低),查询时从最高位开始,优先选择与目标数当前位不同的路径,最大化异或结果。

    2. 子树查询的转化

      • 通过 DFS 序 将子树转化为区间:记录每个节点的进入时间(( in_time ))和离开时间(( out_time )),子树中所有节点对应区间 ([in_time[x], out_time[x]])。
      • 可持久化 Trie 维护区间信息:按 DFS 序插入节点权值,每个版本对应前缀区间,查询时通过区间两端的版本差获取子树内的权值集合,再计算异或最大值。
    3. 路径查询的转化

      • 利用 LCA(最近公共祖先) 分解路径:将 ( x \to y ) 的路径拆分为 ( x \to \text{LCA} ) 和 ( y \to \text{LCA} )(去重)。
      • 可持久化 Trie 维护根到节点的路径:每个节点的 Trie 版本包含从根到该节点的所有权值,查询时通过根到 ( x )、根到 ( y )、根到 (\text{LCA}) 及根到 (\text{LCA}) 父节点的 Trie 版本差,获取路径上的权值集合,再计算异或最大值。

    代码实现

    import sys
    from sys import stdin
    sys.setrecursionlimit(1 << 25)
    
    MOD = 10**9 + 7
    MAX_BIT = 30  # 权值 <= 2^30,需31位(0~30)
    
    # 字典树节点:[左孩子(0), 右孩子(1), 计数]
    nodes = []
    def init_trie():
        global nodes
        nodes = [[-1, -1, 0]]  # 初始根节点
    
    def insert(old_root, x):
        """插入x到旧根为old_root的Trie,返回新根"""
        new_root = len(nodes)
        nodes.append(nodes[old_root].copy())  # 复制旧根
        nodes[new_root][2] += 1  # 计数+1
        curr = new_root
        for i in range(MAX_BIT, -1, -1):
            bit = (x >> i) & 1
            # 当前节点的bit孩子
            child = nodes[curr][bit]
            if child == -1:
                # 新建孩子节点
                new_child = len(nodes)
                nodes.append([-1, -1, 0])
                nodes[curr][bit] = new_child
                curr = new_child
            else:
                # 复制孩子节点并更新计数
                new_child = len(nodes)
                nodes.append(nodes[child].copy())
                nodes[new_child][2] += 1
                nodes[curr][bit] = new_child
                curr = new_child
        return new_root
    
    def query_sub(root_out, root_in, y):
        """查询区间[in, out]内与y的异或最大值(root_out是out版本,root_in是in-1版本)"""
        max_xor = 0
        curr_out = root_out
        curr_in = root_in
        for i in range(MAX_BIT, -1, -1):
            bit = (y >> i) & 1
            desired = 1 - bit  # 优先选择的位
            # 计算desired方向的计数差
            cnt_out = nodes[nodes[curr_out][desired]][2] if nodes[curr_out][desired] != -1 else 0
            cnt_in = nodes[nodes[curr_in][desired]][2] if nodes[curr_in][desired] != -1 else 0
            if cnt_out - cnt_in > 0:
                max_xor |= (1 << i)
                curr_out = nodes[curr_out][desired]
                curr_in = nodes[curr_in][desired] if nodes[curr_in][desired] != -1 else -1
            else:
                # 选择原bit方向
                cnt_out = nodes[nodes[curr_out][bit]][2] if nodes[curr_out][bit] != -1 else 0
                cnt_in = nodes[nodes[curr_in][bit]][2] if nodes[curr_in][bit] != -1 else 0
                if cnt_out - cnt_in > 0:
                    curr_out = nodes[curr_out][bit]
                    curr_in = nodes[curr_in][bit] if nodes[curr_in][bit] != -1 else -1
                else:
                    return 0  # 无数据(理论上不会发生)
        return max_xor
    
    def query_path(root_x, root_y, root_lca, root_plca, z):
        """查询x到y路径上与z的异或最大值(root_plca是LCA父节点的版本)"""
        max_xor = 0
        curr_x = root_x
        curr_y = root_y
        curr_lca = root_lca
        curr_plca = root_plca if root_plca != -1 else 0  # 根的父节点用0(空)
        for i in range(MAX_BIT, -1, -1):
            bit = (z >> i) & 1
            desired = 1 - bit
            # 计算desired方向的计数差:x + y - lca - plca
            cnt_x = nodes[nodes[curr_x][desired]][2] if nodes[curr_x][desired] != -1 else 0
            cnt_y = nodes[nodes[curr_y][desired]][2] if nodes[curr_y][desired] != -1 else 0
            cnt_lca = nodes[nodes[curr_lca][desired]][2] if nodes[curr_lca][desired] != -1 else 0
            cnt_plca = nodes[nodes[curr_plca][desired]][2] if nodes[curr_plca][desired] != -1 else 0
            total = cnt_x + cnt_y - cnt_lca - cnt_plca
            if total > 0:
                max_xor |= (1 << i)
                curr_x = nodes[curr_x][desired] if nodes[curr_x][desired] != -1 else -1
                curr_y = nodes[curr_y][desired] if nodes[curr_y][desired] != -1 else -1
                curr_lca = nodes[curr_lca][desired] if nodes[curr_lca][desired] != -1 else -1
                curr_plca = nodes[curr_plca][desired] if nodes[curr_plca][desired] != -1 else -1
            else:
                # 选择原bit方向
                cnt_x = nodes[nodes[curr_x][bit]][2] if nodes[curr_x][bit] != -1 else 0
                cnt_y = nodes[nodes[curr_y][bit]][2] if nodes[curr_y][bit] != -1 else 0
                cnt_lca = nodes[nodes[curr_lca][bit]][2] if nodes[curr_lca][bit] != -1 else 0
                cnt_plca = nodes[nodes[curr_plca][bit]][2] if nodes[curr_plca][bit] != -1 else 0
                total = cnt_x + cnt_y - cnt_lca - cnt_plca
                if total > 0:
                    curr_x = nodes[curr_x][bit] if nodes[curr_x][bit] != -1 else -1
                    curr_y = nodes[curr_y][bit] if nodes[curr_y][bit] != -1 else -1
                    curr_lca = nodes[curr_lca][bit] if nodes[curr_lca][bit] != -1 else -1
                    curr_plca = nodes[curr_plca][bit] if nodes[curr_plca][bit] != -1 else -1
                else:
                    return 0  # 无数据(理论上不会发生)
        return max_xor
    
    # LCA预处理
    def preprocess_lca(n, adj, root):
        LOG = 20
        up = [[-1] * (n + 1) for _ in range(LOG)]
        depth = [0] * (n + 1)
        # 迭代DFS求up[0]和depth
        stack = [(root, -1)]
        while stack:
            u, parent = stack.pop()
            up[0][u] = parent
            if parent != -1:
                depth[u] = depth[parent] + 1
            for v in adj[u]:
                if v != parent:
                    stack.append((v, u))
        # 填充up表
        for k in range(1, LOG):
            for u in range(1, n + 1):
                if up[k-1][u] != -1:
                    up[k][u] = up[k-1][up[k-1][u]]
        return up, depth
    
    def get_lca(u, v, up, depth):
        if depth[u] < depth[v]:
            u, v = v, u
        # 提u到与v同深度
        LOG = len(up)
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = up[k][u]
        if u == v:
            return u
        # 同时上移找LCA
        for k in range(LOG-1, -1, -1):
            if up[k][u] != -1 and up[k][u] != up[k][v]:
                u = up[k][u]
                v = up[k][v]
        return up[0][u]
    
    def main():
        input = sys.stdin.read().split()
        ptr = 0
        n, Q = int(input[ptr]), int(input[ptr+1])
        ptr +=2
        v = list(map(int, input[ptr:ptr+n]))
        ptr +=n
        v = [0] + v  # 1-based
        # 建树
        adj = [[] for _ in range(n+1)]
        for _ in range(n-1):
            x = int(input[ptr])
            y = int(input[ptr+1])
            adj[x].append(y)
            adj[y].append(x)
            ptr +=2
    
        # 预处理LCA
        up, depth = preprocess_lca(n, adj, 1)
    
        # 1. 处理子树查询:DFS序 + 可持久化Trie
        in_time = [0]*(n+1)
        out_time = [0]*(n+1)
        time = 1
        # 迭代DFS获取in和out时间
        stack = [(1, -1, False)]
        while stack:
            u, parent, visited = stack.pop()
            if visited:
                out_time[u] = time -1
                continue
            in_time[u] = time
            time +=1
            stack.append((u, parent, True))
            # 孩子逆序入栈,保证DFS序正确
            for v_node in reversed(adj[u]):
                if v_node != parent:
                    stack.append((v_node, u, False))
    
        # 构建sub_trie:按in_time顺序插入
        init_trie()
        sub_versions = [0]*(n+1)  # sub_versions[i]是前i个in_time的Trie根
        for i in range(1, n+1):
            # 找到in_time为i的节点u
            u = -1
            for candidate in range(1, n+1):
                if in_time[candidate] == i:
                    u = candidate
                    break
            sub_versions[i] = insert(sub_versions[i-1], v[u])
    
        # 2. 处理路径查询:根到节点的可持久化Trie
        init_trie()  # 重置Trie
        path_versions = [0]*(n+1)  # path_versions[u]是根到u的Trie根
        # BFS构建path_versions(保证父节点先处理)
        from collections import deque
        q = deque()
        q.append(1)
        path_versions[1] = insert(0, v[1])  # 根节点的父版本是0(空)
        parent = [-1]*(n+1)
        parent[1] = -1
        while q:
            u = q.popleft()
            for v_node in adj[u]:
                if v_node != parent[u] and parent[v_node] == -1:
                    parent[v_node] = u
                    path_versions[v_node] = insert(path_versions[u], v[v_node])
                    q.append(v_node)
    
        # 处理查询
        for _ in range(Q):
            parts = input[ptr:ptr+4]
            if parts[0] == '1':
                # 子树查询:1 x y
                x = int(parts[1])
                y = int(parts[2])
                ptr +=3
                l = in_time[x]
                r = out_time[x]
                root_out = sub_versions[r]
                root_in = sub_versions[l-1]
                print(query_sub(root_out, root_in, y))
            else:
                # 路径查询:2 x y z
                x = int(parts[1])
                y = int(parts[2])
                z = int(parts[3])
                ptr +=4
                l = get_lca(x, y, up, depth)
                pl = up[0][l] if l != 1 else -1  # LCA的父节点
                root_x = path_versions[x]
                root_y = path_versions[y]
                root_l = path_versions[l]
                root_pl = path_versions[pl] if pl != -1 else -1
                print(query_path(root_x, root_y, root_l, root_pl, z))
    
    if __name__ == "__main__":
        main()
    

    复杂度分析

    • 预处理

      • LCA 倍增预处理:( O(n \log n) )。
      • DFS 序与子树 Trie 构建:( O(n \cdot 31) )(31 为二进制位数)。
      • 路径 Trie 构建:( O(n \cdot 31) )。
    • 查询

      • 子树查询:( O(31) )(Trie 遍历)。
      • 路径查询:( O(31 + \log n) )(LCA 查询 + Trie 遍历)。

    总复杂度为 ( O(n \log n + Q \cdot 31) ),可满足 ( n, Q \leq 10^5 ) 的数据范围。

    • 1

    信息

    ID
    4925
    时间
    3000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    1
    已通过
    1
    上传者