介绍

这篇博客主要介绍一些并查集的高级应用。

目前我所知的并查集应用有:

  1. 权值并查集
  2. 可撤销并查集
  3. 可持久化并查集(还没学)

权值并查集

略,目前我所知的方法就只有开多倍空间,维护不同元素之间的关系,经典例题见 NOI2001 食物链

可撤销并查集

意思如其名,可以将之前合并的操作撤销。

思想并不难,只要拿一个额外的 stack 记录一下每次操作 更改的信息,然后回退的时候将 stack 内的信息 pop 出来,恢复一下当时的状态即可。

有几个需要注意的点:

  1. 并查集 不能路径压缩,否则会破坏结构。为了保证时间复杂度,需要使用启发式合并。
  2. 无论 unions(u,v) 是否成功,都要记录在栈中。因为我们回退时并不知道这一步合并当时是否成功了。

例题

例1 [HNOI2016]最小公倍数

题意

给定一个 $n$ 个节点,$m$ 条边的无向图。每个边有权值,权值都可以被表示为 $2^a \times 3^b$ 的形式。

现在给出 $q$ 个询问,每次询问 $u ~ v ~ a ~ b$,我们需要回答在 $u,v$ 之间,是否存在一条路径使得路径上所有边权的 LCM 等于 $2^a \times 3^b$?

路径可以不为简单路径(可以经过重复的边)。

其中,$1 \leq n,q \leq 5 \times 10^4, 1 \leq m \leq 10^5, 0 \leq a_i,b_i \leq 10^9$。

题解

题目实际上就想要我们找,每条边的权值表示为 $(a_i,b_i)$,每次询问 $u ~ v ~ a ~ b$,回答是否存在路径使得路径上的 $\max \{a_i\} = a, \max \{b_i\} = b$?

我们先用最暴力的思路来想:

对于每次询问,我们只要把所有满足 $a_i \leq a, b_i \leq b$ 的边全都加进去,然后判断一下 $u,v$ 是否联通,并且这个联通块内 $\max \{a_i\} = a, \max \{b_i\} = b$ 是否成立即可。这可以用并查集实现。

接下来考虑优化。


对于这种拥有 $2$ 个条件的题目,一种套路是:

将第一种权值 sort 一下,然后分块。

每个块内(或者多个块一起)对第二种权值 sort。然后把所有询问放进对应的块内进行处理。

对于本题,我们先把所有的边按照 $a_i$ 的大小进行 sort,然后进行分块。

对于每一条边 $(u,v,a,b)$,我们找到一个块 $B$ 满足,在 $[1,B-1]$ 这些块内,最大的 $a_i$ 都 $\leq a$,且这个 $B$ 尽可能大。

之后,我们对于每一个块进行处理。

假设我们现在在块 $B$ 内,那么我们将 $[1,B-1]$ 这些块看作一个整体,然后对这个整体,按照 第二种权值 $b_i$ 进行 sort

然后我们将块 $B$ 内的所有询问按照 $b_i$ 进行 sort

然后我们开始回答询问 $(u,v,a,b)$。

首先注意到此时,$[1,B-1]$ 这些块内的元素一定满足 $a_i \leq a$,所以不用关心前面这些块内的 $a_i$。

然后考虑 $[1,B-1]$ 这些块内的元素的 $b_i$,注意到它们此时是按照 $b_i$ sort 好的,而我们的询问也是按照 $b$ sort 的,所以我们可以直接用一个指针维护一下现在我们在 $[1,B-1]$ 这些块内,有哪些元素的 $b_i$ 是 $\leq b$ 的。

这样,$[1,B-1]$ 这些块就都处理好了,只剩下当前这个块 $B$ 了。

对于当前块 $B$,我们只要暴力向并查集内加入当前块内,所有满足 $a_i \leq a, b_i \leq b$ 的边即可。

在回答完这个询问后,把当前块 $B$ 的所有边都从并查集中撤销,然后开始回答下一个询问。


总时间复杂度:设块的大小为 $B$,那么总共有 $\frac{m}{B}$ 个块。

