1 条题解
-
0
题解
问题分析
题目要求处理树上两种异或最大值查询:
- 子树查询:节点 ( x ) 的子树中所有节点的权值与 ( y ) 的异或最大值。
- 路径查询:节点 ( x ) 到 ( y ) 的路径上所有节点的权值与 ( z ) 的异或最大值。
核心挑战是高效处理大规模(( n, Q \leq 10^5 ))树上范围的异或最大值查询,需结合字典树(Trie)、可持久化技术和树结构算法。
关键思路
-
异或最大值的经典解法:
利用字典树(Trie)存储数字的二进制位(从高到低),查询时从最高位开始,优先选择与目标数当前位不同的路径,最大化异或结果。 -
子树查询的转化:
- 通过 DFS 序 将子树转化为区间:记录每个节点的进入时间(( in_time ))和离开时间(( out_time )),子树中所有节点对应区间 ([in_time[x], out_time[x]])。
- 用 可持久化 Trie 维护区间信息:按 DFS 序插入节点权值,每个版本对应前缀区间,查询时通过区间两端的版本差获取子树内的权值集合,再计算异或最大值。
-
路径查询的转化:
- 利用 LCA(最近公共祖先) 分解路径:将 ( x \to y ) 的路径拆分为 ( x \to \text{LCA} ) 和 ( y \to \text{LCA} )(去重)。
- 用 可持久化 Trie 维护根到节点的路径:每个节点的 Trie 版本包含从根到该节点的所有权值,查询时通过根到 ( x )、根到 ( y )、根到 (\text{LCA}) 及根到 (\text{LCA}) 父节点的 Trie 版本差,获取路径上的权值集合,再计算异或最大值。
代码实现
import sys from sys import stdin sys.setrecursionlimit(1 << 25) MOD = 10**9 + 7 MAX_BIT = 30 # 权值 <= 2^30,需31位(0~30) # 字典树节点:[左孩子(0), 右孩子(1), 计数] nodes = [] def init_trie(): global nodes nodes = [[-1, -1, 0]] # 初始根节点 def insert(old_root, x): """插入x到旧根为old_root的Trie,返回新根""" new_root = len(nodes) nodes.append(nodes[old_root].copy()) # 复制旧根 nodes[new_root][2] += 1 # 计数+1 curr = new_root for i in range(MAX_BIT, -1, -1): bit = (x >> i) & 1 # 当前节点的bit孩子 child = nodes[curr][bit] if child == -1: # 新建孩子节点 new_child = len(nodes) nodes.append([-1, -1, 0]) nodes[curr][bit] = new_child curr = new_child else: # 复制孩子节点并更新计数 new_child = len(nodes) nodes.append(nodes[child].copy()) nodes[new_child][2] += 1 nodes[curr][bit] = new_child curr = new_child return new_root def query_sub(root_out, root_in, y): """查询区间[in, out]内与y的异或最大值(root_out是out版本,root_in是in-1版本)""" max_xor = 0 curr_out = root_out curr_in = root_in for i in range(MAX_BIT, -1, -1): bit = (y >> i) & 1 desired = 1 - bit # 优先选择的位 # 计算desired方向的计数差 cnt_out = nodes[nodes[curr_out][desired]][2] if nodes[curr_out][desired] != -1 else 0 cnt_in = nodes[nodes[curr_in][desired]][2] if nodes[curr_in][desired] != -1 else 0 if cnt_out - cnt_in > 0: max_xor |= (1 << i) curr_out = nodes[curr_out][desired] curr_in = nodes[curr_in][desired] if nodes[curr_in][desired] != -1 else -1 else: # 选择原bit方向 cnt_out = nodes[nodes[curr_out][bit]][2] if nodes[curr_out][bit] != -1 else 0 cnt_in = nodes[nodes[curr_in][bit]][2] if nodes[curr_in][bit] != -1 else 0 if cnt_out - cnt_in > 0: curr_out = nodes[curr_out][bit] curr_in = nodes[curr_in][bit] if nodes[curr_in][bit] != -1 else -1 else: return 0 # 无数据(理论上不会发生) return max_xor def query_path(root_x, root_y, root_lca, root_plca, z): """查询x到y路径上与z的异或最大值(root_plca是LCA父节点的版本)""" max_xor = 0 curr_x = root_x curr_y = root_y curr_lca = root_lca curr_plca = root_plca if root_plca != -1 else 0 # 根的父节点用0(空) for i in range(MAX_BIT, -1, -1): bit = (z >> i) & 1 desired = 1 - bit # 计算desired方向的计数差:x + y - lca - plca cnt_x = nodes[nodes[curr_x][desired]][2] if nodes[curr_x][desired] != -1 else 0 cnt_y = nodes[nodes[curr_y][desired]][2] if nodes[curr_y][desired] != -1 else 0 cnt_lca = nodes[nodes[curr_lca][desired]][2] if nodes[curr_lca][desired] != -1 else 0 cnt_plca = nodes[nodes[curr_plca][desired]][2] if nodes[curr_plca][desired] != -1 else 0 total = cnt_x + cnt_y - cnt_lca - cnt_plca if total > 0: max_xor |= (1 << i) curr_x = nodes[curr_x][desired] if nodes[curr_x][desired] != -1 else -1 curr_y = nodes[curr_y][desired] if nodes[curr_y][desired] != -1 else -1 curr_lca = nodes[curr_lca][desired] if nodes[curr_lca][desired] != -1 else -1 curr_plca = nodes[curr_plca][desired] if nodes[curr_plca][desired] != -1 else -1 else: # 选择原bit方向 cnt_x = nodes[nodes[curr_x][bit]][2] if nodes[curr_x][bit] != -1 else 0 cnt_y = nodes[nodes[curr_y][bit]][2] if nodes[curr_y][bit] != -1 else 0 cnt_lca = nodes[nodes[curr_lca][bit]][2] if nodes[curr_lca][bit] != -1 else 0 cnt_plca = nodes[nodes[curr_plca][bit]][2] if nodes[curr_plca][bit] != -1 else 0 total = cnt_x + cnt_y - cnt_lca - cnt_plca if total > 0: curr_x = nodes[curr_x][bit] if nodes[curr_x][bit] != -1 else -1 curr_y = nodes[curr_y][bit] if nodes[curr_y][bit] != -1 else -1 curr_lca = nodes[curr_lca][bit] if nodes[curr_lca][bit] != -1 else -1 curr_plca = nodes[curr_plca][bit] if nodes[curr_plca][bit] != -1 else -1 else: return 0 # 无数据(理论上不会发生) return max_xor # LCA预处理 def preprocess_lca(n, adj, root): LOG = 20 up = [[-1] * (n + 1) for _ in range(LOG)] depth = [0] * (n + 1) # 迭代DFS求up[0]和depth stack = [(root, -1)] while stack: u, parent = stack.pop() up[0][u] = parent if parent != -1: depth[u] = depth[parent] + 1 for v in adj[u]: if v != parent: stack.append((v, u)) # 填充up表 for k in range(1, LOG): for u in range(1, n + 1): if up[k-1][u] != -1: up[k][u] = up[k-1][up[k-1][u]] return up, depth def get_lca(u, v, up, depth): if depth[u] < depth[v]: u, v = v, u # 提u到与v同深度 LOG = len(up) for k in range(LOG-1, -1, -1): if depth[u] - (1 << k) >= depth[v]: u = up[k][u] if u == v: return u # 同时上移找LCA for k in range(LOG-1, -1, -1): if up[k][u] != -1 and up[k][u] != up[k][v]: u = up[k][u] v = up[k][v] return up[0][u] def main(): input = sys.stdin.read().split() ptr = 0 n, Q = int(input[ptr]), int(input[ptr+1]) ptr +=2 v = list(map(int, input[ptr:ptr+n])) ptr +=n v = [0] + v # 1-based # 建树 adj = [[] for _ in range(n+1)] for _ in range(n-1): x = int(input[ptr]) y = int(input[ptr+1]) adj[x].append(y) adj[y].append(x) ptr +=2 # 预处理LCA up, depth = preprocess_lca(n, adj, 1) # 1. 处理子树查询:DFS序 + 可持久化Trie in_time = [0]*(n+1) out_time = [0]*(n+1) time = 1 # 迭代DFS获取in和out时间 stack = [(1, -1, False)] while stack: u, parent, visited = stack.pop() if visited: out_time[u] = time -1 continue in_time[u] = time time +=1 stack.append((u, parent, True)) # 孩子逆序入栈,保证DFS序正确 for v_node in reversed(adj[u]): if v_node != parent: stack.append((v_node, u, False)) # 构建sub_trie:按in_time顺序插入 init_trie() sub_versions = [0]*(n+1) # sub_versions[i]是前i个in_time的Trie根 for i in range(1, n+1): # 找到in_time为i的节点u u = -1 for candidate in range(1, n+1): if in_time[candidate] == i: u = candidate break sub_versions[i] = insert(sub_versions[i-1], v[u]) # 2. 处理路径查询:根到节点的可持久化Trie init_trie() # 重置Trie path_versions = [0]*(n+1) # path_versions[u]是根到u的Trie根 # BFS构建path_versions(保证父节点先处理) from collections import deque q = deque() q.append(1) path_versions[1] = insert(0, v[1]) # 根节点的父版本是0(空) parent = [-1]*(n+1) parent[1] = -1 while q: u = q.popleft() for v_node in adj[u]: if v_node != parent[u] and parent[v_node] == -1: parent[v_node] = u path_versions[v_node] = insert(path_versions[u], v[v_node]) q.append(v_node) # 处理查询 for _ in range(Q): parts = input[ptr:ptr+4] if parts[0] == '1': # 子树查询:1 x y x = int(parts[1]) y = int(parts[2]) ptr +=3 l = in_time[x] r = out_time[x] root_out = sub_versions[r] root_in = sub_versions[l-1] print(query_sub(root_out, root_in, y)) else: # 路径查询:2 x y z x = int(parts[1]) y = int(parts[2]) z = int(parts[3]) ptr +=4 l = get_lca(x, y, up, depth) pl = up[0][l] if l != 1 else -1 # LCA的父节点 root_x = path_versions[x] root_y = path_versions[y] root_l = path_versions[l] root_pl = path_versions[pl] if pl != -1 else -1 print(query_path(root_x, root_y, root_l, root_pl, z)) if __name__ == "__main__": main()复杂度分析
-
预处理:
- LCA 倍增预处理:( O(n \log n) )。
- DFS 序与子树 Trie 构建:( O(n \cdot 31) )(31 为二进制位数)。
- 路径 Trie 构建:( O(n \cdot 31) )。
-
查询:
- 子树查询:( O(31) )(Trie 遍历)。
- 路径查询:( O(31 + \log n) )(LCA 查询 + Trie 遍历)。
总复杂度为 ( O(n \log n + Q \cdot 31) ),可满足 ( n, Q \leq 10^5 ) 的数据范围。
- 1
信息
- ID
- 4925
- 时间
- 3000ms
- 内存
- 256MiB
- 难度
- 10
- 标签
- 递交数
- 1
- 已通过
- 1
- 上传者