树链剖分 算法学习

树你应该懂的吧o( ̄︶ ̄)o 学习树链剖分之前需要先学习:$dfs$、线段树(当然大佬们用树状数组代替线段树也可以O(∩_∩)O),据说一名普及+的$oier$应该都会呀

先来了解树链剖分的用处

Luogu题目传送门 已知一棵包含$N$个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
  • 操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
  • 操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
  • 操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

如果直接暴力的话,肯定会TLE(废话)。所以这时候,树链剖分闪亮登场。

什么是树链剖分

一种算法(废话),它通过分轻重边把树分割成很多链,然后利用某种数据结构维护这些链(比如上文提到的线段树、树状数组等)但前提是这种数据结构支持动态修改(你别给我整个RMQ)。本质上是一种暴力算法。 PS:树剖的复杂度约为$O(nlog^2n)$

树链剖分的基本概念

名称

概念

重儿子

父亲节点的所有儿子中子节点数目最多(sz最大)的节点

轻儿子

父亲节点除了重儿子以外的儿子

重边

父亲节点和重儿子连成的边

轻边

父亲节点和轻儿子连成的边

重链

由多条重边连成的路径

轻链

由多条轻边连成的路径

没看懂?没关系,结合下面这张图:(红色的边代表重边,黑色的边代表轻边,红色的节点代表重儿子,黑色的节点代表轻儿子) PS:这里默认树根也是重儿子。

上图的重链有:1-4,3-6。

变量声明

1
2
3
4
5
ll fir[MAXN],nxt[MAXN*2],son[MAXN*2],w[MAXN*2],tot;
struct Node{
ll sum,tag,l,r,ls,rs;
}a[2*MAXN];
ll root,n,m,r,mod,v[MAXN],cnt,fa[MAXN],dep[MAXN],sz[MAXN],c[MAXN],rk[MAXN],top[MAXN],id[MAXN];

名称

作用

$fir_x$

关于$x$的最后一条边编号

$nxt_x$

关于$x$的上一条边编号

$son_x$

第$x$条边的连向

$w_x$

其实没啥用,打着习惯了

$a_x.ls$

编号为$x$的节点的左儿子

$a_x.rs$

编号为$x$的节点的右儿子

$fa_x$

编号为$x$的节点的父亲

$c_x$

编号为$x$的节点的重儿子

$rk_x$

当前$dfs$标号在树中所对应的节点的编号

$top_x$

编号为$x$的节点所在链的顶端节点编号

$id_x$

编号为$x$的节点$dfs$后的新编号

$dep_x$

编号为$x$的节点的深度

$sz_x$

以编号为$x$的节点为根的子树的节点个数

树链剖分的实现

第一次$dfs$求出每个节点的重儿子、父亲、深度、子树大小。

PS:如果一个点的多个儿子所在子树大小相等且最大,那随便找一个当做它的重儿子就好了,叶节点没有重儿子,非叶节点有且只有一个重儿子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
inline void dfs1(ll x,ll f,ll deep){
fa[x]=f;//该节点的父亲
dep[x]=deep;//该节点深度
sz[x]=1;//该节点子树先设置为1(本身)
for(ll i=fir[x];i;i=nxt[i]){//寻找与该节点相连的边
ll to=son[i];//该边的另一个节点
if(to==f) continue ;//如果另一个节点刚好是父亲,那么continue
dfs1(to,x,deep+1);//否则dfs该节点,并且父亲为本节点,深度+1
sz[x]+=sz[to];//子树大小增加
if(sz[to]>sz[c[x]]) c[x]=to;//重儿子更新(找子树最大的)
}
}
//主函数调用
dfs1(root,0,1);

操作完以后应该是下图:

第二次$dfs$求出每个节点的链顶端节点、新编号、$dfs$编号对应的节点编号。

1
2
3
4
5
6
7
8
9
10
11
12
inline void dfs2(ll x,ll ttop){
top[x]=ttop;//链顶端编号
id[x]=++cnt;//新编号(dfs序)
rk[cnt]=x;//新编号对应节点编号
if(c[x]!=0) dfs2(c[x],ttop);//如果不是叶子节点,优先dfs重儿子,因为节点与重儿子处在同一重链,所以重儿子的重链顶端还是ttop
for(ll i=fir[x];i;i=nxt[i]){
ll to=son[i];
if(to!=c[x]&&to!=fa[x]) dfs2(to,to);//如果既不是父亲也不是重儿子,那么就是该节点的轻儿子,那么dfs,且该节点的重链顶端为它本身
}
}
//主函数调用
dfs2(root,root);

操作完以后应该是下图:

线段树等数据结构的维护

