树链剖分

求解问题:在树上进行区间修改区间查询问题,求lca问题,维护路径信息

主要思想:将树上的点分割成一条一条的链,每一条链的第一个点是链头(父亲),利用dfs序按照链优先的思想加上序号,这样每一条链上面的序号都是连续的,就把树上的点映射到了一条数轴上

时复:找到父亲,重子节点,子树大小(dfs1)O(N),进行一次dfs序(dfs2 O(N)),每一条路径都能被分割成最多log2n条链,因此链头数量不超过log2n,每次求lca时复logn

重链剖分

dfn数组: 每一个点映射到链上的标号

rnk数组: 每一个标号对应点的编号 (rnk[dfn[x]]=x)

dep数组: 每一个节点的深度

fa数组: 节点父亲

siz数组: 子树大小

son数组: 重孩子

top数组: 节点在链上链头的编号

以上7个数组是树链剖分的几个必要数组,根据题目不同会使用上面的某几个数组

定义:

  1. 重子节点是子树节点最多的那棵树的根节点,如果有多个随意取出一个即可
  2. 剩下的非重子节点的点都是轻子节点
  3. 从当前节点到重子节点的边是重边
  4. 从当前节点到轻子节点的边是轻边
  5. 若干条首尾相连的重边称为重链,所有落单的点也当作重链

HLD

利用以上定义可以将一棵树分成若干条链,这些链上的dfs序号是连续的

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
void dfs1(int x){
siz[x]=1; //当前子树大小为1
son[x]=-1;
for(int i=head[x];~i;i=e[i].next){
int v=e[i].to;
if(v==fa[x]) continue;
fa[v]=x; //深搜之前一定要更新父亲
dep[v]=dep[x]+1; //往下深搜之前一定要先把深度给更新了
dfs1(v);
siz[x]=siz[v]+1; //深搜过后才能更新子树大小
if(son[x]==-1 || siz[son[x]]<siz[v]) son[x]=v; //找到重子节点
}
}
void dfs2(int x,int t){
top[x]=t; //这条链的链头
cnt++; //dfs序加一
dfn[x]=cnt; //给节点标上dfs序
rnk[cnt]=x; //返回dfs序对应的节点
if(son[x]==-1) return ;
dfs2(son[x],t); //优先遍历重子节点,到重子节点的边都是重链,所以链头不变
for(int i=head[x];~i;i=e[i].next){
int v=e[i].to;
if(v==fa[x] || v==son[x]) continue;
dfs2(v,v); //到轻子节点,链头要变化
}
}

树上单点修改区间查询