每一个块的处理:首先要对前面的块进行sort,所以是 $O(m\log m)$。

对于每一个询问,我们都要暴力遍历一个块内的所有边,总共有 $B$ 条边,每次加入操作需要 $O(\log m)$(因为没有路径压缩),所以是 $O(B \log m)$。

于是最终的复杂度就是:

$$T(n) = \frac{m}{B} m\log m + qB \log m$$

取 $B = \sqrt m$ 或者 $\frac{m}{\sqrt q}$ 之类的都可以过。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e4+5;
const int maxm = 1e5+5;

struct Query {
    int u, v, a, b, be, id;
} query[maxn];
vector<Query> vec[maxn];

struct Edge {
    int from, to, a, b;
} edges[maxm];
bool cmpA(const Edge& e1, const Edge& e2) {
    return e1.a < e2.a;
}
bool cmpB(const Edge& e1, const Edge& e2) {
    return e1.b < e2.b;
}

int ecnt = 0, n, m, B, blockcnt, Q;
void addEdge(int u, int v, int a, int b) {
    Edge e = {u, v, a, b};
    edges[++ecnt] = e;
}

struct State {
    int u, v, szu, amaxu, bmaxu;
} st[maxm];
int tail = 0;

int par[maxn], sz[maxn], amax[maxn], bmax[maxn];
inline void init() {
    for (int i = 1; i <= n; i++) par[i] = i, sz[i] = 1, amax[i] = bmax[i] = -1;
    tail = 0;
}

int finds(int u) {
    if (par[u] == u) return u;
    return finds(par[u]);
}
void unions(int u, int v, int a, int b) {
    u = finds(u), v = finds(v);
    if (sz[u] < sz[v]) swap(u,v);
    st[++tail] = {u, v, sz[u], amax[u], bmax[u]};
    amax[u] = max(amax[u], a);   // 注意不管是否成功,都要更改 amax, bmax
    bmax[u] = max(bmax[u], b);
    if (u == v) return;
    par[v] = u;
    sz[u] += sz[v];
    amax[u] = max(amax[u], amax[v]);
    bmax[u] = max(bmax[u], bmax[v]);
}

void cancel() {
    if (tail > 0) {
        int u = st[tail].u, v = st[tail].v;
        par[v] = v;
        sz[u] = st[tail].szu;
        amax[u] = st[tail].amaxu;
        bmax[u] = st[tail].bmaxu;
        tail--;
    }
}

bool ans[maxn];
int main() {
    cin >> n >> m;
    for (int i = 1; i <= m; i++) {
        int u,v,a,b; cin >> u >> v >> a >> b;
        addEdge(u,v,a,b);
    }
    sort(edges+1, edges+ecnt+1, cmpA);
    B = sqrt(n);  // block size
    blockcnt = (ecnt + B - 1) / B;

    cin >> Q;

    // block 1: [1 ... B], block 2: [B+1, ... , 2B], 
    for (int i = 1; i <= Q; i++) {
        cin >> query[i].u >> query[i].v >> query[i].a >> query[i].b;
        query[i].id = i;
        query[i].be = 1;
        for (int b = 1; b <= blockcnt-1; b++) {
            if (edges[b * B].a <= query[i].a) {
                query[i].be = b+1;
            }
        }
        vec[query[i].be].push_back(query[i]);
    }

    for (int b = 1; b <= blockcnt; b++) {
        if (!vec[b].size()) continue;  // 若这个块为空
        sort(vec[b].begin(), vec[b].end(), [](auto a, auto b) {
            return a.b < b.b;  // 把所有询问按照 b 的大小 sort 一下
        });
        init();
        if (b > 1) {
            int R = (b-1) * B;  // 前一个块的右端点
            sort(edges+1, edges+R+1, cmpB);

            // 开始处理询问
            int j = 0;
            for (Query que : vec[b]) {
                int qa = que.a, qb = que.b;

                // 先将前面的块满足条件的都加进去
                while (j + 1 <= R && edges[j+1].b <= qb) {
                    unions(edges[j+1].from, edges[j+1].to, edges[j+1].a, edges[j+1].b);
                    j++;
                }

                // 然后开始处理当前块
                int curt = tail;  // 记录当前 stack 的 tail 位置
                for (int i = R+1; i <= min(b*B, ecnt); i++) {   // 当前块
                    if (edges[i].a <= qa && edges[i].b <= qb) {
                        unions(edges[i].from, edges[i].to, edges[i].a, edges[i].b);
                    }
                }

                int u = finds(que.u), v = finds(que.v);
                if (u == v && amax[u] == qa && bmax[u] == qb) {
                    ans[que.id] = 1;
                }

                while (tail != curt) cancel();  // 退回当前块的所有操作
            }
        }
    }

    for (int i = 1; i <= Q; i++) {
        cout << (ans[i] ? "Yes" : "No") << "\n";
    }
}

