# 树链剖分

求解问题:在树上进行区间修改区间查询问题,求 lca 问题,维护路径信息

主要思想:将树上的点分割成一条一条的链,每一条链的第一个点是链头 (父亲),利用 dfs 序按照链优先的思想加上序号,这样每一条链上面的序号都是连续的,就把树上的点映射到了一条数轴上

时复:找到父亲,重子节点,子树大小(dfs1)O (N),进行一次 dfs 序 (dfs2 O (N)),每一条路径都能被分割成最多 log2n 条链,因此链头数量不超过 log2n,每次求 lca 时复 logn

# 重链剖分

dfn 数组:每一个点映射到链上的标号

rnk 数组:每一个标号对应点的编号 (rnk [dfn [x]]=x)

dep 数组:每一个节点的深度

fa 数组:节点父亲

siz 数组:子树大小

son 数组:重孩子

top 数组:节点在链上链头的编号

以上 7 个数组是树链剖分的几个必要数组,根据题目不同会使用上面的某几个数组

定义:

  1. 重子节点是子树节点最多的那棵树的根节点,如果有多个随意取出一个即可
  2. 剩下的非重子节点的点都是轻子节点
  3. 从当前节点到重子节点的边是重边
  4. 从当前节点到轻子节点的边是轻边
  5. 若干条首尾相连的重边称为重链,所有落单的点也当作重链
HLD

利用以上定义可以将一棵树分成若干条链,这些链上的 dfs 序号是连续的

# 实现

void dfs1(int x){
	siz[x]=1;  // 当前子树大小为 1
    son[x]=-1;  
    for(int i=head[x];~i;i=e[i].next){
        int v=e[i].to;
        if(v==fa[x]) continue;
        fa[v]=x;  // 深搜之前一定要更新父亲
        dep[v]=dep[x]+1;  // 往下深搜之前一定要先把深度给更新了
        dfs1(v);
        siz[x]=siz[v]+1;  // 深搜过后才能更新子树大小
        if(son[x]==-1 || siz[son[x]]<siz[v]) son[x]=v;  // 找到重子节点
    }
}
void dfs2(int x,int t){
	top[x]=t;  // 这条链的链头
    cnt++;  //dfs 序加一
    dfn[x]=cnt;  // 给节点标上 dfs 序
    rnk[cnt]=x;  // 返回 dfs 序对应的节点
    if(son[x]==-1) return ;
    dfs2(son[x],t);  // 优先遍历重子节点,到重子节点的边都是重链,所以链头不变
    for(int i=head[x];~i;i=e[i].next){
		int v=e[i].to;
        if(v==fa[x] || v==son[x]) continue;
        dfs2(v,v);  // 到轻子节点,链头要变化
    }
}

# 树上单点修改区间查询

求两个点的区间值时,在找 lca 的过程中保存信息,最后输出答案即可