求两个点的区间值时,在找lca的过程中保存信息,最后输出答案即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#include <bits/stdc++.h>
//#pragma G++ optimize(2)
//#pragma G++ optimize(3,"Ofast","inline")
#define endl '\n'
#define debug freopen("in.txt","r",stdin); freopen("out.txt","w",stdout)
#define ios ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int MAXN=1e6+100;
const int MOD=1e9+7;
const int INF=0x3f3f3f3f;
const int SUB=-0x3f3f3f3f;
const double eps=1e-4;
const double E=exp(1);
const double pi=acos(-1);
struct edge{
int to,next;
}e[MAXN];
int head[MAXN],w[MAXN];
int tot;
void add(int u,int v){
e[tot]={v,head[u]};
head[u]=tot++;
}
int n,a,b,q;
int dfn[MAXN],siz[MAXN],dep[MAXN],fa[MAXN];
int son[MAXN],top[MAXN],rnk[MAXN];
int cnt;
void dfs(int x,int las){ //求出dep、son、siz、fa
son[x]=-1;
siz[x]=1;
for(int i=head[x];~i;i=e[i].next){
int v=e[i].to;
if(v==las) continue;
fa[v]=x;
dep[v]=dep[x]+1;
dfs(v,x);
siz[x]+=siz[v];
if(son[x]==-1 || siz[v]>siz[son[x]]) son[x]=v;
}
}
void dfs2(int x,int t){ //求出dfn、rnk、top
++cnt;
dfn[x]=cnt;
rnk[cnt]=x;
top[x]=t;
if(son[x]!=-1) dfs2(son[x],t);
for(int i=head[x];~i;i=e[i].next){
int v=e[i].to;
if(v==fa[x] || v==son[x]) continue;
dfs2(v,v);
}
}
#define lson (u<<1)
#define rson (u<<1|1)
struct node{
int l,r,mx,sum;
}tr[MAXN<<2];
void pushup(int u){
tr[u].mx=max(tr[lson].mx,tr[rson].mx);
tr[u].sum=tr[lson].sum+tr[rson].sum;
}
void build(int u,int l,int r){
if(l==r){
tr[u]={l,r,w[rnk[l]],w[rnk[l]]};
return ;
}
tr[u]={l,r};
int mid=l+r>>1;
build(lson,l,mid);
build(rson,mid+1,r);
pushup(u);
}
void update(int u,int pos,int val){
if(tr[u].l==tr[u].r){
tr[u].mx=val;
tr[u].sum=val;
return ;
}
int mid=tr[u].l+tr[u].r>>1;
if(pos<=mid) update(lson,pos,val);
else update(rson,pos,val);
pushup(u);
}
int qmax(int u,int ql,int qr){
if(tr[u].l>=ql && tr[u].r<=qr) return tr[u].mx;
int ret=-INF;
int mid=tr[u].l+tr[u].r>>1;
if(ql<=mid) ret=max(ret,qmax(lson,ql,qr));
if(mid+1<=qr) ret=max(ret,qmax(rson,ql,qr));
return ret;
}
int query_mx(int u,int x,int y){
int l=x,r=y;
int ret=-INF;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]){
ret=max(ret,qmax(1,dfn[top[x]],dfn[x]));
x=fa[top[x]];
}
else{
ret=max(ret,qmax(1,dfn[top[y]],dfn[y]));
y=fa[top[y]];
}
}
if(dfn[x]>dfn[y]) ret=max(ret,qmax(1,dfn[y],dfn[x]));
else ret=max(ret,qmax(1,dfn[x],dfn[y]));
return ret;
}
int qsum(int u,int ql,int qr){
if(tr[u].l>=ql && tr[u].r<=qr) return tr[u].sum;
int ret=0;
int mid=tr[u].l+tr[u].r>>1;
if(ql<=mid) ret+=qsum(lson,ql,qr);
if(qr>=mid+1) ret+=qsum(rson,ql,qr);
return ret;
}
int query_sum(int u,int x,int y){
int ret=0;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]){
ret+=qsum(1,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
else{
ret+=qsum(1,dfn[top[y]],dfn[y]);
y=fa[top[y]];
}
}
if(dfn[x]>dfn[y]) ret+=qsum(1,dfn[y],dfn[x]);
else ret+=qsum(1,dfn[x],dfn[y]);
return ret;
}
int main(){
// debug;
ios;
memset(head,-1,sizeof head);
cin>>n;
for(int i=1;i<=n-1;i++){
cin>>a>>b;
add(a,b);
add(b,a);
}
dfs(1,-1);
dfs2(1,1);
for(int i=1;i<=n;i++) cin>>w[i];
build(1,1,n);
// cout<<tr[1].mx<<endl;
// cout<<qmax(1,1,4)<<endl;
// for(int i=1;i<=n;i++) cout<<dfn[i]<<" ";
// cout<<'\n';
// for(int i=1;i<=n;i++) cout<<son[i]<<" ";
// cout<<endl;
// for(int i=1;i<=n;i++) cout<<dep[i]<<" ";
// cout<<endl;
// for(int i=1;i<=n;i++) cout<<rnk[i]<<" ";
// cout<<endl;
cin>>q;
while(q--){
string s;
int u,v;
cin>>s>>u>>v;
if(s=="CHANGE") update(1,dfn[u],v);
else if(s=="QMAX") cout<<query_mx(1,u,v)<<endl;
else cout<<query_sum(1,u,v)<<endl;
// if(u==3 && v==6) cout<<qmax(1,3,4)<<"ss"<<endl;
}
return 0;
}

/*

*/