1 条题解

  • 0
    @ 2025-10-27 16:07:11

    「广义线段树的区间分解与距离求和」题解

    题目分析

    本题围绕广义线段树展开,核心任务是处理两类操作:

    1. 构建广义线段树(根据给定的划分规则);
    2. 对每个查询,计算指定节点到“区间分解结果”中所有节点的距离之和。

    关键概念:

    • 广义线段树:非叶子节点可任意划分区间(只要满足l ≤ m < r),节点按先序遍历编号,总节点数为2n-1
    • 区间分解S[l,r]:用最少的线段树节点覆盖区间[l,r],节点间互不重叠。
    • 节点距离d(u,v):两节点在树上的最短路径边数,公式为d(u,v) = depth[u] + depth[v] - 2×depth[LCA(u,v)](LCA为最近公共祖先)。

    解法步骤

    步骤1:构建广义线段树

    需记录每个节点的关键信息:区间[L, R]、左孩子、右孩子、父节点。

    • 构建方式:用栈模拟先序遍历(避免递归栈溢出),按输入的n-1个划分位置依次拆分非叶子节点。
    • 细节
      • 根节点为1号,对应区间[1, n]
      • 每个非叶子节点按输入的划分位置m拆分为[L, m](左孩子)和[m+1, R](右孩子)。
      • 栈中先压右孩子再压左孩子(保证左孩子先处理,符合先序遍历)。

    步骤2:预处理深度和LCA

    距离计算依赖节点深度和LCA,需提前预处理:

    1. 节点深度:用BFS遍历树,根节点深度为0,孩子深度=父节点深度+1。
    2. LCA(最近公共祖先):用二进制 lifting(倍增法) 预处理:
      • 定义up[k][u]为节点u向上跳2^k步的祖先。
      • 初始化up[0][u] = 父节点[u],再通过up[k][u] = up[k-1][up[k-1][u]]递推更高层。
      • 查询LCA时,先将两节点拉至同一深度,再共同上跳至相遇。

    步骤3:区间分解(求S[l,r])

    对查询的[l,r],用栈迭代分解:

    • 栈中存储(当前节点, 待匹配区间[cur_l, cur_r])
    • 若当前节点区间完全包含于[cur_l, cur_r],则加入S[l,r]。
    • 若部分重叠,递归处理左右孩子(先压右孩子再压左孩子)。

    步骤4:计算距离和

    对S[l,r]中的每个节点v,用距离公式d(u,v) = depth[u] + depth[v] - 2×depth[LCA(u,v)]计算距离,累加后输出。

    代码实现

    import sys
    from collections import deque
    
    def main():
        sys.setrecursionlimit(1 << 25)
        n = int(sys.stdin.readline())
        m_list = list(map(int, sys.stdin.readline().split()))  # n-1个划分位置
        total_nodes = 2 * n - 1  # 总节点数
    
        # 初始化节点信息:L[u], R[u]为区间;left/right_child[u]为左右孩子;parent[u]为父节点
        L = [0] * (total_nodes + 2)  # 节点编号从1开始
        R = [0] * (total_nodes + 2)
        left_child = [0] * (total_nodes + 2)
        right_child = [0] * (total_nodes + 2)
        parent = [0] * (total_nodes + 2)
    
        # 用栈构建广义线段树(先序遍历)
        idx = 1  # 当前节点编号
        m_ptr = 0  # 划分位置的指针
        # 栈元素:(当前区间L, 当前区间R, 父节点, 是否为左孩子)
        stack = [(1, n, 0, False)]
        while stack:
            cur_L, cur_R, p, is_left = stack.pop()
            u = idx
            idx += 1
            L[u] = cur_L
            R[u] = cur_R
            # 记录父节点
            if p != 0:
                parent[u] = p
                if is_left:
                    left_child[p] = u
                else:
                    right_child[p] = u
            # 叶子节点无需拆分
            if cur_L == cur_R:
                continue
            # 非叶子节点,按m_list拆分
            m = m_list[m_ptr]
            m_ptr += 1
            # 先压右孩子(后处理),再压左孩子(先处理)
            stack.append((m + 1, cur_R, u, False))  # 右孩子
            stack.append((cur_L, m, u, True))       # 左孩子
    
        # 预处理节点深度(BFS)
        depth = [0] * (total_nodes + 2)
        q = deque()
        q.append(1)  # 根节点
        while q:
            u = q.popleft()
            if left_child[u]:
                depth[left_child[u]] = depth[u] + 1
                q.append(left_child[u])
            if right_child[u]:
                depth[right_child[u]] = depth[u] + 1
                q.append(right_child[u])
    
        # 预处理LCA的二进制lifting表
        max_level = 20  # 足够覆盖2^20 > 2e5
        up = [[0] * (total_nodes + 2) for _ in range(max_level)]
        for u in range(1, total_nodes + 1):
            up[0][u] = parent[u] if parent[u] != 0 else u  # 根节点的父节点设为自己
        for k in range(1, max_level):
            for u in range(1, total_nodes + 1):
                up[k][u] = up[k-1][up[k-1][u]]
    
        # 计算LCA
        def lca(u, v):
            if depth[u] < depth[v]:
                u, v = v, u
            # 拉平深度
            for k in range(max_level-1, -1, -1):
                if depth[u] - (1 << k) >= depth[v]:
                    u = up[k][u]
            if u == v:
                return u
            # 共同上跳
            for k in range(max_level-1, -1, -1):
                if up[k][u] != up[k][v]:
                    u = up[k][u]
                    v = up[k][v]
            return parent[u]
    
        # 区间分解:返回S[l,r]
        def decompose(l, r):
            res = []
            stack = [(1, l, r)]  # (当前节点, 待匹配l, 待匹配r)
            while stack:
                u, cur_l, cur_r = stack.pop()
                # 节点区间与待匹配区间无重叠
                if R[u] < cur_l or L[u] > cur_r:
                    continue
                # 节点区间完全包含于待匹配区间
                if cur_l <= L[u] and R[u] <= cur_r:
                    res.append(u)
                    continue
                # 部分重叠,处理左右孩子
                if right_child[u]:
                    stack.append((right_child[u], cur_l, cur_r))
                if left_child[u]:
                    stack.append((left_child[u], cur_l, cur_r))
            return res
    
        # 处理查询
        m = int(sys.stdin.readline())
        for _ in range(m):
            u, l, r = map(int, sys.stdin.readline().split())
            s = decompose(l, r)
            total = 0
            for v in s:
                ancestor = lca(u, v)
                total += depth[u] + depth[v] - 2 * depth[ancestor]
            print(total)
    
    if __name__ == "__main__":
        main()
    

    复杂度分析

    • 构建线段树:O(n),每个节点处理一次。
    • 预处理深度:O(n),BFS遍历所有节点。
    • 预处理LCA:O(n log H),H为树深(log H约20)。
    • 每个查询:分解区间的节点数为O(log n)(广义线段树性质),每个节点的LCA查询为O(log H),故总复杂度为O(m log n log H)。

    整体可应对n, m ≤ 2×10^5的规模。

    • 1

    信息

    ID
    3653
    时间
    2000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    1
    已通过
    1
    上传者