#include <bits/stdc++.h>
//#pragma G++ optimize(2)
//#pragma G++ optimize(3,"Ofast","inline")
#define endl '\n'
#define debug freopen("in.txt","r",stdin); freopen("out.txt","w",stdout)
#define ios ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int MAXN=1e6+100;
const int MOD=1e9+7;
const int INF=0x3f3f3f3f;
const int SUB=-0x3f3f3f3f;
const double eps=1e-4;
const double E=exp(1);
const double pi=acos(-1);
struct edge{
	int to,next;
}e[MAXN];
int head[MAXN],w[MAXN];
int tot;
void add(int u,int v){
	e[tot]={v,head[u]};
	head[u]=tot++;
}
int n,a,b,q;
int dfn[MAXN],siz[MAXN],dep[MAXN],fa[MAXN];
int son[MAXN],top[MAXN],rnk[MAXN];
int cnt;
void dfs(int x,int las){ // 求出 dep、son、siz、fa
	son[x]=-1;
	siz[x]=1;
	for(int i=head[x];~i;i=e[i].next){
		int v=e[i].to;
		if(v==las) continue;
		fa[v]=x;
		dep[v]=dep[x]+1;
		dfs(v,x);
		siz[x]+=siz[v];
		if(son[x]==-1 || siz[v]>siz[son[x]]) son[x]=v;
	}
}
void dfs2(int x,int t){ // 求出 dfn、rnk、top
	++cnt;
	dfn[x]=cnt;
	rnk[cnt]=x;
	top[x]=t;
	if(son[x]!=-1) dfs2(son[x],t);
	for(int i=head[x];~i;i=e[i].next){
		int v=e[i].to;
		if(v==fa[x] || v==son[x]) continue;
		dfs2(v,v);
	}
}
#define lson (u<<1)
#define rson (u<<1|1)
struct node{
	int l,r,mx,sum;
}tr[MAXN<<2];
void pushup(int u){
	tr[u].mx=max(tr[lson].mx,tr[rson].mx);
	tr[u].sum=tr[lson].sum+tr[rson].sum;
}
void build(int u,int l,int r){
	if(l==r){
		tr[u]={l,r,w[rnk[l]],w[rnk[l]]};
		return ;
	}
	tr[u]={l,r};
	int mid=l+r>>1;
	build(lson,l,mid);
	build(rson,mid+1,r);
	pushup(u);
}
void update(int u,int pos,int val){
	if(tr[u].l==tr[u].r){
		tr[u].mx=val;
		tr[u].sum=val;
		return ;
	}
	int mid=tr[u].l+tr[u].r>>1;
	if(pos<=mid) update(lson,pos,val);
	else update(rson,pos,val);
	pushup(u);
}
int qmax(int u,int ql,int qr){
	if(tr[u].l>=ql && tr[u].r<=qr) return tr[u].mx;
	int ret=-INF;
	int mid=tr[u].l+tr[u].r>>1;
	if(ql<=mid) ret=max(ret,qmax(lson,ql,qr));
	if(mid+1<=qr) ret=max(ret,qmax(rson,ql,qr));
	return ret;
}
int query_mx(int u,int x,int y){
	int l=x,r=y;
	int ret=-INF;
	while(top[x]!=top[y]){
		if(dep[top[x]]>dep[top[y]]){
			ret=max(ret,qmax(1,dfn[top[x]],dfn[x]));
			x=fa[top[x]];
		}
		else{
			ret=max(ret,qmax(1,dfn[top[y]],dfn[y]));
			y=fa[top[y]];
		}
	}
	if(dfn[x]>dfn[y]) ret=max(ret,qmax(1,dfn[y],dfn[x]));
	else ret=max(ret,qmax(1,dfn[x],dfn[y]));
	return ret;
}
int qsum(int u,int ql,int qr){
	if(tr[u].l>=ql && tr[u].r<=qr) return tr[u].sum;
	int ret=0;
	int mid=tr[u].l+tr[u].r>>1; 
	if(ql<=mid) ret+=qsum(lson,ql,qr);
	if(qr>=mid+1) ret+=qsum(rson,ql,qr);
	return ret;
}
int query_sum(int u,int x,int y){
	int ret=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]>dep[top[y]]){
			ret+=qsum(1,dfn[top[x]],dfn[x]);
			x=fa[top[x]];
		}
		else{
			ret+=qsum(1,dfn[top[y]],dfn[y]);
			y=fa[top[y]];
		}
	}
	if(dfn[x]>dfn[y]) ret+=qsum(1,dfn[y],dfn[x]);
	else ret+=qsum(1,dfn[x],dfn[y]);
	return ret;
}
int main(){
	// debug;
	ios;
	memset(head,-1,sizeof head);
	cin>>n;
	for(int i=1;i<=n-1;i++){
		cin>>a>>b;
		add(a,b);
		add(b,a);
	}
	dfs(1,-1);
	dfs2(1,1);
	for(int i=1;i<=n;i++) cin>>w[i];
	build(1,1,n);
	// cout<<tr[1].mx<<endl;
	// cout<<qmax(1,1,4)<<endl;
	// for(int i=1;i<=n;i++) cout<<dfn[i]<<" ";
	// cout<<'\n';
	// for(int i=1;i<=n;i++) cout<<son[i]<<" ";
	// cout<<endl;
	// for(int i=1;i<=n;i++) cout<<dep[i]<<" ";
	// cout<<endl;
	// for(int i=1;i<=n;i++) cout<<rnk[i]<<" ";
	// cout<<endl;
	cin>>q;
	while(q--){
		string s;
		int u,v;
		cin>>s>>u>>v;
		if(s=="CHANGE") update(1,dfn[u],v);
		else if(s=="QMAX") cout<<query_mx(1,u,v)<<endl;
		else cout<<query_sum(1,u,v)<<endl;
		// if(u==3 && v==6) cout<<qmax(1,3,4)<<"ss"<<endl;
	}
	return 0;
}
/*
*/
更新于

请我喝[茶]~( ̄▽ ̄)~*

PocketCat 微信支付

微信支付

PocketCat 支付宝

支付宝