珂朵莉树 学习笔记

什么是珂朵莉树?

珂朵莉树,又称$ODT(Old Driver Tree)$,是一个基于$std::set$的暴力、玄学数据结构。

什么时候使用珂朵莉树?

如果有一题涉及到区间赋值(即把区间内所有的数全部赋值成同一个量)的操作,且数据随机,就可以考虑使用珂朵莉树。

下面以一道例题CF896来详解珂朵莉树。

Description

给你$n$个数,要求进行$m$次操作。

  1. 区间加
  2. 区间赋值
  3. 求区间第$k$小
  4. 求区间每个数的幂的和

保证数据随机

Solution

珂朵莉树的板子

首先定义珂朵莉树

1
2
3
4
5
6
7
8
#define It set<node>::iterator//由于set的指针定义实在是太长了,这里先define下
struct node{
int l,r;//每一个区间的范围
mutable LL val;//该区间内所有的数字都是val
node(int L,int R=-1,LL V=0):l(L),r(R),val(V){}//生成一个node
bool operator<(const node& q)const{return l<q.l;}//set按照l红黑树
};
set<node> s;

不知道$mutable$?点击此了解详情

珂朵莉树操作

核心操作——split

如果要修改某一个区间,那么肯定要把区间拆出来。

1
2
3
4
5
6
7
8
9
I It split(int x){
It it=s.lower_bound(node(x));//寻找第一个不小于x的点
if(it!=s.end()&&it->l==x) return it;//如果该区间刚好以x作为左端点,那么直接返回
it--;//否则一定在前一个
int L=it->l,R=it->r;LL V=it->val;//[L,R]区间值为V
s.erase(it);//删除
s.insert(node(L,x-1,V));//插入左边的区间
return s.insert(node(x,R,V)).first;//返回有右边需要的区间
}
推平操作——assign
1
2
3
4
5
I void assign(int l,int r,LL val){
It it2=split(r+1),it1=split(l);//求出区间指针
s.erase(it1,it2);//全部删除
s.insert(node(l,r,val));//新建一个推平区间
}

珂朵莉树是靠推平操作来减小复杂度的。由于数据随机,就有约$\frac{1}{4}$的操作是推平操作,使得$set$大小飞速下降,从而保证了复杂度。

那么为啥要先$split(r+1)$呢?因为如果先$split(l)$根据$split$中的$erase$操作,迭代器$it1$可能会失效。(因为$it1$所属的节点可能被删除了)

区间加
1
2
3
4
I void add(int l,int r,LL val){
It it2=split(r+1),it1=split(l);//求出区间指针
for(;it1!=it2;it1++) it1->val+=val;//暴力扫一次直接加
}
求区间第k小
1
2
3
4
5
6
7
8
9
10
I LL grank(int l,int r,int x){
It it2=split(r+1),it1=split(l);//求出区间指针
vector<pair<LL,int> > tmp;tmp.clear();//临时定义个vector
for(;it1!=it2;it1++) tmp.push_back(make_pair(it1->val,it1->r-it1->l+1));//将区间内所有数字插入vector
sort(tmp.begin(),tmp.end());//将所有数字排序
for(vector<pair<LL,int> >::iterator it=tmp.begin();it!=tmp.end();it++){//扫一次找第k小
x-=it->second;//因为区间上有r-l+1个相同的数,所以一次减去r-l+1
if(x<=0) return it->first;
}
}
求区间每个数的幂的和
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
inline LL fpow(LL a,LL b,LL mod){//快速幂
LL s=1;a%=mod;
while(b){
if(b&1) s*=a,s%=mod;
a*=a,a%=mod;
b>>=1;
}
return s;
}
I LL sum(int l,int r,LL x,LL y){
It it2=split(r+1),it1=split(l);//求出区间指针
LL s=0;
for(;it1!=it2;it1++) s+=(LL)(it1->r-it1->l+1)*fpow(it1->val,x,y)%y,s%=y;//暴力扫每个数求和,别忘记一个区间应该乘上r-l+1
return s;
}

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include<bits/stdc++.h>
#define I inline
#define LL long long
#define It set<node>::iterator
using namespace std;
inline LL read(){
LL res=0,f=1;char ch=getchar();
while(ch<'0'ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') res=res*10+ch-'0',ch=getchar();
return res*f;
}
inline void write(LL x){
if(x<0) putchar('-'),x=-x;
if(x<10) putchar(x+'0');
else{
write(x/10);
putchar(x%10+'0');
}
}
struct node{
int l,r;
mutable LL val;
node(int L,int R=-1,LL V=0):l(L),r(R),val(V){}
bool operator<(const node& q)const{return l<q.l;}
};
set<node> s;
inline LL fpow(LL a,LL b,LL mod){
LL s=1;a%=mod;
while(b){
if(b&1) s*=a,s%=mod;
a*=a,a%=mod;
b>>=1;
}
return s;
}
I It split(int x){
It it=s.lower_bound(node(x));
if(it!=s.end()&&it->l==x) return it;
it--;
int L=it->l,R=it->r;LL V=it->val;
s.erase(it);
s.insert(node(L,x-1,V));
return s.insert(node(x,R,V)).first;
}
I void assign(int l,int r,LL val){
It it2=split(r+1),it1=split(l);
s.erase(it1,it2);
s.insert(node(l,r,val));
}
I void add(int l,int r,LL val){
It it2=split(r+1),it1=split(l);
for(;it1!=it2;it1++) it1->val+=val;
}
I LL grank(int l,int r,int x){
It it2=split(r+1),it1=split(l);
vector<pair<LL,int> > tmp;tmp.clear();
for(;it1!=it2;it1++) tmp.push_back(make_pair(it1->val,it1->r-it1->l+1));
sort(tmp.begin(),tmp.end());
for(vector<pair<LL,int> >::iterator it=tmp.begin();it!=tmp.end();it++){
x-=it->second;
if(x<=0) return it->first;
}
}
I LL sum(int l,int r,LL x,LL y){
It it2=split(r+1),it1=split(l);
LL s=0;
for(;it1!=it2;it1++) s+=(LL)(it1->r-it1->l+1)*fpow(it1->val,x,y)%y,s%=y;
return s;
}
int n,m;
LL seed,vmax,a[100010];
I LL rnd(){
LL res=seed;
seed=(seed*7+13)%1000000007;
return res;
}
int main(){
n=read(),m=read(),seed=read(),vmax=read();
for(int i=1;i<=n;i++){
a[i]=rnd()%vmax+1;
s.insert(node(i,i,a[i]));
}
s.insert(node(n+1,n+1,0));
for(int i=1;i<=m;i++){
LL op,l,r,x,y;
op=rnd()%4+1,l=rnd()%n+1,r=rnd()%n+1;
if(l>r) swap(l,r);
if(op==3) x=rnd()%(r-l+1)+1;
else x=rnd()%vmax+1;
if(op==4) y=rnd()%vmax+1;
if(op==1) add(l,r,x);
else if(op==2) assign(l,r,x);
else if(op==3) write(grank(l,r,x)),putchar('\n');
else if(op==4) write(sum(l,r,x,y)),putchar('\n');
}
return 0;
}