题意:给出一棵带边权树,询问有多少点对的距离小于等于\(k\)

本题解参考 lyd 的算法竞赛进阶指南,讲解的十分清晰,比网上那些讲的乱七八糟的好多了

不过写起来还是困难重重(史诗巨作

打完多校更详细做法

对于所有路径,以某个节点 u 来看分为两种情况

1.经过 u 的路径

2.不经过 u 的路径

能对答案有贡献的肯定是 1 类型,对 2 我们处理完 1 后递归求解

于是条件变为

以 u 为根的树中,其联通块中的点对符合

距离 u 之和小于等于 k ①

且位于各不同的 u 的子树(u 单独处理) ②

不妨先维护①再维护②

每一次选定根 u 后预处理联通块内的点相对于 u 的距离 dis,以及归属于哪个子树 belong,

并且把它们放入数组 vec 中,清空存在的 cnt(cnt[t]表示归属于子树 t 的点的个数有多少个,其中 cnt[u]设为单独节点并不包含其它子树)

处理完后对 vec 数组的节点按 dis 排序,这样做的目的是为了\(O(n)\)处理贡献

贡献该怎么算?如果在 vec 中设两个指针 L,R,满足最小的 L 和最大的 R 符合 \(dis[vec[L]]+dis[vec[R]]≤k\),这就满足了①条件

那么②呢?显然是在统计完 \([L+1,R]\) 范围内的 \(cnt[belong[vec[i]]]\) 后求 \(R-(L+1)-1-cnt[belong[vec[L]]]\)

这个时候 vec 对 dis 排序就显得很有用,k 是恒定的,随着 L 指针的右移,R 指针也一定是左移,直到\(L=R\)就表明以后无论怎么移动都不会符合条件①了

在指针移动的过程中我们顺便完成了对 cnt 的统计更新,所以只需 L 从头到尾扫一遍,每个元素至多被指针遍历 4 次,整个操作是\(O(n)\)的 (虽然排序是 nlogn 的

由此我们完成了对 u 有关的第一个情况的所有路径统计

第二个情况只需把 u 删去 (block 标记,删去也意味着所有经过 u 的路径不会再遍历,因为不会再有任何贡献了) 递归接下来的子树即可

为了避免单链型的 \(O(n)\) 次递归,我们选择将每一个子树的重心作为根进行处理,以达到 \(O(logn)\) 的最坏情况

具体的先 dfs 一下更新子树大小得出联通块大小 V 再不断比对删除节点后的最大联通块的最小值就能找到了

总而言之我们在 \(O(nlog^2n)\) 的时间完成了传说中的点分治啦

八分之三的男人达成√

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<string>
#include<vector>
#include<stack>
#include<queue>
#include<set>
#include<map>
//#include<unordered_map>
#define rep(i,j,k) for(register int i=j;i<=k;i++)
#define rrep(i,j,k) for(register int i=j;i>=k;i--)
#define erep(i,u) for(register int i=head[u];~i;i=nxt[i])
#define print(a) printf("%lld",(ll)a)
#define println(a) printf("%lld\n",(ll)a)
using namespace std;
const int MAXN = 1e5+11;
const int INF = 0x3f3f3f3f;
const double EPS = 1e-7;
typedef long long ll;
ll read(){
    ll x=0,f=1;register char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int n,k;
bool vis[MAXN],block[MAXN];
int dis[MAXN],belong[MAXN],cnt[MAXN],sz[MAXN];
int to[MAXN<<1],nxt[MAXN<<1],cost[MAXN<<1],head[MAXN],tot;
vector<int> vec;
void init(){
    memset(head,-1,sizeof head);
    memset(vis,0,sizeof vis);
    memset(block,0,sizeof block);
    tot=0;
}
void add(int u,int v,int w){
    to[tot]=v;cost[tot]=w;
    nxt[tot]=head[u];head[u]=tot++;
}
int mxsize,mxson,V;//V==N // each V of subtrees can obtain using sz
int getroot(int u,int fa){
    sz[u]=1;
    int tmp=0;//max size of subtrees while u deleted
    erep(i,u){
        int v=to[i];
        if(v==fa) continue;
        if(block[v]) continue;
        getroot(v,u);
        sz[u]+=sz[v];
        tmp=max(tmp,sz[v]);
    }
    tmp=max(tmp,V-sz[u]);
    if(tmp<mxsize){
        mxsize=tmp;
        mxson=u;// top and down compared
    }
    return mxson;
}
bool cmp(int a,int b){return dis[a]<dis[b];}
void prepare(int u,int fa,int d,int son,int rt){
    dis[u]=d; vec.push_back(u);
    belong[u]=son; cnt[son]=0;
    erep(i,u){
        int v=to[i],w=cost[i];
        if(v==fa) continue;
        if(block[v]) continue;
        prepare(v,u,d+w,son==rt?v:son,rt);
    }
}
int getV(int u,int fa){
    sz[u]=1;
    erep(i,u){
        int v=to[i];
        if(v==fa) continue;
        if(block[v]) continue;
        getV(v,u);
        sz[u]+=sz[v];
    }
    return sz[u];
}

ll solve(int u,int fa){
    if(vis[u]||u<1) return 0;
    vis[u]=1;
    
    ll ans=0;
    vec.clear();
    prepare(u,fa,0,u,u);
    sort(vec.begin(),vec.end(),cmp);
    int L,R=vec.size();R--;
    for(int i=0;i<vec.size();i++) cnt[belong[vec[i]]]++;
    bool flag=0;
    for(int i=0;i+1<(int)vec.size();i++){//enum L
        L=i;cnt[belong[vec[L]]]--;
        while(dis[vec[L]]+dis[vec[R]]>k){
            if(L>=R) break;
            cnt[belong[vec[R]]]--;
            R--;
        }
        if(L>=R) break;
        ans+=(ll)R-L-cnt[belong[vec[L]]];
    }
    while(L<vec.size()) cnt[belong[vec[L++]]]=0;
    block[u]=1;
    erep(i,u){
        int v=to[i];
        if(v==fa) continue;
        if(block[v]) continue;
        V=getV(v,u);
        mxson=v,mxsize=INF;
        int rt=getroot(v,u);
        ans+=solve(rt,u);
    }
    return ans;
}
int main(){
    while(cin>>n>>k){
        if(n==0&&k==0) break;
        init();
        rep(i,1,n-1){
            int u=read();
            int v=read();
            int w=read();
            add(u,v,w);
            add(v,u,w);
        }
        mxsize=INF;mxson=1;V=n;
        int rt=getroot(1,-1);
        println(solve(rt,-1));
    }
    return 0;
}