树链剖分浅析

前言

前几天看了树剖有点忘了…

mmp树剖的元素太多了我还是用 lct 好了…
侵删


树链剖分

引言

树链剖分,顾名思义,就是把一棵树分割成很多条链
在实际操作中,我们经常需要把一棵树剖分成很多条链,如果要进行查询或修改路径,用线段树维护
那么我们如何实现呢?

剖分与线段树结合

树链剖分离不开剖分操作,而剖分操作是靠 dfs 完成的

我们先来看一下概念

  • 重儿子:节点 i 的儿子节点中拥有子节点最多的点
  • 轻儿子:节点 i 除去重儿子的其他儿子节点
  • 重边:节点 i 与它的重儿子连成的边
  • 轻边:节点 i 与它的轻儿子连成的边
  • 重链:由重边连成的路径
  • 轻链:由轻边连成的路径

一个点只有一个重儿子
注意儿子节点和子节点概念不同

还有以下我们设置一些数组的表示

  • size( i )表示根为 i 的点的子节点数目
  • son ( i )表示节点 i 的儿子节点中 size 最大的点
  • dep ( i ) 表示节点 i 的深度(一棵树总是有深度
  • fa ( i ) 表示节点 i 的父亲节点
  • top ( i )表示的是节点 i 所在链的顶端(在求LCA问题中用处很大

树剖就是把一棵树剖分成重链和轻链,来看这样一棵树

照以上的描述,标记为黑色粗边的为重链,每一条黑色粗边为重边
标记为红色的点为每一条链的 top ,细边的为轻链

以上是剖分的基本概念,如果要和线段树相结合,我们用

  • id( i ) 表示点 i 与 fa ( i ) 的连边在线段树上的编号
  • wt( i ) 表示这条连边的值

再次参考上面的图,每条边旁边的数字就是该边在线段树上的编号

那为什么我们要用线段树呢

线段树可以在 log 的级别询问修改区间,但是如果我们要在一棵树上进行路径的询问和修改,那么只凭线段树是不可行的,所以我们要在树上先进行树链剖分,然后能够满足线段树的操作
总的来说在树剖里用线段树,线段树的写法和一般的线段树是差不多的…我们维护的是拆分后的树链,对于一棵树的询问修改就变成了 log 级别的了

算法实现

我们先用一次 dfs 来求出各个点的 siz , fa , dep , son
这个实现其实很简单

1
2
3
4
5
6
7
8
9
10
11
inline void dfs1(int x,int f,int deep){
dep[x]=deep;fa[x]=f;siz[x]=1;
int maxson=-1;
for(Rint i=beg[x];i;i=nex[i]){
int y=to[i];
if(y==f)continue;
dfs1(y,x,deep+1);
siz[x]+=siz[y];
if(siz[y]>maxson)son[x]=y,maxson=siz[y];
}
}

我们用再做一次 dfs 来进行 id ,wt , top 的赋值

1
2
3
4
5
6
7
8
9
10
inline void dfs2(int x,int topf){ 
id[x]=++cnt;wt[cnt]=w[x];top[x]=topf;
if(!son[x])return;
dfs2(son[x],topf);
for(Rint i=beg[x];i;i=nex[i]){
int y=to[i];
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);
}
}

讲述也比较麻烦就直接看代码了也是很好理解的

树剖解LCA问题

现在我们先来看一下用树剖解LCA问题

询问两个点 x ,y ,我们要求他们的公共祖先。
对于每一条树链我们已经求出了这一条树链中深度最小的点,那么如果 top ( x )==top ( y ) ,那么 x 和 y 的最近公共祖先就是 min( x , y )
如果 top( x )! = top ( y ) ,我们就往上找祖先,取 top( x ) , top ( y ) 的最小值min,再判断 top ( fa ( min ) )与另一个点的 top 是不是相等,一直这样做,直到两点的 top 相等即可
我们来看一下代码实现

1
2
3
4
5
while(top[x]!=top[y]){
if(dep[top[x]]>=dep[top[y]]) x=fa[top[x]];
else y=fa[top[y]];
}
printf("%d\n",dep[x]<dep[y]?x:y);

在题目中应用,时间跑的应该不是很慢的,还是比较优越的…

树链在线段树中的维护

我们再来看一下树剖后在线段树中的应用

我们将树的各边的权值在线段树中维护更新

如果我们要进行修改操作,比如将 u 与 v 路径上的每条边的权值都加上 c ,我们可以求出 u 到 v 的 lca,然后慢慢修改 u , v 到公共祖先的边
但是我们有更加优越的算法
记f1 = top(u),f2 = top(v)。
当 f1 != f2 时:不妨设 dep(f1) >= dep(f2),那么就更新 u 到 f1 的父边的权值(logn),并使 u = fa(f1)
当 f1 == f2 时:u 与 v 在同一条重链上,若 u 与 v 不是同一点,就更新 u 到 v 路径上的边的权值(logn),否则修改完成
重复上述操作直至完成
在进行求和的操作时,类似修改操作,但是不更新边权

我们依旧用那张图进行演示

当要修改11到10的路径时。
第一次迭代:u = 11,v = 10,f1 = 2,f2 = 10。此时 dep(f1) < dep(f2),因此修改线段树中的5号点,v = 4, f2 = 1;
第二次迭代:dep(f1) > dep(f2),修改线段树中10–11号点。u = 2,f1 = 2;
第三次迭代:dep(f1) > dep(f2),修改线段树中9号点。u = 1,f1 = 1;
第四次迭代:f1 = f2 且 u = v,修改结束。

(其实我自己是不会画图的…逃

我们来看一下模板题


题目

题目描述

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

输入输出格式

输入格式:

第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。

接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)

接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:

操作1: 1 x y z

操作2: 2 x y

操作3: 3 x z

操作4: 4 x

输出格式:

输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模)

