题意:给出一棵带边权树,询问有多少点对的距离小于等于\(k\)
本题解参考 lyd 的算法竞赛进阶指南,讲解的十分清晰,比网上那些讲的乱七八糟的好多了
不过写起来还是困难重重(史诗巨作
打完多校更详细做法
对于所有路径,以某个节点 u 来看分为两种情况
1.经过 u 的路径
2.不经过 u 的路径
能对答案有贡献的肯定是 1 类型,对 2 我们处理完 1 后递归求解
于是条件变为
以 u 为根的树中,其联通块中的点对符合
距离 u 之和小于等于 k ①
且位于各不同的 u 的子树(u 单独处理) ②
不妨先维护①再维护②
每一次选定根 u 后预处理联通块内的点相对于 u 的距离 dis,以及归属于哪个子树 belong,
并且把它们放入数组 vec 中,清空存在的 cnt(cnt[t]表示归属于子树 t 的点的个数有多少个,其中 cnt[u]设为单独节点并不包含其它子树)
处理完后对 vec 数组的节点按 dis 排序,这样做的目的是为了\(O(n)\)处理贡献
贡献该怎么算?如果在 vec 中设两个指针 L,R,满足最小的 L 和最大的 R 符合 \(dis[vec[L]]+dis[vec[R]]≤k\),这就满足了①条件
那么②呢?显然是在统计完 \([L+1,R]\) 范围内的 \(cnt[belong[vec[i]]]\) 后求 \(R-(L+1)-1-cnt[belong[vec[L]]]\)
这个时候 vec 对 dis 排序就显得很有用,k 是恒定的,随着 L 指针的右移,R 指针也一定是左移,直到\(L=R\)就表明以后无论怎么移动都不会符合条件①了
在指针移动的过程中我们顺便完成了对 cnt 的统计更新,所以只需 L 从头到尾扫一遍,每个元素至多被指针遍历 4 次,整个操作是\(O(n)\)的 (虽然排序是 nlogn 的
由此我们完成了对 u 有关的第一个情况的所有路径统计
第二个情况只需把 u 删去 (block 标记,删去也意味着所有经过 u 的路径不会再遍历,因为不会再有任何贡献了) 递归接下来的子树即可
为了避免单链型的 \(O(n)\) 次递归,我们选择将每一个子树的重心作为根进行处理,以达到 \(O(logn)\) 的最坏情况
具体的先 dfs 一下更新子树大小得出联通块大小 V 再不断比对删除节点后的最大联通块的最小值就能找到了
总而言之我们在 \(O(nlog^2n)\) 的时间完成了传说中的点分治啦
八分之三的男人达成√
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<string>
#include<vector>
#include<stack>
#include<queue>
#include<set>
#include<map>
//#include<unordered_map>
#define rep(i,j,k) for(register int i=j;i<=k;i++)
#define rrep(i,j,k) for(register int i=j;i>=k;i--)
#define erep(i,u) for(register int i=head[u];~i;i=nxt[i])
#define print(a) printf("%lld",(ll)a)
#define println(a) printf("%lld\n",(ll)a)
using namespace std;
const int MAXN = 1e5+11;
const int INF = 0x3f3f3f3f;
const double EPS = 1e-7;
typedef long long ll;
ll read(){
ll x=0,f=1;register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n,k;
bool vis[MAXN],block[MAXN];
int dis[MAXN],belong[MAXN],cnt[MAXN],sz[MAXN];
int to[MAXN<<1],nxt[MAXN<<1],cost[MAXN<<1],head[MAXN],tot;
vector<int> vec;
void init(){
memset(head,-1,sizeof head);
memset(vis,0,sizeof vis);
memset(block,0,sizeof block);
tot=0;
}
void add(int u,int v,int w){
to[tot]=v;cost[tot]=w;
nxt[tot]=head[u];head[u]=tot++;
}
int mxsize,mxson,V;//V==N // each V of subtrees can obtain using sz
int getroot(int u,int fa){
sz[u]=1;
int tmp=0;//max size of subtrees while u deleted
erep(i,u){
int v=to[i];
if(v==fa) continue;
if(block[v]) continue;
getroot(v,u);
sz[u]+=sz[v];
tmp=max(tmp,sz[v]);
}
tmp=max(tmp,V-sz[u]);
if(tmp<mxsize){
mxsize=tmp;
mxson=u;// top and down compared
}
return mxson;
}
bool cmp(int a,int b){return dis[a]<dis[b];}
void prepare(int u,int fa,int d,int son,int rt){
dis[u]=d; vec.push_back(u);
belong[u]=son; cnt[son]=0;
erep(i,u){
int v=to[i],w=cost[i];
if(v==fa) continue;
if(block[v]) continue;
prepare(v,u,d+w,son==rt?v:son,rt);
}
}
int getV(int u,int fa){
sz[u]=1;
erep(i,u){
int v=to[i];
if(v==fa) continue;
if(block[v]) continue;
getV(v,u);
sz[u]+=sz[v];
}
return sz[u];
}
ll solve(int u,int fa){
if(vis[u]||u<1) return 0;
vis[u]=1;
ll ans=0;
vec.clear();
prepare(u,fa,0,u,u);
sort(vec.begin(),vec.end(),cmp);
int L,R=vec.size();R--;
for(int i=0;i<vec.size();i++) cnt[belong[vec[i]]]++;
bool flag=0;
for(int i=0;i+1<(int)vec.size();i++){//enum L
L=i;cnt[belong[vec[L]]]--;
while(dis[vec[L]]+dis[vec[R]]>k){
if(L>=R) break;
cnt[belong[vec[R]]]--;
R--;
}
if(L>=R) break;
ans+=(ll)R-L-cnt[belong[vec[L]]];
}
while(L<vec.size()) cnt[belong[vec[L++]]]=0;
block[u]=1;
erep(i,u){
int v=to[i];
if(v==fa) continue;
if(block[v]) continue;
V=getV(v,u);
mxson=v,mxsize=INF;
int rt=getroot(v,u);
ans+=solve(rt,u);
}
return ans;
}
int main(){
while(cin>>n>>k){
if(n==0&&k==0) break;
init();
rep(i,1,n-1){
int u=read();
int v=read();
int w=read();
add(u,v,w);
add(v,u,w);
}
mxsize=INF;mxson=1;V=n;
int rt=getroot(1,-1);
println(solve(rt,-1));
}
return 0;
}