注意事项

  1. unions(u,v,a,b) 时,无论 u,v 是否已经联通,都要更新 a,b
  2. 可撤销并查集的栈只记录被更改的信息,其他的就不用记了。

例2 BZOJ4668 冷战

题意

给定 $n$ 个点 和 $m$ 次询问。

每次询问有两种格式:

$0 ~ u ~ v$:将 $u,v$ 用一条边链接起来。

$1 ~ u ~ v$:询问最早在加入哪条边以后,$u,v$ 在同一个联通块中。如果此时尚未联通,则输出 $0$。

所有询问强制在线。

其中,$n,m \leq 5 \times 10^5$。

题解

启发式合并并查集的题。

第一反应是每次链接的时候,新建一个 parent 节点储存这是第几条边,然后把 $u’, v'$ 作为它的两个 child。($u’,v'$ 是在 finds(u), finds(v) 得到的根节点)。类似于 kruskal 重构树的思路。

然后每次询问 $1 ~ u ~ v$ 的时候,直接找到 LCA 位置的那个节点即可。


这样是正确的,不过有个更优雅的写法:

注意到有效的链接只有 $(n-1)$ 次,普通的启发式合并刚好生成的就是一棵树,也不用新建节点了。直接启发式合并的时候,把当前是第几条边,记录在链接用的那一条树边上。

这样,$(n-1)$ 条树边,刚好对应的就是 $(n-1)$ 次启发式合并。

每次询问 $1 ~ u ~ v$ 的时候,直接暴力往上跳到 LCA,然后取边上的最大值即可。(就相当于 $u,v$ 之间链上的边权最大值)。

由于启发式合并的性质,暴力跳 LCA 的时间是 $O(\log n)$。

所以复杂度就是 $O(n \log n)$。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;

int n,m, dep[maxn], par[maxn], sz[maxn], val[maxn];
int finds(int u) {
    if (par[u] == u) {
        dep[u] = 0;
        return u;
    }
    int res = finds(par[u]);
    dep[u] = dep[par[u]] + 1;
    return res;
}
void unions(int u, int v, int id) {
    u = finds(u), v = finds(v);
    if (u == v) return;
    if (sz[u] < sz[v]) swap(u,v);
    sz[u] += sz[v];
    par[v] = u;
    val[v] = id;  // 链接 (v,u) 的边权为 id
}
// 寻找 u,v 链上最大的值
int find_max(int u, int v) {
    if (finds(u) != finds(v)) return 0;
    if (dep[u] < dep[v]) swap(u,v);
    int res = 0;
    while (dep[u] > dep[v]) {
        res = max(res, val[u]);
        u = par[u];
    }
    if (u == v) return res;
    while (u != v) {
        res = max({res, val[u], val[v]});
        u = par[u];
        v = par[v];
    }
    return res;
}

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        par[i] = i;
        sz[i] = 1;
    }
    int last = 0, id = 0;
    while (m--) {
        int op,u,v; cin >> op >> u >> v;
        u ^= last, v ^= last;
        if (op == 0) {
            unions(u, v, ++id);
        } else {
            int res = find_max(u, v);
            cout << res << "\n";
            last = res;
        }
    }
}