输入输出样例

输入样例#1:
5 5 2 24
7 3 7 8 0 
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3
输出样例#1:
2
21

说明

时空限制:1s,128M

数据规模:

对于30%的数据: N ≤ 10, M ≤ 10

对于70%的数据: N ≤ 103, M ≤ 103

对于100%的数据: N≤10 5 , M ≤ 105

解题思路

树剖后套上线段树…
写了那么多简析的我不想多说什么了…

AC代码352ms

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// luogu-judger-enable-o2
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#define Rint register int
#define mem(a,b) memset(a,(b),sizeof(a))
#define Temp template<typename T>

using namespace std;

typedef long long LL;
Temp inline void read(T &x){
x=0;T w=1,ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(isdigit(ch))x=(x<<3)+(x<<1)+(ch^'0'),ch=getchar();
x=x*w;
}

#define mid ((l+r)>>1)
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define len (r-l+1)

const int maxn=200000+10;
int n,m,r,mod;
int e,beg[maxn],nex[maxn],to[maxn],w[maxn],wt[maxn];
int a[maxn<<2],laz[maxn<<2];
int son[maxn],id[maxn],fa[maxn],cnt,dep[maxn],siz[maxn],top[maxn];
int res=0;

inline void add(int x,int y){to[++e]=y;nex[e]=beg[x];beg[x]=e;}

inline void pushdown(int rt,int lenn){
laz[rt<<1]+=laz[rt];laz[rt<<1|1]+=laz[rt];
a[rt<<1]+=laz[rt]*(lenn-(lenn>>1));
a[rt<<1|1]+=laz[rt]*(lenn>>1);
a[rt<<1]%=mod;a[rt<<1|1]%=mod;
laz[rt]=0;
}

inline void build(int rt,int l,int r){
if(l==r){
a[rt]=wt[l];
if(a[rt]>mod)a[rt]%=mod;
return;
}
build(lson);build(rson);
a[rt]=(a[rt<<1]+a[rt<<1|1])%mod;
}

inline void query(int rt,int l,int r,int L,int R){
if(L<=l&&r<=R){res+=a[rt];res%=mod;return;}
else{
if(laz[rt])pushdown(rt,len);
if(L<=mid)query(lson,L,R);
if(R>mid)query(rson,L,R);
}
}

inline void update(int rt,int l,int r,int L,int R,int k){
if(L<=l&&r<=R){
laz[rt]+=k;
a[rt]+=k*len;
}
else{
if(laz[rt])pushdown(rt,len);
if(L<=mid)update(lson,L,R,k);
if(R>mid)update(rson,L,R,k);
a[rt]=(a[rt<<1]+a[rt<<1|1])%mod;
}
}

inline int qRange(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=0;
query(1,1,n,id[top[x]],id[x]);
ans+=res;ans%=mod;
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
res=0;
query(1,1,n,id[x],id[y]);
ans+=res;
return ans%mod;
}

inline void updRange(int x,int y,int k){
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,1,n,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(1,1,n,id[x],id[y],k);
}

inline int qSon(int x){
res=0;
query(1,1,n,id[x],id[x]+siz[x]-1);
return res;
}

inline void updSon(int x,int k){
update(1,1,n,id[x],id[x]+siz[x]-1,k);
}

inline void dfs1(int x,int f,int deep){
dep[x]=deep;fa[x]=f;siz[x]=1;
int maxson=-1;
for(Rint i=beg[x];i;i=nex[i]){
int y=to[i];
if(y==f)continue;
dfs1(y,x,deep+1);
siz[x]+=siz[y];
if(siz[y]>maxson)son[x]=y,maxson=siz[y];
}
}

inline void dfs2(int x,int topf){
id[x]=++cnt;wt[cnt]=w[x];top[x]=topf;
if(!son[x])return;
dfs2(son[x],topf);
for(Rint i=beg[x];i;i=nex[i]){
int y=to[i];
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);
}
}

int main(){
read(n);read(m);read(r);read(mod);
for(Rint i=1;i<=n;i++)read(w[i]);
for(Rint i=1;i<n;i++){
int a,b;
read(a);read(b);
add(a,b);add(b,a);
}
dfs1(r,0,1);dfs2(r,r);build(1,1,n);
while(m--){
int k,x,y,z;
read(k);
if(k==1){read(x);read(y);read(z);updRange(x,y,z);}
else if(k==2){read(x);read(y);printf("%d\n",qRange(x,y));}
else if(k==3){read(x);read(y);updSon(x,y);}
else{read(x);printf("%d\n",qSon(x));}
}
}

后记

代码其实不是很长(嘿嘿嘿光是模板就一百五十行?!
后面一段接近口糊我不是很好意思…(逃
要不是做题目的时候有点生疏了,我也不会把写这篇博客提上日程…

欢迎指正联系方式如下…

qq: 953559040

微博: IncinblePan

洛谷: SherlockPan