1 条题解

  • 0
    @ 2025-10-19 19:31:01

    题解

    本题使用点分治 + 虚树 + 动态规划求解三棵树上的最长路径问题。

    算法思路:

    1. 问题建模

      • 给定三棵树 T1,T2,T3T_1, T_2, T_3,节点编号相同
      • 选择两个点 u,vu, v,使得 dis1(u,v)+dis2(u,v)+dis3(u,v)dis_1(u,v) + dis_2(u,v) + dis_3(u,v) 最大
      • 其中 disi(u,v)dis_i(u,v) 表示在树 TiT_iuuvv 的距离
    2. 点分治框架

      • 在树 T1T_1 上进行点分治
      • 找到重心 rtrt,将树分成若干子树
      • 对于每个分治中心,计算经过该中心的最优路径
    3. 虚树优化

      • 对于需要计算的点集,在树 T2T_2 上建立虚树
      • 虚树只包含关键节点和它们的 LCA
      • 使用 ST 表预处理 LCA,O(logn)O(\log n) 查询
    4. 动态规划

      • 状态 f[u][i]:表示子树 uu 中,与当前分治中心同在 T2T_2 的哪一侧(i{0,1}i \in \{0,1\})的最优点对
      • 转移:合并子树时,考虑不同侧的点对组合
      • 评价函数:$F(u,v) = d_1[u] + d_1[v] + dep_2[u] + dep_2[v] + dis_3(u,v)$
        • d1[u]d_1[u]uu 到分治中心在 T1T_1 上的距离
        • dep2[u]dep_2[u]uuT2T_2 上的深度
    5. 启发式合并

      • 使用优先队列按子树大小排序
      • 每次合并最小的两棵子树,建立虚树并计算贡献
      • 复杂度优化:启发式合并保证总复杂度可控
    6. 答案更新

      • 在虚树上 DP 时,枚举不同侧的点对更新答案
      • 递归处理所有子树

    时间复杂度O(nlog2n)O(n \log^2 n),点分治 O(logn)O(\log n) 层,每层虚树 DP O(nlogn)O(n \log n)

    这是一道综合性极强的树上问题,需要深刻理解点分治、虚树和树上 DP 的结合。

    #include<bits/stdc++.h>
    #include<ext/pb_ds/assoc_container.hpp>
    #include<ext/pb_ds/tree_policy.hpp>
    #include<ext/pb_ds/hash_policy.hpp>
    #define gt getchar
    #define pt putchar
    #define fst first
    #define scd second
    #define SZ(s) ((int)s.size())
    #define all(s) s.begin(),s.end()
    #define pb push_back
    #define eb emplace_back
    typedef long long ll;
    typedef double db;
    typedef long double ld;
    typedef unsigned long long ull;
    typedef unsigned int uint;
    const int N=1e5+5;
    using namespace std;
    using namespace __gnu_pbds;
    typedef pair<int,int> pii;
    template<class T,class I> inline bool chkmax(T &a,I b){return b>a?a=b,1:0;}
    template<class T,class I> inline bool chkmin(T &a,I b){return b<a?a=b,1:0;}
    inline bool __(char ch){return ch>=48&&ch<=57;}
    template<class T> inline void read(T &x){
    	x=0;bool sgn=0;static char ch=gt();
    	while(!__(ch)&&ch!=EOF) sgn|=(ch=='-'),ch=gt();
    	while(__(ch)) x=(x<<1)+(x<<3)+(ch&15),ch=gt();
    	if(sgn) x=-x;
    }
    template<class T,class ...I> inline void read(T &x,I &...x1){
    	read(x);
    	read(x1...);
    }
    template<class T> inline void print(T x){
    	static char stk[70];short top=0;
    	if(x<0) pt('-');
    	do{stk[++top]=x>=0?(x%10+48):(-(x%10)+48),x/=10;}while(x);
    	while(top) pt(stk[top--]);
    }
    template<class T> inline void printsp(T x){
    	print(x);
    	putchar(' ');
    }
    template<class T> inline void println(T x){
    	print(x);
    	putchar('\n');
    }
    int n;
    struct Graph{
    	struct Edge{
    		int to,nxt;
    		ll w;
    	}e[N<<1];
    	int head[N],cnt;
    	inline void add_edge(int f,int t,ll w){
    		e[++cnt].to=t;
    		e[cnt].w=w;
    		e[cnt].nxt=head[f];
    		head[f]=cnt;
    	}
    	inline void add_double(int f,int t,ll w){
    		add_edge(f,t,w);
    		add_edge(t,f,w);
    	}
    	ll dep[N];
    	int st[N][22],ti,dfn[N],fa[N];
    	void dfs(int u){
    		st[dfn[u]=++ti][0]=fa[u];
    		for(int i=head[u];i;i=e[i].nxt){
    			int v=e[i].to;
    			ll w=e[i].w;
    			if(v==fa[u]) continue;
    			fa[v]=u,dep[v]=dep[u]+w;
    			dfs(v);
    		}
    	}
    	inline int get(int u,int v){return dfn[u]<dfn[v]?u:v;}
    	inline void build_st(){
    		for(int j=1;(1<<j)<=n;++j){
    			for(int i=1;i+(1<<j)-1<=n;++i){
    				st[i][j]=get(st[i][j-1],st[i+(1<<(j-1))][j-1]);
    			}
    		}
    	}
    	inline int LCA(int u,int v){
    		if(u==v) return u;
    		u=dfn[u],v=dfn[v];
    		if(u>v) swap(u,v);
    		u++;
    		int k=__lg(v-u+1);
    		return get(st[u][k],st[v-(1<<k)+1][k]);
    	}
    	inline ll dis(int u,int v){return dep[u]+dep[v]-2*dep[LCA(u,v)];}
    	inline void input(){
    		for(ll u,v,w,i=1;i<n;++i){
    			read(u,v,w);
    			add_double(u,v,w);
    		}
    	}
    	inline void init(){dfs(1);build_st();}
    }T1,T2,T3;
    int all,rt,siz[N],mxsiz[N],anc[N];
    bool vis[N];
    void get_root(int u,int fa){
    	siz[u]=1,mxsiz[u]=0;
    	for(int i=T1.head[u];i;i=T1.e[i].nxt){
    		int v=T1.e[i].to;
    		if(vis[v]||v==fa) continue;
    		get_root(v,u);
    		siz[u]+=siz[v];
    		chkmax(mxsiz[u],siz[v]);
    	}
    	chkmax(mxsiz[u],all-siz[u]);
    	if(mxsiz[u]<mxsiz[rt]) rt=u;
    }
    vector<int> vec[N],nxt[N];
    int col[N],C[N],stk[N],top;
    ll ans,d[N];
    struct Cmp{inline bool operator()(const vector<int> &a,const vector<int> &b){return SZ(a)>SZ(b);}};
    vector<int> ed[N],pot;
    inline void ins(int u){
    	pot.eb(u);
    	if(!top){
    		stk[++top]=u;
    		return;
    	}
    	int lca=T2.LCA(u,stk[top]);
    	while(top>=2&&T2.dfn[stk[top-1]]>T2.dfn[lca]) ed[stk[top-1]].eb(stk[top]),top--;
    	if(T2.dfn[stk[top]]>T2.dfn[lca]) ed[lca].eb(stk[top--]);
    	if(lca!=stk[top]) stk[++top]=lca,pot.eb(lca);
    	stk[++top]=u;
    }
    inline void build_vt(vector<int> c){
    	for(int u:pot) ed[u].clear();
    	pot.clear();
    	if(c[0]!=1) ins(1);
    	for(int u:c) ins(u); 
    	while(top>1) ed[stk[top-1]].eb(stk[top]),top--;
    	top=0;
    }
    inline ll F(int u,int v){return d[u]+d[v]+T2.dep[u]+T2.dep[v]+T3.dis(u,v);}
    inline ll G(pii u,pii v){
    	if(!u.fst||!v.fst) return -1e18;
    	return max({F(u.fst,v.fst),F(u.fst,v.scd),F(u.scd,v.fst),F(u.scd,v.scd)});
    }
    inline pii operator+(const pii &a,const pii &b){
    	if(!a.fst) return b;
    	if(!b.fst) return a;
    	vector<int> vec={a.fst,a.scd,b.fst,b.scd};
    	pii res(0,0);
    	ll mx=0;
    	for(int i=0;i<SZ(vec);++i){
    		for(int j=i+1;j<SZ(vec);++j){
    			int u=vec[i],v=vec[j];
    			if(chkmax(mx,F(u,v))) res=pii(u,v);
    		}
    	}
    	return res;
    }
    pii f[N][2];
    void dfs(int u){
    	f[u][0]=f[u][1]={0,0};
    	if(C[u]) f[u][C[u]-1]={u,u};
    	for(int v:ed[u]){
    		dfs(v);
    		for(int i=0;i<2;++i) chkmax(ans,G(f[u][i],f[v][!i])-2*T2.dep[u]);
    		for(int i=0;i<2;++i) f[u][i]=f[u][i]+f[v][i];
    	}
    }
    void solve(int u){
    	vis[u]=1,d[u]=0;
    	vector<int> son;
    	for(int i=T1.head[u];i;i=T1.e[i].nxt){
    		int v=T1.e[i].to;
    		if(vis[v]) continue;
    		son.eb(v);
    		d[v]=T1.e[i].w;
    		auto dfs=[&](auto self,int x,int fa)->void{
    			col[x]=v;
    			for(int _=T1.head[x];_;_=T1.e[_].nxt){
    				int y=T1.e[_].to;
    				ll w=T1.e[_].w;
    				if(!vis[y]&&y!=fa) d[y]=d[x]+w,self(self,y,x);
    			}
    		};
    		dfs(dfs,v,u);
    	}
    	for(int v:vec[u]) if(v!=u) nxt[col[v]].eb(v);
    	priority_queue<vector<int>,vector<vector<int>>,Cmp> pq;
    	pq.push(vector<int>{u});
    	for(int v:son) pq.push(nxt[v]);
    	while(SZ(pq)>=2){
    		auto A=pq.top(); pq.pop();
    		auto B=pq.top(); pq.pop();
    		for(int v:A) C[v]=1;
    		for(int v:B) C[v]=2;
    		vector<int> c;
    		merge(all(A),all(B),back_inserter(c),[](int u,int v){return T2.dfn[u]<T2.dfn[v];});
    		build_vt(c);
    		dfs(1);
    		for(int v:c) C[v]=0;
    		pq.push(c);
    	}
    	for(int v:son){
    		all=siz[v],rt=0;
    		get_root(v,0);
    		get_root(rt,0);
    		swap(nxt[v],vec[rt]);
    		solve(rt);
    	}
    }
    signed main(){
    	read(n);
    	T1.input(),T2.input(),T3.input();
    	T2.init(),T3.init();
    	all=mxsiz[rt=0]=n;
    	get_root(1,0);
    	get_root(rt,0);
    	vec[rt].resize(n);
    	iota(all(vec[rt]),1);
    	sort(all(vec[rt]),[](int u,int v){return T2.dfn[u]<T2.dfn[v];});
    	solve(rt);
    	println(ans);
    	return 0;
    }
    
    • 1

    信息

    ID
    3472
    时间
    4000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    4
    已通过
    1
    上传者