接下来就是线段树、树状数组等数据结构的维护了,具体使用哪种数据结构因题目而异,这里提供模板题(上文介绍的题目)所使用的线段树(区间修改、区间询问)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
inline void pushup(ll x){
a[x].sum=(a[a[x].ls].sum+a[a[x].rs].sum)%mod;//更新求和
}
inline void build(ll l,ll r,ll x){
if(l==r){
a[x].sum=v[rk[l]];//符合所在区间,更新
a[x].l=a[x].r=l;//l、r更新
return ;
}
ll mid=l+r>>1;//线段树性质
a[x].ls=cnt++;a[x].rs=cnt++;//左右儿子节点编号
build(l,mid,a[x].ls);build(mid+1,r,a[x].rs);//分而治之
a[x].l=a[a[x].ls].l,a[x].r=a[a[x].rs].r;//区间更新
pushup(x);//sum更新
}
inline ll len(ll x){
return a[x].r-a[x].l+1;//该区间的节点数量
}
inline void pushdown(ll x){
if(a[x].tag!=0){//如果有lazy tag
a[a[x].ls].tag+=a[x].tag;a[a[x].rs].tag+=a[x].tag;//向左右儿子传递
a[a[x].ls].tag%=mod;a[a[x].rs].tag%=mod;
a[a[x].ls].sum+=a[x].tag*len(a[x].ls);a[a[x].rs].sum+=a[x].tag*len(a[x].rs);//左右儿子更新
a[a[x].ls].sum%=mod;a[a[x].rs].sum%=mod;
a[x].tag=0;//lazy tag取消
}
}
inline void update(ll l,ll r,ll c,ll x){
if(a[x].l>=l&&a[x].r<=r){
a[x].tag+=c;a[x].tag%=mod;//修改lazy tag
a[x].sum+=len(x)*c;a[x].sum%=mod;//修改sum
return ;
}
pushdown(x);//标记下传
ll mid=a[x].l+a[x].r>>1;
if(mid>=l) update(l,r,c,a[x].ls);//分而治之
if(mid<r) update(l,r,c,a[x].rs);
pushup(x);//更新sum
}
inline ll query(ll l,ll r,ll x){
if(a[x].l>=l&&a[x].r<=r) return a[x].sum;//如果符合在本区间内,那么return
pushdown(x);//标记下传
ll mid=a[x].l+a[x].r>>1,ss=0;
if(mid>=l) ss+=query(l,r,a[x].ls);ss%=mod;//分而治之
if(mid<r) ss+=query(l,r,a[x].rs);ss%=mod;
return ss;//返回
}
//主函数调用(根据上文题目)
cnt=0;build(1,n,root=cnt++);
update(id[x],id[x]+sz[x]-1,y,root);
query(id[x],id[x]+sz[x]-1,root);

根据题目需要添加操作

就比如上文的题目中还要求的操作:

  • 操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
  • 操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

与操作3、操作4不同,这里要求的是一条路径上的节点,而没有告诉我们节点的编号,所以,我们这时要求出节点编号:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
inline ll Query(ll x,ll y){
ll res=0;
while(top[x]!=top[y]){//若两点不再同一条链上
if(dep[top[x]]<dep[top[y]]) swap(x,y);
res+=query(id[top[x]],id[x],root);//ans更新
res%=mod;
x=fa[top[x]];//让x向上爬(与倍增思想类似,但有时复杂度更低)
}
if(id[x]>id[y]) swap(x,y);
res+=query(id[x],id[y],root);//在同一条链,跳到同一点,ans更新
res%=mod;
return res;
}
inline void Update(ll x,ll y,ll c){
while(top[x]!=top[y]){//两点不在同一条链
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(id[top[x]],id[x],c,root);//更新
x=fa[top[x]];//让x向上爬
}
if(id[x]>id[y]) swap(x,y);
update(id[x],id[y],c,root);//在同一链,跳到同一点,更新
}

当然,还有一个操作是非常常用的,那就是求lca(最近公共祖先)。

1
2
3
4
5
6
7
inline ll lca(ll x,ll y){
while(top[x]!=top[y]){//两点不在同一条链上肯定没有公共祖先
if(dep[top[x]]>=dep[top[y]])x=fa[top[x]];//让深度低的点向上爬,x向上爬
else y=fa[top[y]];//y向上爬
}
return dep[x]<dep[y]?x:y;//取深度低的点
}

模板题代码

