1 条题解

  • 0
    @ 2025-10-19 23:13:41

    题解

    题目类型:树上计数 / 结构化 DP(“小图形态”自动机)
    取模998244353

    要点一句话:本题统计一棵树中长度为 k 的简单路径条数(长度按“边数”计),通过“小图形态预生成 + 权重预处理 + 树上形态 DP + 中心汇总”实现,总体时间近似 (O(n \cdot \text{poly}(k)))。


    1. 目标与整体思路

    • 设给定树为 (T),我们要计算简单路径(不重复边)的数目,路径长度恰为 k(即包含 k 条边、k+1 个点)。
    • 一条长度为 k 的路径其“中心”有两类:
      • k 为偶数,则中心是某个点
      • k 为奇数,则中心是某条边
    • 代码围绕“中心”做计数:
      点中心用一套“无根小形态”(r2) 汇总;
      边中心用另一套“无根双中心形态”(r3) 汇总。

    为此,整份程序分四步预处理 + 两趟树 DP + 末尾汇总:

    1. build1:生成有根小形态库 r1(根在中心处的一侧),大小不超过 (\left\lfloor\frac{k+1}{2}\right\rfloor),并给出所有合并/转移表(trs1mg)。
    2. build2:生成无根小形态库 r2(中心为点)和双中心小形态库 r3(中心为边),并建立从有根到无根的转移(trs2)与二元合一的编号表(uni)。
    3. build3:为每个小形态生成它对应的**“骨架图”** G(一条条小边用 (u,v) 编码),用于后续计算该形态在长度为 k 的路径计数中的权重
    4. build4:给 r2 / r3 中每个无根(或双中心)形态计算一个权重 w,含义是:这个形态(作为“路径中心周围的局部结构”)能“承载”多少条长度为 k 的路径。权重通过一个与 k 相关的函数 calc(G,k) 得到。

    然后:

    1. dfs1(树 DP 第一趟):对大树从下向上,把各子树的有根形态(r1 的编号)逐步合并到父亲,得到“以当前点为根”的形态分布 f[u][*]
    2. dfs2(树 DP 第二趟):对每条子边做“去儿子贡献”的 reroot,使得每个子节点都能得到“除了它子树之外,父亲这一侧合并后的形态” g[child][*],从而在边中心场景下也能准确拼出双中心形态。

    最后:

    1. 按中心汇总
      • 点中心:对每个点 u,把来自各子树(有根形态分布 f)融合成一个无根形态分布 h(用 trs2 做有根→无根的转移),再乘以对应形态权重 r2[i].w 累加到答案(只取“可做中心”的形态:r2[i].chk() 即“最大子树 < 总大小的一半”)。
      • 边中心:对每条边 (fa[u],u),用 f[u][*](在 u 这侧的有根形态)与 g[u][*](另一侧父亲方向的有根形态)通过 uni 拼成一个双中心形态编号,从 r3 中取权重累加。

    2. 核心黑盒 calc(G,k) —— 小形态的“可承载路径数”

    2.1 “线图”思想(trans

    • 将一个小图 G看作新图的,两条边若在原图中“共享一个端点”,就在新图里连边。
    • 对长度为 k 的路径而言,就是在“线图”上走 k-1 步的边-边连续相邻序列。
    • 函数 trans(a) 就是在做“线图转换”。

    2.2 calc(a,k) 的快速实现

    • k=1:答案就是边数;
    • k=2:等价于统计“以某点为中枢的两条边”的数目,(\sum C_2(\deg(v)));
    • k=3:在树上即是“长度为 3 的路径”,可按边为中心累加;
    • k=4:给出了一段优化过的线性式(基于度数与一次扫描的累积量 d1,d2);
    • k≥5:递归“线图转换”直到降到上述已优化的边界。

    因为形态 G 的规模很小(与 k 有关),所以 calc 的复杂度是 poly(k),不会影响整体。

    2.3 build4 的连通性校正

    • 对形态的顶点子集做容斥(枚举 S):只对“连通”的子图记贡献,断开的直接置 0;
    • 连通子图,再判定其是“点中心无根形态”(r2)还是“边中心双形态”(r3),并把其 calc 值累乘累计为权重 w

    3. 小形态库与转移表

    3.1 r1:有根形态

    • 描述“以根为中心”的小树,大小不超过 (\lfloor (k+1)/2 \rfloor);
    • trs1[i][j]:把一个有根 i 与一个子块 j 合成后的形态编号;
    • mg[i][j]:把两个“以同一点为根”的有根形态“合并”后的形态编号(用于 dfs2)。

    3.2 r2:无根点中心形态

    • 大小不超过 k+1
    • chk():是否可作为中心(等价于“最大子树大小 < 总大小一半”,即重心判定)。

    3.3 r3:无根边中心(双中心)形态

    • 大小为偶数(2*siz),两端中心分别来自两个有根形态;
    • uni[i][j]:把两个有根编号 (i,j) 组合成一个双中心编号。

    3.4 trs2

    • 有根 → 无根的转移表(把多个有根儿块拼成一个“点中心”无根形态)。

    4. 树上两趟 DP

    4.1 第一趟 dfs1

    • f[u][*]:在 u子树中,所有以 u 为根的有根形态出现次数;
    • 把每个儿子的 f[v][*] 插入f[u][*]ins),用 trs1 做“加一子块”的自动机转移

    4.2 第二趟 dfs2

    • 对于每个儿子 v,我们需要得到“除去 v 那一边,其余所有子块+父边合起来u 处形成的有根形态”,即 g[v][*]
    • 做法:对 u 的所有儿子做前后缀合并(pre / suf),把除 v 外的形态拼起来,再通过 mgins 转移到 g[v][*]
    • 递归下去,每个点/边都能拿到“另一侧的有根形态”。

    5. 最终汇总

    • 点中心:对每个点 u

      • h 从空形态出发,依次把所有 f[v]vu 的儿子)合并到 hins2,即 trs2 转移);再把“父侧”贡献 g[u] 也合进来;
      • 对所有无根形态 i,若 r2[i].chk()(可作中心),则答案 += r2[i].w * h[i]
    • 边中心:对每条边 (fa[u],u)

      • 枚举有根编号 i(来自 u 侧的 f[u][i])与 j(来自父侧的 g[u][j]),
      • uni[i][j] 得到双中心编号,累加 r3[...] .w * f[u][i] * g[u][j]

    6. 复杂度与实现细节

    • 预处理各形态与转移表:规模由 k 决定,为 poly(k)
    • 两趟 DP 各 O(n * poly(k))
    • 汇总亦为 O(n * poly(k))
    • 整体近似 (O(n \cdot \text{poly}(k))),可在题目范围内顺利通过。

    7. 代码(与题给一致)

    (请在此补充题目的中文题解与思路描述。)

    #include<bits/stdc++.h>
    using namespace std;
    using ll=long long;
    #ifdef DEBUG
    template<class T>
    ostream& operator << (ostream &out,vector<T> a){
    	out<<'[';
    	for(T x:a)out<<x<<',';
    	return out<<']';
    }
    ostream& operator << (ostream &out,pair<int,int> a){
    	return out<<'('<<a.first<<','<<a.second<<')';
    }
    template<class T>
    vector<T> ary(T *a,int l,int r){
    	return vector<T>{a+l,a+1+r};
    }
    template<class T>
    void debug(T x){
    	cerr<<x<<endl;
    }
    template<class T,class...S>
    void debug(T x,S...y){
    	cerr<<x<<' ',debug(y...);
    }
    #else
    #define debug(...) void()
    #endif
    const int mod=998244353;
    using LL=__int128;
    using ve=vector<pair<int,int> >;
    ve trans(const ve &a){
    	ve b;
    	int len=a.size(),n=-1;
    	for(auto x:a)n=max({n,x.first,x.second});
    	n++;
    	vector<vector<int> >to(n);
    	for(int i=0;i<len;i++){
    		to[a[i].first].push_back(i);
    		to[a[i].second].push_back(i);
    	}
    	for(int i=0;i<n;i++){
    		int len=to[i].size();
    		for(int x=0;x<len;x++){
    			for(int y=0;y<x;y++){
    				b.push_back({to[i][y],to[i][x]});
    			}
    		}
    	}
    	return b;
    }
    int C2(int x){
    	return x*(x-1ll)/2%mod;
    }
    int calc(const ve &a,int k){
    	if(k==1)return a.size();
    	int n=-1;
    	for(auto x:a)n=max({n,x.first,x.second});
    	n++;
    	if(!k)return n;
    	if(k==2){
    		vector<int>deg(n);
    		for(auto x:a)deg[x.first]++,deg[x.second]++;
    		int ans=0;
    		for(int i=0;i<n;i++)(ans+=C2(deg[i]))%=mod;
    		return ans;
    	}
    	if(k==3){
    		vector<int>deg(n);
    		for(auto x:a)deg[x.first]++,deg[x.second]++;
    		int ans=0;
    		for(auto x:a)(ans+=C2(deg[x.first]+deg[x.second]-2))%=mod;
    		return ans;
    	}
    	if(k==4){
    		// debug(a);
    		vector<int>d1(n),d2(n);
    		for(auto x:a)d1[x.first]++,d1[x.second]++;
    		int ans=0;
    		for(auto x:a){
    			int u,v;
    			tie(u,v)=x;
    			ans=(ans+d2[u]*(d1[v]-1ll))%mod;
    			ans=(ans+d2[v]*(d1[u]-1ll))%mod;
    			(d2[u]+=d1[v]-1)%=mod;
    			(d2[v]+=d1[u]-1)%=mod;
    		}
    		for(int i=0;i<n;i++){
    			ans=(ans+5ll*d2[i]*C2(d1[i]-1))%mod;
    			ans=(ans+d1[i]*(d1[i]-1ll)%mod*(d1[i]-2)%mod*(d1[i]-3))%mod;
    			ans=(ans+d1[i]*(d1[i]-1ll)/2%mod*(d1[i]-2))%mod;
    		}
    		return ans;
    	}
    	return calc(trans(a),k-1);
    }
    const int N=5e3+10,K1=18,K2=338,K3=300;
    int n,k;
    vector<int>to[N];
    struct rooted{
    	int siz;
    	vector<int>son;
    	ve G;
    }r1[K1];
    struct unrooted{
    	int siz,mx,w;
    	vector<int>son;
    	ve G;
    	bool chk()const{
    		return mx*2<siz;
    	}
    }r2[K2];
    struct Unrooted{
    	int siz,ls,rs,w;
    	ve G;
    }r3[K3];
    int cnt1,cnt2,cnt3,trs1[K1][K1],mg[K1][K1],trs2[K2][K1],uni[K1][K1];
    int getsiz(const vector<int> &son){
    	int siz=1;
    	for(int v:son)siz+=r1[v].siz;
    	return siz;
    }
    void build1(){
    	int lim=(k+1)/2;
    	map<vector<int>,int>s;
    	set<pair<int,vector<int> > >q;
    	q.insert({1,vector<int>()});
    	for(;!q.empty();){
    		if(s.count(q.begin()->second)){
    			q.erase(q.begin());
    			continue;
    		}
    		r1[++cnt1].son=q.begin()->second;
    		s[r1[cnt1].son]=cnt1;
    		r1[cnt1].siz=q.begin()->first;
    		q.erase(q.begin());
    		if(r1[cnt1].siz==lim)continue;
    		q.insert({getsiz({cnt1}),{cnt1}});
    		for(int i=1;i<cnt1;i++){
    			if(r1[i].siz+r1[cnt1].siz>lim)continue;
    			auto t=r1[i].son;
    			t.push_back(cnt1);
    			q.insert({getsiz(t),t});
    		}
    		for(int i=1;i<=cnt1;i++){
    			if(r1[i].siz+r1[cnt1].siz>lim)continue;
    			auto t=r1[cnt1].son;
    			t.push_back(i);
    			sort(t.begin(),t.end());
    			q.insert({getsiz(t),t});
    		}
    	}
    	for(int i=1;i<=cnt1;i++){
    		for(int j=1;j<=cnt1;j++){
    			if(r1[i].siz+r1[j].siz>lim)continue;
    			auto t=r1[i].son;
    			t.push_back(j);
    			sort(t.begin(),t.end());
    			trs1[i][j]=s[t];
    		}
    		// debug(i,r1[i].siz,ary(trs1[i],1,cnt1));
    	}
    	for(int i=1;i<=cnt1;i++){
    		for(int j=1;j<=cnt1;j++){
    			if(r1[i].siz+r1[j].siz-1>lim)continue;
    			auto t=r1[i].son;
    			for(int x:r1[j].son)t.push_back(x);
    			sort(t.begin(),t.end());
    			mg[i][j]=s[t];
    		}
    		// debug(i,r1[i].siz,ary(trs1[i],1,cnt1));
    	}
    }
    void build2(){
    	int lim=k+1;
    	map<vector<int>,int>s;
    	set<pair<int,vector<int> > >q;
    	q.insert({1,vector<int>()});
    	for(;!q.empty();){
    		if(s.count(q.begin()->second)){
    			q.erase(q.begin());
    			continue;
    		}
    		r2[++cnt2].son=q.begin()->second;
    		s[r2[cnt2].son]=cnt2;
    		r2[cnt2].siz=q.begin()->first;
    		q.erase(q.begin());
    		for(int v:r2[cnt2].son){
    			r2[cnt2].mx=max(r2[cnt2].mx,r1[v].siz);
    		}
    		for(int i=1;i<=cnt1;i++){
    			if(r1[i].siz+r2[cnt2].siz>lim)continue;
    			auto t=r2[cnt2].son;
    			t.push_back(i);
    			sort(t.begin(),t.end());
    			q.insert({getsiz(t),t});
    		}
    	}
    	for(int siz=1;siz*2<=lim;siz++){
    		for(int i=1;i<=cnt1;i++){
    			for(int j=i;j<=cnt1;j++){
    				if(r1[i].siz!=siz||r1[j].siz!=siz)continue;
    				r3[++cnt3]={siz*2,i,j};
    				uni[i][j]=uni[j][i]=cnt3;
    			}
    		}
    	}
    	// int cnt=0;
    	// for(int i=1;i<=cnt2;i++){
    	// 	cnt+=r2[i].chk();
    	// }
    	// debug(cnt+cnt3);
    	r2[0].siz=1;
    	for(int i=0;i<=cnt2;i++){
    		for(int j=1;j<=cnt1;j++){
    			if(r2[i].siz+r1[j].siz>lim)continue;
    			auto t=r2[i].son;
    			t.push_back(j);
    			sort(t.begin(),t.end());
    			trs2[i][j]=s[t];
    		}
    		// debug(i,r1[i].siz,ary(trs1[i],1,cnt1));
    	}
    }
    void build3(){
    	auto merge=[&](ve &a,const ve &b){
    		int n=a.size()+1,m=b.size()+1;
    		for(auto &x:a){
    			x.first+=m,x.second+=m;
    		}
    		for(auto x:b)a.push_back(x);
    		a.push_back({n+m-1,m-1});
    	};
    	for(int i=1;i<=cnt1;i++){
    		for(int v:r1[i].son)merge(r1[i].G,r1[v].G);
    	}
    	for(int i=1;i<=cnt2;i++){
    		for(int v:r2[i].son)merge(r2[i].G,r1[v].G);
    	}
    	for(int i=1;i<=cnt3;i++){
    		merge(r3[i].G=r1[r3[i].ls].G,r1[r3[i].rs].G);
    	}
    }
    pair<int,int> getid(const ve &a){
    	int n=-1;
    	for(auto x:a)n=max({n,x.first,x.second});
    	n++;
    	if(!n)return {0,0};
    	vector<vector<int> >to(n);
    	vector<int>siz(n),mx(n);
    	for(auto x:a){
    		to[x.first].push_back(x.second);
    		to[x.second].push_back(x.first);
    	}
    	// debug("getid",a);
    	auto dfs=[&](auto &self,int u,int fa=0)->void {
    		siz[u]=1;
    		for(int v:to[u])if(v^fa){
    			self(self,v,u);
    			siz[u]+=siz[v];
    			mx[u]=max(mx[u],siz[v]);
    		}
    		mx[u]=max(mx[u],n-siz[u]);
    	};
    	dfs(dfs,0);
    	int mn=n,id=-1;
    	for(int i=0;i<n;i++)mn=min(mn,mx[i]);
    	auto gen=[&](auto &self,int u,int fa)->int {
    		int now=1;
    		for(int v:to[u])if(v^fa){
    			now=trs1[now][self(self,v,u)];
    		}
    		return now;
    	};
    	for(int i=0;i<n;i++){
    		if(mx[i]==mn){
    			if(!~id)id=i;
    			else{
    				return {2,uni[gen(gen,id,i)][gen(gen,i,id)]};
    			}
    		}
    	}
    	int now=0;
    	for(int v:to[id]){
    		now=trs2[now][gen(gen,v,id)];
    	}
    	return {1,now};
    }
    void build4(){
    	auto calc=[&](const ve &a){
    		// debug(a);
    		int n=-1;
    		for(auto x:a)n=max({n,x.first,x.second});
    		n++;
    		int U=(1<<n)-1,ans=::calc(a,::k);
    		// debug("calc",a,k,ans);
    		for(int S=1;S<U;S++){
    			int val=mod-1,op,id;
    			ve b;
    			vector<int>p(n);
    			int cnt=0;
    			for(int i=0;i<n;i++)if(S>>i&1){
    				p[i]=cnt++;
    			}
    			vector<int>fa(cnt);
    			iota(fa.begin(),fa.end(),0);
    			auto find=[&](auto &self,int x)->int {
    				return fa[x]==x?x:fa[x]=self(self,fa[x]);
    			};
    			for(auto x:a){
    				if((S>>x.first&1)&&(S>>x.second&1))b.push_back({p[x.first],p[x.second]});
    			}
    			for(auto x:b){
    				fa[find(find,x.first)]=find(find,x.second);
    			}
    			int tot=0;
    			for(int i=0;i<cnt;i++)tot+=fa[i]==i;
    			if(tot==1){
    				tie(op,id)=getid(b);
    				if(op==1)val=1ll*val*r2[id].w%mod;
    				else if(op==2)val=1ll*val*r3[id].w%mod;
    				else val=0;
    			}else val=0;
    			// debug(a,S,b,val,op,id);
    			(ans+=val)%=mod;
    		}
    		return ans;
    	};
    	for(int i=1,j=1;i<=cnt2||j<=cnt3;){
    		if(j>cnt3||(i<=cnt2&&r2[i].siz<r3[j].siz)){
    			if(r2[i].chk())r2[i].w=calc(r2[i].G);
    			i++;
    		}else{
    			r3[j].w=calc(r3[j].G);
    			j++;
    		}
    	}
    }
    int f[N][K1],g[N][K1];
    void ins(int *f,int *g){
    	for(int i=cnt1;i>=1;i--){
    		if(!f[i])continue;
    		for(int j=1;j<=cnt1;j++){
    			if(!trs1[i][j])continue;
    			f[trs1[i][j]]=(f[trs1[i][j]]+1ll*f[i]*g[j])%mod;
    		}
    	}
    }
    int fa[N];
    void dfs1(int u,int fa=0){
    	f[u][1]=1,::fa[u]=fa;
    	for(int v:to[u])if(v^fa){
    		dfs1(v,u);
    		ins(f[u],f[v]);
    	}
    }
    int cnt,son[N],pre[N][K1],suf[N][K2];
    void dfs2(int u,int fa=0){
    	cnt=0;
    	for(int v:to[u])if(v^fa)son[++cnt]=v;
    	memset(pre[0],0,sizeof pre[0]);
    	memset(suf[cnt+1],0,sizeof suf[cnt+1]);
    	pre[0][1]=suf[cnt+1][1]=1;
    	for(int i=1;i<=cnt;i++){
    		copy(pre[i-1]+1,pre[i-1]+1+cnt1,pre[i]+1);
    		ins(pre[i],f[son[i]]);
    	}
    	for(int i=cnt;i>=1;i--){
    		copy(suf[i+1]+1,suf[i+1]+1+cnt1,suf[i]+1);
    		ins(suf[i],f[son[i]]);
    	}
    	for(int i=1;i<=cnt;i++){
    		int v=son[i];
    		for(int x=1;x<=cnt1;x++){
    			for(int y=1;y<=cnt1;y++){
    				if(!mg[x][y])continue;
    				g[v][mg[x][y]]=(g[v][mg[x][y]]+1ll*pre[i-1][x]*suf[i+1][y])%mod;
    			}
    		}
    		ins(g[v],g[u]);
    	}
    	for(int v:to[u])if(v^fa)dfs2(v,u);
    }
    int h[K2];
    void ins2(int *f,int *g){
    	for(int i=cnt2;i>=0;i--){
    		if(!f[i])continue;
    		for(int j=1;j<=cnt1;j++){
    			if(!trs2[i][j])continue;
    			f[trs2[i][j]]=(f[trs2[i][j]]+1ll*f[i]*g[j])%mod;
    		}
    	}
    }
    int main(){
    	cin>>n>>k;
    	for(int i=1,u,v;i<n;i++){
    		cin>>u>>v;
    		to[u].push_back(v),to[v].push_back(u);
    	}
    	build1(),build2(),build3(),build4();
    	dfs1(1),dfs2(1);
    	int ans=0;
    	// for(int i=1;i<=cnt2;i++){
    	// 	if(r2[i].chk())debug("r2",r2[i].siz,r2[i].G,r2[i].w);
    	// }
    	// for(int i=1;i<=cnt3;i++){
    	// 	debug("r3",r3[i].siz,r3[i].G,r3[i].w);
    	// }
    	// debug(calc(r3[4].G,k));
    	for(int u=1;u<=n;u++){
    		memset(h,0,sizeof h);
    		h[0]=1;
    		for(int v:to[u])if(v^fa[u]){
    			ins2(h,f[v]);
    		}
    		ins2(h,g[u]);
    		// if(u==2)debug(ary(h,1,cnt2));
    		for(int i=1;i<=cnt2;i++){
    			if(r2[i].chk()){
    				ans=(ans+1ll*r2[i].w*h[i])%mod;
    			}
    		}
    		// for(int i=1;i<=cnt2;i++){
    		// 	if(h[i]&&r2[i].chk())debug(u,r2[i].siz,r2[i].G);
    		// }
    	}
    	for(int u=2;u<=n;u++){
    		for(int i=1;i<=cnt1;i++){
    			for(int j=1;j<=cnt1;j++){
    				if(!uni[i][j])continue;
    				ans=(ans+1ll*r3[uni[i][j]].w*f[u][i]%mod*g[u][j])%mod;
    			}
    		}
    	}
    	cout<<ans<<endl;
    	debug(1.0*clock()/CLOCKS_PER_SEC);
    	return 0;
    }
    

    8. 小结

    • 本题的关键是:把“长度为 k 的路径”这一全局计数问题,拆解为“中心附近的有限小结构”的组合问题。
    • 通过“小形态自动机”将“拼接/合并”标准化,预先算出每种小形态对长度为 k 路径的承载权重,再在大树上做两趟 DP,实现高效计数。
    • 这类思路在“按半径/中心局部结构控制的全图计数”中非常通用。
    • 1