bzoj 3653 谈笑风生 题解

Description

题目链接

设T 为一棵有根树,我们做如下的定义:

设a和b为T 中的两个不同节点。如果a是b的祖先,那么称“a比b不知道高明到哪里去了”。

设a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定常数x,那么称“a 与b 谈笑风生”。

给定一棵n个节点的有根树T,节点的编号为1 到 n,根节点为1号节点。你需要回答q 个询问,询问给定两个整数p和k,问有多少个有序三元组(a;b;c)满足:

  1. a、b和 c为 T 中三个不同的点,且 a为p 号节点;
  2. a和b 都比 c不知道高明到哪里去了;
  3. a和b 谈笑风生。这里谈笑风生中的常数为给定的 k。

Solution

分类讨论+dfs序+主席树

若$b$在$a$上方,$ans=min( dep[a]-1,k )\times ( sz[a] -1 )$

若$b$在$a$下方,主席树最外层$dfs$序,内层套线段树求区间和,维护子树大小。在$dep[a]+1$到$dep[a]+k$范围内求和$sz[TT]$。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read(){
int res=0,f=1;char ch=getchar();
while(!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
while(isdigit(ch)) res=(res<<3)+(res<<1)+(ch&15),ch=getchar();
return res*f;
}
int n,q,g[300010],cnt,dep[300010],cnd,rt[300010];
LL sz[300010];
int fir[300010],nxt[300010<<1],son[300010<<1],tot;
pair<int,int> id[300010];
struct node{
int l,r;
LL sum;
}tr[300010<<5];
inline void add(int x,int y){++tot;nxt[tot]=fir[x];fir[x]=tot;son[tot]=y;}
inline void dfs(int x,int fa){
++cnt;
id[x].first=cnt;
dep[x]=dep[fa]+1;
sz[x]=1;
g[cnt]=x;
for(int to,i=fir[x];i;i=nxt[i]){
to=son[i];
if(to==fa) continue ;
dfs(to,x);
sz[x]+=sz[to];
}
id[x].second=cnt;
}
inline void insert(int &x,int l,int r,int pos,int v){
++cnd;
tr[cnd]=tr[x];
x=cnd;
tr[x].sum+=(LL)v;
if(l==r) return ;
int mid=l+r>>1;
if(pos<=mid) insert(tr[x].l,l,mid,pos,v);
else insert(tr[x].r,mid+1,r,pos,v);
}
inline LL query(int x,int y,int l,int r,int L,int R){
if(L<=l&&r<=R) return tr[y].sum-tr[x].sum;
int mid=l+r>>1;
LL res=0;
if(L<=mid) res+=query(tr[x].l,tr[y].l,l,mid,L,R);
if(R>mid) res+=query(tr[x].r,tr[y].r,mid+1,r,L,R);
return res;
}
int main(){
n=read(),q=read();
for(int u,v,i=1;i<n;i++){
u=read(),v=read();
add(u,v),add(v,u);
}
dfs(1,0);
for(int i=1;i<=n;i++) rt[i]=rt[i-1],insert(rt[i],1,n,dep[g[i]],sz[g[i]]-1);
for(int p,k,i=1;i<=q;i++){
p=read(),k=read();
LL ans=query(rt[id[p].first],rt[id[p].second],1,n,dep[p]+1,dep[p]+k);
ans+=(LL)min(k,dep[p]-1)*(sz[p]-1);
printf("%lld\n",ans);
}
}