对对对,就是上文提到的题目。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include<bits/stdc++.h>
#define MAXN 200010
#define ll long long
using namespace std;
ll fir[MAXN],nxt[MAXN*2],son[MAXN*2],w[MAXN*2],tot;
struct Node{
ll sum,tag,l,r,ls,rs;
}a[2*MAXN];
ll root,n,m,r,mod,v[MAXN],cnt,fa[MAXN],dep[MAXN],sz[MAXN],c[MAXN],rk[MAXN],top[MAXN],id[MAXN];
inline void dfs1(ll x,ll f,ll deep){
fa[x]=f;
dep[x]=deep;
sz[x]=1;
for(ll i=fir[x];i;i=nxt[i]){
ll to=son[i];
if(to==f) continue ;
dfs1(to,x,deep+1);
sz[x]+=sz[to];
if(sz[to]>sz[c[x]]) c[x]=to;
}
}
inline void dfs2(ll x,ll ttop){
top[x]=ttop;
id[x]=++cnt;
rk[cnt]=x;
if(c[x]!=0) dfs2(c[x],ttop);
for(ll i=fir[x];i;i=nxt[i]){
ll to=son[i];
if(to!=c[x]&&to!=fa[x]) dfs2(to,to);
}
}
inline void pushup(ll x){
a[x].sum=(a[a[x].ls].sum+a[a[x].rs].sum)%mod;
}
inline void build(ll l,ll r,ll x){
if(l==r){
a[x].sum=v[rk[l]];
a[x].l=a[x].r=l;
return ;
}
ll mid=l+r>>1;
a[x].ls=cnt++;a[x].rs=cnt++;
build(l,mid,a[x].ls);build(mid+1,r,a[x].rs);
a[x].l=a[a[x].ls].l,a[x].r=a[a[x].rs].r;
pushup(x);
}
inline ll len(ll x){
return a[x].r-a[x].l+1;
}
inline void pushdown(ll x){
if(a[x].tag!=0){
a[a[x].ls].tag+=a[x].tag;a[a[x].rs].tag+=a[x].tag;
a[a[x].ls].tag%=mod;a[a[x].rs].tag%=mod;
a[a[x].ls].sum+=a[x].tag*len(a[x].ls);a[a[x].rs].sum+=a[x].tag*len(a[x].rs);
a[a[x].ls].sum%=mod;a[a[x].rs].sum%=mod;
a[x].tag=0;
}
}
inline void update(ll l,ll r,ll c,ll x){
if(a[x].l>=l&&a[x].r<=r){
a[x].tag+=c;a[x].tag%=mod;
a[x].sum+=len(x)*c;a[x].sum%=mod;
return ;
}
pushdown(x);
ll mid=a[x].l+a[x].r>>1;
if(mid>=l) update(l,r,c,a[x].ls);
if(mid<r) update(l,r,c,a[x].rs);
pushup(x);
}
inline ll lca(ll x,ll y){
while(top[x]!=top[y]){
if(dep[top[x]]>=dep[top[y]])x=fa[top[x]];
else y=fa[top[y]];
}
return dep[x]<dep[y]?x:y;
}
inline ll query(ll l,ll r,ll x){
if(a[x].l>=l&&a[x].r<=r) return a[x].sum;
pushdown(x);
ll mid=a[x].l+a[x].r>>1,ss=0;
if(mid>=l) ss+=query(l,r,a[x].ls);ss%=mod;
if(mid<r) ss+=query(l,r,a[x].rs);ss%=mod;
return ss;
}
inline ll Query(ll x,ll y){
ll res=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
res+=query(id[top[x]],id[x],root);
res%=mod;
x=fa[top[x]];
}
if(id[x]>id[y]) swap(x,y);
res+=query(id[x],id[y],root);
res%=mod;
return res;
}
inline void Update(ll x,ll y,ll c){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(id[top[x]],id[x],c,root);
x=fa[top[x]];
}
if(id[x]>id[y]) swap(x,y);
update(id[x],id[y],c,root);
}
inline ll read(){
char ch=getchar();ll res=0,f=1;
while(ch<'0'ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') res=res*10+ch-'0',ch=getchar();
return res*f;
}
inline void write(ll x){
if(x<10) putchar(x+'0');
else{
write(x/10);
putchar(x%10+'0');
}
}
inline void add(ll x,ll y){
++tot;
son[tot]=y;
nxt[tot]=fir[x];
fir[x]=tot;
}
int main(){
n=read();m=read();r=read();mod=read();
for(ll i=1;i<=n;i++) v[i]=read();
for(ll x,y,i=1;i<n;i++){
x=read(),y=read();
add(x,y);add(y,x);
}
cnt=0;dfs1(r,0,1);
dfs2(r,r);
cnt=0;build(1,n,root=cnt++);
for(ll op,x,y,k,i=1;i<=m;i++){
op=read();
if(op==1){
x=read();y=read();k=read();
Update(x,y,k);
}else if(op==2){
x=read();y=read();
write(Query(x,y));putchar('\n');
}else if(op==3){
x=read();y=read();
update(id[x],id[x]+sz[x]-1,y,root);
}else if(op==4){
x=read();
write(query(id[x],id[x]+sz[x]-1,root));putchar('\n');
}
}
return 0;
}

完美撒花✿✿ヽ(°▽°)ノ✿