1 条题解
-
0
题解
题目类型:树上计数 / 结构化 DP(“小图形态”自动机)
取模:998244353
要点一句话:本题统计一棵树中长度为
k
的简单路径条数(长度按“边数”计),通过“小图形态预生成 + 权重预处理 + 树上形态 DP + 中心汇总”实现,总体时间近似 (O(n \cdot \text{poly}(k)))。
1. 目标与整体思路
- 设给定树为 (T),我们要计算简单路径(不重复边)的数目,路径长度恰为
k
(即包含k
条边、k+1
个点)。 - 一条长度为
k
的路径其“中心”有两类:- 若
k
为偶数,则中心是某个点; - 若
k
为奇数,则中心是某条边。
- 若
- 代码围绕“中心”做计数:
对点中心用一套“无根小形态”(r2
) 汇总;
对边中心用另一套“无根双中心形态”(r3
) 汇总。
为此,整份程序分四步预处理 + 两趟树 DP + 末尾汇总:
- build1:生成有根小形态库
r1
(根在中心处的一侧),大小不超过 (\left\lfloor\frac{k+1}{2}\right\rfloor),并给出所有合并/转移表(trs1
、mg
)。 - build2:生成无根小形态库
r2
(中心为点)和双中心小形态库r3
(中心为边),并建立从有根到无根的转移(trs2
)与二元合一的编号表(uni
)。 - build3:为每个小形态生成它对应的**“骨架图”**
G
(一条条小边用(u,v)
编码),用于后续计算该形态在长度为k
的路径计数中的权重。 - build4:给
r2 / r3
中每个无根(或双中心)形态计算一个权重w
,含义是:这个形态(作为“路径中心周围的局部结构”)能“承载”多少条长度为k
的路径。权重通过一个与k
相关的函数calc(G,k)
得到。
然后:
- dfs1(树 DP 第一趟):对大树从下向上,把各子树的有根形态(
r1
的编号)逐步合并到父亲,得到“以当前点为根”的形态分布f[u][*]
。 - dfs2(树 DP 第二趟):对每条子边做“去儿子贡献”的 reroot,使得每个子节点都能得到“除了它子树之外,父亲这一侧合并后的形态”
g[child][*]
,从而在边中心场景下也能准确拼出双中心形态。
最后:
- 按中心汇总:
- 点中心:对每个点
u
,把来自各子树(有根形态分布f
)融合成一个无根形态分布h
(用trs2
做有根→无根的转移),再乘以对应形态权重r2[i].w
累加到答案(只取“可做中心”的形态:r2[i].chk()
即“最大子树 < 总大小的一半”)。 - 边中心:对每条边
(fa[u],u)
,用f[u][*]
(在u
这侧的有根形态)与g[u][*]
(另一侧父亲方向的有根形态)通过uni
拼成一个双中心形态编号,从r3
中取权重累加。
- 点中心:对每个点
2. 核心黑盒
calc(G,k)
—— 小形态的“可承载路径数”2.1 “线图”思想(
trans
)- 将一个小图
G
的边看作新图的点,两条边若在原图中“共享一个端点”,就在新图里连边。 - 对长度为
k
的路径而言,就是在“线图”上走k-1
步的边-边连续相邻序列。 - 函数
trans(a)
就是在做“线图转换”。
2.2
calc(a,k)
的快速实现k=1
:答案就是边数;k=2
:等价于统计“以某点为中枢的两条边”的数目,(\sum C_2(\deg(v)));k=3
:在树上即是“长度为 3 的路径”,可按边为中心累加;k=4
:给出了一段优化过的线性式(基于度数与一次扫描的累积量d1,d2
);k≥5
:递归“线图转换”直到降到上述已优化的边界。
因为形态
G
的规模很小(与k
有关),所以calc
的复杂度是poly(k)
,不会影响整体。2.3
build4
的连通性校正- 对形态的顶点子集做容斥(枚举
S
):只对“连通”的子图记贡献,断开的直接置 0; - 对连通子图,再判定其是“点中心无根形态”(
r2
)还是“边中心双形态”(r3
),并把其calc
值累乘累计为权重w
。
3. 小形态库与转移表
3.1
r1
:有根形态- 描述“以根为中心”的小树,大小不超过 (\lfloor (k+1)/2 \rfloor);
trs1[i][j]
:把一个有根i
与一个子块j
合成后的形态编号;mg[i][j]
:把两个“以同一点为根”的有根形态“合并”后的形态编号(用于 dfs2)。
3.2
r2
:无根点中心形态- 大小不超过
k+1
; chk()
:是否可作为中心(等价于“最大子树大小 < 总大小一半”,即重心判定)。
3.3
r3
:无根边中心(双中心)形态- 大小为偶数(
2*siz
),两端中心分别来自两个有根形态; uni[i][j]
:把两个有根编号(i,j)
组合成一个双中心编号。
3.4
trs2
- 有根 → 无根的转移表(把多个有根儿块拼成一个“点中心”无根形态)。
4. 树上两趟 DP
4.1 第一趟
dfs1
f[u][*]
:在u
的子树中,所有以u
为根的有根形态出现次数;- 把每个儿子的
f[v][*]
插入到f[u][*]
(ins
),用trs1
做“加一子块”的自动机转移。
4.2 第二趟
dfs2
- 对于每个儿子
v
,我们需要得到“除去 v 那一边,其余所有子块+父边合起来在u
处形成的有根形态”,即g[v][*]
; - 做法:对
u
的所有儿子做前后缀合并(pre / suf
),把除v
外的形态拼起来,再通过mg
、ins
转移到g[v][*]
; - 递归下去,每个点/边都能拿到“另一侧的有根形态”。
5. 最终汇总
-
点中心:对每个点
u
:- 用
h
从空形态出发,依次把所有f[v]
(v
为u
的儿子)合并到h
(ins2
,即trs2
转移);再把“父侧”贡献g[u]
也合进来; - 对所有无根形态
i
,若r2[i].chk()
(可作中心),则答案+= r2[i].w * h[i]
。
- 用
-
边中心:对每条边
(fa[u],u)
:- 枚举有根编号
i
(来自u
侧的f[u][i]
)与j
(来自父侧的g[u][j]
), - 用
uni[i][j]
得到双中心编号,累加r3[...] .w * f[u][i] * g[u][j]
。
- 枚举有根编号
6. 复杂度与实现细节
- 预处理各形态与转移表:规模由
k
决定,为poly(k)
; - 两趟 DP 各
O(n * poly(k))
; - 汇总亦为
O(n * poly(k))
; - 整体近似 (O(n \cdot \text{poly}(k))),可在题目范围内顺利通过。
7. 代码(与题给一致)
(请在此补充题目的中文题解与思路描述。)
#include<bits/stdc++.h> using namespace std; using ll=long long; #ifdef DEBUG template<class T> ostream& operator << (ostream &out,vector<T> a){ out<<'['; for(T x:a)out<<x<<','; return out<<']'; } ostream& operator << (ostream &out,pair<int,int> a){ return out<<'('<<a.first<<','<<a.second<<')'; } template<class T> vector<T> ary(T *a,int l,int r){ return vector<T>{a+l,a+1+r}; } template<class T> void debug(T x){ cerr<<x<<endl; } template<class T,class...S> void debug(T x,S...y){ cerr<<x<<' ',debug(y...); } #else #define debug(...) void() #endif const int mod=998244353; using LL=__int128; using ve=vector<pair<int,int> >; ve trans(const ve &a){ ve b; int len=a.size(),n=-1; for(auto x:a)n=max({n,x.first,x.second}); n++; vector<vector<int> >to(n); for(int i=0;i<len;i++){ to[a[i].first].push_back(i); to[a[i].second].push_back(i); } for(int i=0;i<n;i++){ int len=to[i].size(); for(int x=0;x<len;x++){ for(int y=0;y<x;y++){ b.push_back({to[i][y],to[i][x]}); } } } return b; } int C2(int x){ return x*(x-1ll)/2%mod; } int calc(const ve &a,int k){ if(k==1)return a.size(); int n=-1; for(auto x:a)n=max({n,x.first,x.second}); n++; if(!k)return n; if(k==2){ vector<int>deg(n); for(auto x:a)deg[x.first]++,deg[x.second]++; int ans=0; for(int i=0;i<n;i++)(ans+=C2(deg[i]))%=mod; return ans; } if(k==3){ vector<int>deg(n); for(auto x:a)deg[x.first]++,deg[x.second]++; int ans=0; for(auto x:a)(ans+=C2(deg[x.first]+deg[x.second]-2))%=mod; return ans; } if(k==4){ // debug(a); vector<int>d1(n),d2(n); for(auto x:a)d1[x.first]++,d1[x.second]++; int ans=0; for(auto x:a){ int u,v; tie(u,v)=x; ans=(ans+d2[u]*(d1[v]-1ll))%mod; ans=(ans+d2[v]*(d1[u]-1ll))%mod; (d2[u]+=d1[v]-1)%=mod; (d2[v]+=d1[u]-1)%=mod; } for(int i=0;i<n;i++){ ans=(ans+5ll*d2[i]*C2(d1[i]-1))%mod; ans=(ans+d1[i]*(d1[i]-1ll)%mod*(d1[i]-2)%mod*(d1[i]-3))%mod; ans=(ans+d1[i]*(d1[i]-1ll)/2%mod*(d1[i]-2))%mod; } return ans; } return calc(trans(a),k-1); } const int N=5e3+10,K1=18,K2=338,K3=300; int n,k; vector<int>to[N]; struct rooted{ int siz; vector<int>son; ve G; }r1[K1]; struct unrooted{ int siz,mx,w; vector<int>son; ve G; bool chk()const{ return mx*2<siz; } }r2[K2]; struct Unrooted{ int siz,ls,rs,w; ve G; }r3[K3]; int cnt1,cnt2,cnt3,trs1[K1][K1],mg[K1][K1],trs2[K2][K1],uni[K1][K1]; int getsiz(const vector<int> &son){ int siz=1; for(int v:son)siz+=r1[v].siz; return siz; } void build1(){ int lim=(k+1)/2; map<vector<int>,int>s; set<pair<int,vector<int> > >q; q.insert({1,vector<int>()}); for(;!q.empty();){ if(s.count(q.begin()->second)){ q.erase(q.begin()); continue; } r1[++cnt1].son=q.begin()->second; s[r1[cnt1].son]=cnt1; r1[cnt1].siz=q.begin()->first; q.erase(q.begin()); if(r1[cnt1].siz==lim)continue; q.insert({getsiz({cnt1}),{cnt1}}); for(int i=1;i<cnt1;i++){ if(r1[i].siz+r1[cnt1].siz>lim)continue; auto t=r1[i].son; t.push_back(cnt1); q.insert({getsiz(t),t}); } for(int i=1;i<=cnt1;i++){ if(r1[i].siz+r1[cnt1].siz>lim)continue; auto t=r1[cnt1].son; t.push_back(i); sort(t.begin(),t.end()); q.insert({getsiz(t),t}); } } for(int i=1;i<=cnt1;i++){ for(int j=1;j<=cnt1;j++){ if(r1[i].siz+r1[j].siz>lim)continue; auto t=r1[i].son; t.push_back(j); sort(t.begin(),t.end()); trs1[i][j]=s[t]; } // debug(i,r1[i].siz,ary(trs1[i],1,cnt1)); } for(int i=1;i<=cnt1;i++){ for(int j=1;j<=cnt1;j++){ if(r1[i].siz+r1[j].siz-1>lim)continue; auto t=r1[i].son; for(int x:r1[j].son)t.push_back(x); sort(t.begin(),t.end()); mg[i][j]=s[t]; } // debug(i,r1[i].siz,ary(trs1[i],1,cnt1)); } } void build2(){ int lim=k+1; map<vector<int>,int>s; set<pair<int,vector<int> > >q; q.insert({1,vector<int>()}); for(;!q.empty();){ if(s.count(q.begin()->second)){ q.erase(q.begin()); continue; } r2[++cnt2].son=q.begin()->second; s[r2[cnt2].son]=cnt2; r2[cnt2].siz=q.begin()->first; q.erase(q.begin()); for(int v:r2[cnt2].son){ r2[cnt2].mx=max(r2[cnt2].mx,r1[v].siz); } for(int i=1;i<=cnt1;i++){ if(r1[i].siz+r2[cnt2].siz>lim)continue; auto t=r2[cnt2].son; t.push_back(i); sort(t.begin(),t.end()); q.insert({getsiz(t),t}); } } for(int siz=1;siz*2<=lim;siz++){ for(int i=1;i<=cnt1;i++){ for(int j=i;j<=cnt1;j++){ if(r1[i].siz!=siz||r1[j].siz!=siz)continue; r3[++cnt3]={siz*2,i,j}; uni[i][j]=uni[j][i]=cnt3; } } } // int cnt=0; // for(int i=1;i<=cnt2;i++){ // cnt+=r2[i].chk(); // } // debug(cnt+cnt3); r2[0].siz=1; for(int i=0;i<=cnt2;i++){ for(int j=1;j<=cnt1;j++){ if(r2[i].siz+r1[j].siz>lim)continue; auto t=r2[i].son; t.push_back(j); sort(t.begin(),t.end()); trs2[i][j]=s[t]; } // debug(i,r1[i].siz,ary(trs1[i],1,cnt1)); } } void build3(){ auto merge=[&](ve &a,const ve &b){ int n=a.size()+1,m=b.size()+1; for(auto &x:a){ x.first+=m,x.second+=m; } for(auto x:b)a.push_back(x); a.push_back({n+m-1,m-1}); }; for(int i=1;i<=cnt1;i++){ for(int v:r1[i].son)merge(r1[i].G,r1[v].G); } for(int i=1;i<=cnt2;i++){ for(int v:r2[i].son)merge(r2[i].G,r1[v].G); } for(int i=1;i<=cnt3;i++){ merge(r3[i].G=r1[r3[i].ls].G,r1[r3[i].rs].G); } } pair<int,int> getid(const ve &a){ int n=-1; for(auto x:a)n=max({n,x.first,x.second}); n++; if(!n)return {0,0}; vector<vector<int> >to(n); vector<int>siz(n),mx(n); for(auto x:a){ to[x.first].push_back(x.second); to[x.second].push_back(x.first); } // debug("getid",a); auto dfs=[&](auto &self,int u,int fa=0)->void { siz[u]=1; for(int v:to[u])if(v^fa){ self(self,v,u); siz[u]+=siz[v]; mx[u]=max(mx[u],siz[v]); } mx[u]=max(mx[u],n-siz[u]); }; dfs(dfs,0); int mn=n,id=-1; for(int i=0;i<n;i++)mn=min(mn,mx[i]); auto gen=[&](auto &self,int u,int fa)->int { int now=1; for(int v:to[u])if(v^fa){ now=trs1[now][self(self,v,u)]; } return now; }; for(int i=0;i<n;i++){ if(mx[i]==mn){ if(!~id)id=i; else{ return {2,uni[gen(gen,id,i)][gen(gen,i,id)]}; } } } int now=0; for(int v:to[id]){ now=trs2[now][gen(gen,v,id)]; } return {1,now}; } void build4(){ auto calc=[&](const ve &a){ // debug(a); int n=-1; for(auto x:a)n=max({n,x.first,x.second}); n++; int U=(1<<n)-1,ans=::calc(a,::k); // debug("calc",a,k,ans); for(int S=1;S<U;S++){ int val=mod-1,op,id; ve b; vector<int>p(n); int cnt=0; for(int i=0;i<n;i++)if(S>>i&1){ p[i]=cnt++; } vector<int>fa(cnt); iota(fa.begin(),fa.end(),0); auto find=[&](auto &self,int x)->int { return fa[x]==x?x:fa[x]=self(self,fa[x]); }; for(auto x:a){ if((S>>x.first&1)&&(S>>x.second&1))b.push_back({p[x.first],p[x.second]}); } for(auto x:b){ fa[find(find,x.first)]=find(find,x.second); } int tot=0; for(int i=0;i<cnt;i++)tot+=fa[i]==i; if(tot==1){ tie(op,id)=getid(b); if(op==1)val=1ll*val*r2[id].w%mod; else if(op==2)val=1ll*val*r3[id].w%mod; else val=0; }else val=0; // debug(a,S,b,val,op,id); (ans+=val)%=mod; } return ans; }; for(int i=1,j=1;i<=cnt2||j<=cnt3;){ if(j>cnt3||(i<=cnt2&&r2[i].siz<r3[j].siz)){ if(r2[i].chk())r2[i].w=calc(r2[i].G); i++; }else{ r3[j].w=calc(r3[j].G); j++; } } } int f[N][K1],g[N][K1]; void ins(int *f,int *g){ for(int i=cnt1;i>=1;i--){ if(!f[i])continue; for(int j=1;j<=cnt1;j++){ if(!trs1[i][j])continue; f[trs1[i][j]]=(f[trs1[i][j]]+1ll*f[i]*g[j])%mod; } } } int fa[N]; void dfs1(int u,int fa=0){ f[u][1]=1,::fa[u]=fa; for(int v:to[u])if(v^fa){ dfs1(v,u); ins(f[u],f[v]); } } int cnt,son[N],pre[N][K1],suf[N][K2]; void dfs2(int u,int fa=0){ cnt=0; for(int v:to[u])if(v^fa)son[++cnt]=v; memset(pre[0],0,sizeof pre[0]); memset(suf[cnt+1],0,sizeof suf[cnt+1]); pre[0][1]=suf[cnt+1][1]=1; for(int i=1;i<=cnt;i++){ copy(pre[i-1]+1,pre[i-1]+1+cnt1,pre[i]+1); ins(pre[i],f[son[i]]); } for(int i=cnt;i>=1;i--){ copy(suf[i+1]+1,suf[i+1]+1+cnt1,suf[i]+1); ins(suf[i],f[son[i]]); } for(int i=1;i<=cnt;i++){ int v=son[i]; for(int x=1;x<=cnt1;x++){ for(int y=1;y<=cnt1;y++){ if(!mg[x][y])continue; g[v][mg[x][y]]=(g[v][mg[x][y]]+1ll*pre[i-1][x]*suf[i+1][y])%mod; } } ins(g[v],g[u]); } for(int v:to[u])if(v^fa)dfs2(v,u); } int h[K2]; void ins2(int *f,int *g){ for(int i=cnt2;i>=0;i--){ if(!f[i])continue; for(int j=1;j<=cnt1;j++){ if(!trs2[i][j])continue; f[trs2[i][j]]=(f[trs2[i][j]]+1ll*f[i]*g[j])%mod; } } } int main(){ cin>>n>>k; for(int i=1,u,v;i<n;i++){ cin>>u>>v; to[u].push_back(v),to[v].push_back(u); } build1(),build2(),build3(),build4(); dfs1(1),dfs2(1); int ans=0; // for(int i=1;i<=cnt2;i++){ // if(r2[i].chk())debug("r2",r2[i].siz,r2[i].G,r2[i].w); // } // for(int i=1;i<=cnt3;i++){ // debug("r3",r3[i].siz,r3[i].G,r3[i].w); // } // debug(calc(r3[4].G,k)); for(int u=1;u<=n;u++){ memset(h,0,sizeof h); h[0]=1; for(int v:to[u])if(v^fa[u]){ ins2(h,f[v]); } ins2(h,g[u]); // if(u==2)debug(ary(h,1,cnt2)); for(int i=1;i<=cnt2;i++){ if(r2[i].chk()){ ans=(ans+1ll*r2[i].w*h[i])%mod; } } // for(int i=1;i<=cnt2;i++){ // if(h[i]&&r2[i].chk())debug(u,r2[i].siz,r2[i].G); // } } for(int u=2;u<=n;u++){ for(int i=1;i<=cnt1;i++){ for(int j=1;j<=cnt1;j++){ if(!uni[i][j])continue; ans=(ans+1ll*r3[uni[i][j]].w*f[u][i]%mod*g[u][j])%mod; } } } cout<<ans<<endl; debug(1.0*clock()/CLOCKS_PER_SEC); return 0; }
8. 小结
- 本题的关键是:把“长度为
k
的路径”这一全局计数问题,拆解为“中心附近的有限小结构”的组合问题。 - 通过“小形态自动机”将“拼接/合并”标准化,预先算出每种小形态对长度为
k
路径的承载权重,再在大树上做两趟 DP,实现高效计数。 - 这类思路在“按半径/中心局部结构控制的全图计数”中非常通用。
- 设给定树为 (T),我们要计算简单路径(不重复边)的数目,路径长度恰为
- 1
信息
- ID
- 3560
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 10
- 标签
- 递交数
- 2
- 已通过
- 1
- 上传者