CCF CSP 202112-5 极差路径

题意:给你一棵树,定义一条路径(x,y)被推荐的,当且仅当

\displaystyle \min (x,y) - k_1 \leq \min P(x,y) \leq \max P(x,y) \leq \max (x,y) + k_2

\min P(x,y) 表示从xy的简单路径上的编号最小值,\max P(x,y)同理。求被推荐的路径条数,(x,y)(y,x)被视为同一条路径。

容易想到这种路径计数还带\min \max的无非就三种做法:并查集、树上DP/按秩合并、点分治。前面两种做法局限性比较强,而这里的条件比较多,难以维护,于是我们考虑直接用点分治来做。假如我们目前处理的根节点为u,我们需要统计所有经过u的合法路径。从u开始dfs一遍之后我们得到了从u出发到各个点的路径的\min \max,我们用三元组去表示,记作(v,minv,maxv)。然后合法路径的条数就是满足v_1 < v_2, v_1 - k_1 \leq min(minv_1,minv_2),max(maxv_1,maxv_2) \leq v_2 + k_2的三元组(v_1,minv_1,maxv_1), (v_2,minv_2,maxv_2)的对数。这就是一个三维数点问题了,可以用CDQ在O(n \log^2 n)时间或者用可持久化线段树在O(n \log n)时间复杂度内解决。如果套上点分治的话,CDQ复杂度就有点高了,于是我选择了可持久化线段树。

具体如何使用可持久化线段树实现这个三维数点,方法有很多。我这里是按照点的编号将三元组排序后,按编号从小到大依次处理。v_1 - k_1 \leq min(minv_1,minv_2)可以拆分成v_1 - k_1 \leq minv_1v_1 - k_1 \leq minv_2,前面的式子可以直接判断,后面的式子的处理将会在最后说明。同样的,我们将max(maxv_1,maxv_2) \leq v_2 + k_2拆分成maxv_1 \leq v_2 + k_2maxv_2 \leq v_2 + k_2,后面的式子可以直接判断。

最后我们还剩下两个限制v_1 - k_1 \leq minv_2maxv_1 \leq v_2 + k_2。只需要在root[v_1-k_1]线段树的maxv_1位置进行+1操作。然后枚举到v_2时,对应的以v_2为大端的路径条数就是root[minv_2]中区间1 \cdots v_2+k_2的和。由于v_i是从小到大枚举的,所以可以直接维护。

#include<iostream>
#include<cstdio>
#include<vector>
#include<algorithm>
#pragma GCC optimize(3)
using namespace std;
const int maxn = 5e5+10;
const int INF = 1e9 + 10;
vector<int> g[maxn];
int S, Mx, K1, K2, n, root;
int sm[maxn], mxson[maxn], vis[maxn];
char buf[1<<23],*p1=buf,*p2=buf,obuf[1<<23],*O=obuf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
inline int rd() {
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)) x=x*10+(ch^48),ch=getchar();
    return x*f;
}
void getrt(int u, int fa){
    sm[u] = 1;
    mxson[u] = 0;
    for (int v : g[u]) if (!vis[v] && v != fa){
        getrt(v, u);
        sm[u] += sm[v];
        mxson[u] = max(mxson[u], sm[v]);
    }
    mxson[u] = max(mxson[u], S - sm[u]);
    if (mxson[u] < Mx){
        root = u;
        Mx = mxson[u];
    }
}
void get(int u, int fa, vector<int> & nodes, pair<int, int> *value, int mn, int mx) {
    nodes.push_back(u);
    value[u].first = mn;
    value[u].second = mx;
    for (int v : g[u]) if (!vis[v] && v != fa){
        get(v, u, nodes, value, min(mn, v), max(mx, v));
    }
}
int rt[maxn], sz[maxn * 20], ch[maxn * 20][2], top = 0;
int newnode(int x){
    sz[++top] = sz[x];
    ch[top][0] = ch[x][0];
    ch[top][1] = ch[x][1];
    return top;
}
void ins(int &rt, int l, int r, int val) {
    rt = newnode(rt);
    sz[rt]++;
    int t = rt; 
    while (l < r){
        int mid = l + r >> 1;
        if (val <= mid){
            ch[t][0] = newnode(ch[t][0]), t = ch[t][0], sz[t]++, r = mid;   
        }else{
            ch[t][1] = newnode(ch[t][1]), t = ch[t][1], sz[t]++, l = mid + 1;
        }
    }
}
int get(int rt, int l, int r, int x){
    int cnt = 0;
    while (l < r){
        int mid = l + r >> 1;
        if (x <= mid) rt = ch[rt][0], r = mid;
        else cnt += sz[ch[rt][0]], rt = ch[rt][1], l = mid + 1;
    }
    cnt += sz[rt];
    return cnt;
}
long long solve(int v, int mn, int mx){
    vector<int> nodes;
    static int w[maxn];
    static pair<int, int> value[maxn];
    get(v, 0, nodes, value, min(mn, v), max(mx, v));
    for (int i = 0; i < nodes.size(); i++) w[i] = nodes[i];
    sort(w, w + nodes.size());
    long long cnt = 0;
    top = 0;
    rt[0] = 0;
    for (int i = 0; i < nodes.size(); i++) {
        auto p = value[w[i]];
        if (i) rt[i] = rt[i - 1];
        if (w[i] - K1 <= p.first) ins(rt[i], 1, n, p.second);
        if (w[i] + K2 >= p.second) {
            int nv = p.first + K1;
            //int pos = min(int(upper_bound(w, w + nodes.size(), nv) - w - 1), i);
            int l = -1, r = nodes.size() - 1;
            while (l < r) {
                int mid = l + r + 1 >> 1;
                if (w[mid] > nv) r = mid - 1; else l = mid; 
            }
            int pos = min(l, i);
            if (pos >= 0) {
                cnt += get(rt[pos], 1, n, w[i] + K2);
            }
        }
    }
    return cnt;
}
long long ans = 0;
void Divide(int rt){
    ans += solve(rt, INF, 0);
    vis[rt] = 1;
    for (int v : g[rt]) if (!vis[v]){
        ans -= solve(v, rt, rt);
        S = sm[v];
        root = 0;
        Mx = INF;
        getrt(v, 0);
        Divide(root);
    }
}
int main(){
    n = rd();
    K1 = rd();
    K2 = rd(); 
    for (int i = 1; i < n; i++){
        int u, v;
        u = rd();
        v = rd();
        g[u].push_back(v);
        g[v].push_back(u);
    }
    Mx = INF;
    S = n;
    getrt(1, 0);
    Divide(root);
    cout << ans << endl;
}

背景:

报名了今年3月份的csp认证,于是尝试模拟了一下上次的csp认证试题。

前面4题没啥好说的,但这题的难度突然上升,思前想后找到了一个点分+三维数点的做法,写完后直接TLE。然后经历了2个小时的卡常过后,从84分卡到了96分,还有4分实在卡不下去了(卡吐了)。