介绍

这篇博客主要介绍一些并查集的高级应用。

目前我所知的并查集应用有:

  1. 权值并查集
  2. 可撤销并查集
  3. 可持久化并查集(还没学)

权值并查集

略,目前我所知的方法就只有开多倍空间,维护不同元素之间的关系,经典例题见 NOI2001 食物链

可撤销并查集

意思如其名,可以将之前合并的操作撤销。

思想并不难,只要拿一个额外的 stack 记录一下每次操作 更改的信息,然后回退的时候将 stack 内的信息 pop 出来,恢复一下当时的状态即可。

有几个需要注意的点:

  1. 并查集 不能路径压缩,否则会破坏结构。为了保证时间复杂度,需要使用启发式合并。
  2. 无论 unions(u,v) 是否成功,都要记录在栈中。因为我们回退时并不知道这一步合并当时是否成功了。
模版
ll ans = 0;
ll cal(ll x) {
    return x*(x-1) / 2;
}

struct State {
    int u, v, szu, szv;
} st[maxm];  // 注意这里是 maxm,应该是边的数量
struct DSU {
    int par[maxn], sz[maxn], tail = 0;
    inline void init() {
        for (int i = 1; i <= n; i++) par[i] = i, sz[i] = 1;
        tail = 0;
    }
    int finds(int u) {
        if (par[u] == u) return u;
        return finds(par[u]);  // 无路径压缩
    }
    void unions(int u, int v) {
        u = finds(u), v = finds(v);
        if (sz[u] < sz[v]) swap(u,v);  // sz[u] >= sz[v]
        st[++tail] = {u, v, sz[u], sz[v]};   // 无论是否 union 成功都要push到 stack 里
        if (u == v) return;
        par[v] = u;
        ans = ans - cal(sz[u]) - cal(sz[v]) + cal(sz[u] + sz[v]);
        sz[u] += sz[v];
    }
    void cancel() {
        if (tail > 0) {
            int u = st[tail].u, v = st[tail].v;
            par[v] = v;
            if (sz[u] != st[tail].szu) {
                ans = ans - cal(sz[u]) + cal(st[tail].szu) + cal(st[tail].szv);
            }
            sz[u] = st[tail].szu;
            sz[v] = st[tail].szv;
            tail--;
        }
    }
} dsu;

并查集维护仅删除的单向链表/树

并查集还可以用于维护一个只有删除操作的单向链表 / 树,维护的信息是一个节点的 “下一个xx” 相关信息。

例如在一棵树上,支持两种操作:

  1. 删除一个节点 $u$,将其所有 child 连到节点 $u$ 的parent上。
  2. 查询一个节点的parent。

就可以将它看作并查集问题,每个节点 $u$ 的并查集中,par[u] 记录的是:在它的上方,第一个还存在的节点是谁。

所以当一个节点还存在时,par[u] = u,当它被删除时,就 unions(u, p),其中 $p$ 是 $u$ 在树上的 parent。

• 需要注意,unions(u,p) 时不能随便 unions,我们需要保证 par[finds(u)] = finds(p)

例题

例1 [HNOI2016]最小公倍数

题意

给定一个 $n$ 个节点,$m$ 条边的无向图。每个边有权值,权值都可以被表示为 $2^a \times 3^b$ 的形式。

现在给出 $q$ 个询问,每次询问 $u ~ v ~ a ~ b$,我们需要回答在 $u,v$ 之间,是否存在一条路径使得路径上所有边权的 LCM 等于 $2^a \times 3^b$?

路径可以不为简单路径(可以经过重复的边)。

其中,$1 \leq n,q \leq 5 \times 10^4, 1 \leq m \leq 10^5, 0 \leq a_i,b_i \leq 10^9$。

题解

题目实际上就想要我们找,每条边的权值表示为 $(a_i,b_i)$,每次询问 $u ~ v ~ a ~ b$,回答是否存在路径使得路径上的 $\max \{a_i\} = a, \max \{b_i\} = b$?

我们先用最暴力的思路来想:

对于每次询问,我们只要把所有满足 $a_i \leq a, b_i \leq b$ 的边全都加进去,然后判断一下 $u,v$ 是否联通,并且这个联通块内 $\max \{a_i\} = a, \max \{b_i\} = b$ 是否成立即可。这可以用并查集实现。

接下来考虑优化。


对于这种拥有 $2$ 个条件的题目,一种套路是:

将第一种权值 sort 一下,然后分块。

每个块内(或者多个块一起)对第二种权值 sort。然后把所有询问放进对应的块内进行处理。

对于本题,我们先把所有的边按照 $a_i$ 的大小进行 sort,然后进行分块。

对于每一条边 $(u,v,a,b)$,我们找到一个块 $B$ 满足,在 $[1,B-1]$ 这些块内,最大的 $a_i$ 都 $\leq a$,且这个 $B$ 尽可能大。

之后,我们对于每一个块进行处理。

假设我们现在在块 $B$ 内,那么我们将 $[1,B-1]$ 这些块看作一个整体,然后对这个整体,按照 第二种权值 $b_i$ 进行 sort

然后我们将块 $B$ 内的所有询问按照 $b_i$ 进行 sort

然后我们开始回答询问 $(u,v,a,b)$。

首先注意到此时,$[1,B-1]$ 这些块内的元素一定满足 $a_i \leq a$,所以不用关心前面这些块内的 $a_i$。

然后考虑 $[1,B-1]$ 这些块内的元素的 $b_i$,注意到它们此时是按照 $b_i$ sort 好的,而我们的询问也是按照 $b$ sort 的,所以我们可以直接用一个指针维护一下现在我们在 $[1,B-1]$ 这些块内,有哪些元素的 $b_i$ 是 $\leq b$ 的。

这样,$[1,B-1]$ 这些块就都处理好了,只剩下当前这个块 $B$ 了。

对于当前块 $B$,我们只要暴力向并查集内加入当前块内,所有满足 $a_i \leq a, b_i \leq b$ 的边即可。

在回答完这个询问后,把当前块 $B$ 的所有边都从并查集中撤销,然后开始回答下一个询问。


总时间复杂度:设块的大小为 $B$,那么总共有 $\frac{m}{B}$ 个块。

每一个块的处理:首先要对前面的块进行sort,所以是 $O(m\log m)$。

对于每一个询问,我们都要暴力遍历一个块内的所有边,总共有 $B$ 条边,每次加入操作需要 $O(\log m)$(因为没有路径压缩),所以是 $O(B \log m)$。

于是最终的复杂度就是:

$$T(n) = \frac{m}{B} m\log m + qB \log m$$

取 $B = \sqrt m$ 或者 $\frac{m}{\sqrt q}$ 之类的都可以过。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e4+5;
const int maxm = 1e5+5;

struct Query {
    int u, v, a, b, be, id;
} query[maxn];
vector<Query> vec[maxn];

struct Edge {
    int from, to, a, b;
} edges[maxm];
bool cmpA(const Edge& e1, const Edge& e2) {
    return e1.a < e2.a;
}
bool cmpB(const Edge& e1, const Edge& e2) {
    return e1.b < e2.b;
}

int ecnt = 0, n, m, B, blockcnt, Q;
void addEdge(int u, int v, int a, int b) {
    Edge e = {u, v, a, b};
    edges[++ecnt] = e;
}

struct State {
    int u, v, szu, amaxu, bmaxu;
} st[maxm];
int tail = 0;

int par[maxn], sz[maxn], amax[maxn], bmax[maxn];
inline void init() {
    for (int i = 1; i <= n; i++) par[i] = i, sz[i] = 1, amax[i] = bmax[i] = -1;
    tail = 0;
}

int finds(int u) {
    if (par[u] == u) return u;
    return finds(par[u]);
}
void unions(int u, int v, int a, int b) {
    u = finds(u), v = finds(v);
    if (sz[u] < sz[v]) swap(u,v);
    st[++tail] = {u, v, sz[u], amax[u], bmax[u]};
    amax[u] = max(amax[u], a);   // 注意不管是否成功,都要更改 amax, bmax
    bmax[u] = max(bmax[u], b);
    if (u == v) return;
    par[v] = u;
    sz[u] += sz[v];
    amax[u] = max(amax[u], amax[v]);
    bmax[u] = max(bmax[u], bmax[v]);
}

void cancel() {
    if (tail > 0) {
        int u = st[tail].u, v = st[tail].v;
        par[v] = v;
        sz[u] = st[tail].szu;
        amax[u] = st[tail].amaxu;
        bmax[u] = st[tail].bmaxu;
        tail--;
    }
}

bool ans[maxn];
int main() {
    cin >> n >> m;
    for (int i = 1; i <= m; i++) {
        int u,v,a,b; cin >> u >> v >> a >> b;
        addEdge(u,v,a,b);
    }
    sort(edges+1, edges+ecnt+1, cmpA);
    B = sqrt(n);  // block size
    blockcnt = (ecnt + B - 1) / B;

    cin >> Q;

    // block 1: [1 ... B], block 2: [B+1, ... , 2B], 
    for (int i = 1; i <= Q; i++) {
        cin >> query[i].u >> query[i].v >> query[i].a >> query[i].b;
        query[i].id = i;
        query[i].be = 1;
        for (int b = 1; b <= blockcnt-1; b++) {
            if (edges[b * B].a <= query[i].a) {
                query[i].be = b+1;
            }
        }
        vec[query[i].be].push_back(query[i]);
    }

    for (int b = 1; b <= blockcnt; b++) {
        if (!vec[b].size()) continue;  // 若这个块为空
        sort(vec[b].begin(), vec[b].end(), [](auto a, auto b) {
            return a.b < b.b;  // 把所有询问按照 b 的大小 sort 一下
        });
        init();
        if (b > 1) {
            int R = (b-1) * B;  // 前一个块的右端点
            sort(edges+1, edges+R+1, cmpB);

            // 开始处理询问
            int j = 0;
            for (Query que : vec[b]) {
                int qa = que.a, qb = que.b;

                // 先将前面的块满足条件的都加进去
                while (j + 1 <= R && edges[j+1].b <= qb) {
                    unions(edges[j+1].from, edges[j+1].to, edges[j+1].a, edges[j+1].b);
                    j++;
                }

                // 然后开始处理当前块
                int curt = tail;  // 记录当前 stack 的 tail 位置
                for (int i = R+1; i <= min(b*B, ecnt); i++) {   // 当前块
                    if (edges[i].a <= qa && edges[i].b <= qb) {
                        unions(edges[i].from, edges[i].to, edges[i].a, edges[i].b);
                    }
                }

                int u = finds(que.u), v = finds(que.v);
                if (u == v && amax[u] == qa && bmax[u] == qb) {
                    ans[que.id] = 1;
                }

                while (tail != curt) cancel();  // 退回当前块的所有操作
            }
        }
    }

    for (int i = 1; i <= Q; i++) {
        cout << (ans[i] ? "Yes" : "No") << "\n";
    }
}

注意事项

  1. unions(u,v,a,b) 时,无论 u,v 是否已经联通,都要更新 a,b
  2. 可撤销并查集的栈只记录被更改的信息,其他的就不用记了。

例2 BZOJ4668 冷战

题意

给定 $n$ 个点 和 $m$ 次询问。

每次询问有两种格式:

$0 ~ u ~ v$:将 $u,v$ 用一条边链接起来。

$1 ~ u ~ v$:询问最早在加入哪条边以后,$u,v$ 在同一个联通块中。如果此时尚未联通,则输出 $0$。

所有询问强制在线。

其中,$n,m \leq 5 \times 10^5$。

题解

启发式合并并查集的题。

第一反应是每次链接的时候,新建一个 parent 节点储存这是第几条边,然后把 $u’, v'$ 作为它的两个 child。($u’,v'$ 是在 finds(u), finds(v) 得到的根节点)。类似于 kruskal 重构树的思路。

然后每次询问 $1 ~ u ~ v$ 的时候,直接找到 LCA 位置的那个节点即可。


这样是正确的,不过有个更优雅的写法:

注意到有效的链接只有 $(n-1)$ 次,普通的启发式合并刚好生成的就是一棵树,也不用新建节点了。直接启发式合并的时候,把当前是第几条边,记录在链接用的那一条树边上。

这样,$(n-1)$ 条树边,刚好对应的就是 $(n-1)$ 次启发式合并。

每次询问 $1 ~ u ~ v$ 的时候,直接暴力往上跳到 LCA,然后取边上的最大值即可。(就相当于 $u,v$ 之间链上的边权最大值)。

由于启发式合并的性质,暴力跳 LCA 的时间是 $O(\log n)$。

所以复杂度就是 $O(n \log n)$。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;

int n,m, dep[maxn], par[maxn], sz[maxn], val[maxn];
int finds(int u) {
    if (par[u] == u) {
        dep[u] = 0;
        return u;
    }
    int res = finds(par[u]);
    dep[u] = dep[par[u]] + 1;
    return res;
}
void unions(int u, int v, int id) {
    u = finds(u), v = finds(v);
    if (u == v) return;
    if (sz[u] < sz[v]) swap(u,v);
    sz[u] += sz[v];
    par[v] = u;
    val[v] = id;  // 链接 (v,u) 的边权为 id
}
// 寻找 u,v 链上最大的值
int find_max(int u, int v) {
    if (finds(u) != finds(v)) return 0;
    if (dep[u] < dep[v]) swap(u,v);
    int res = 0;
    while (dep[u] > dep[v]) {
        res = max(res, val[u]);
        u = par[u];
    }
    if (u == v) return res;
    while (u != v) {
        res = max({res, val[u], val[v]});
        u = par[u];
        v = par[v];
    }
    return res;
}

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        par[i] = i;
        sz[i] = 1;
    }
    int last = 0, id = 0;
    while (m--) {
        int op,u,v; cin >> op >> u >> v;
        u ^= last, v ^= last;
        if (op == 0) {
            unions(u, v, ++id);
        } else {
            int res = find_max(u, v);
            cout << res << "\n";
            last = res;
        }
    }
}

例3 Do Not Try This Problem

题意

给定一个长度为 $n$ 的 string $s$,给定 $q$ 个操作,每次操作的格式是 $i~a~k~c$,将从index $i$ 开始,将 $s_i,s_{i+a},…,s_{i+ka}$ 替换为字符 $c$。

输出所有操作结束后的string。

其中,$n, q \leq 10^5, i \in [1,n], a \in [1, n), k \in [0,n), i+ka \leq n$。

题解

由于询问离线,所以可以从后向前处理。因为后面的操作优先级更高,当一个后面的操作更改了某个位置,说明之后再也不用改这个位置了。

然后考虑分块,对于 $a > \sqrt n$ 的操作,直接暴力修改即可。

对于 $a \leq \sqrt n$ 的操作,我们利用 并查集维护仅带删除的单向链表

由于一个位置被更改以后,就再也不用改这个位置了,把它当作被修改后,直接删除这个位置,那么我们维护 $\sqrt n$ 种链表,第 $j$ 种链表将会维护 $j$ 个跨度为 $j$ 的链表,例如 $j=3$,那么维护的就是 $1 \rightarrow 4 \rightarrow 7 \rightarrow …$ 和 $2 \rightarrow 5 \rightarrow 8 \rightarrow…$ 和 $3 \rightarrow 6 \rightarrow 9 \rightarrow …$。

而每一种链表就是一个并查集,每一个元素的parent是它 下一个存在的元素

在删除一个元素时,更新其 parent 即可。

按理说,删除一个元素时,应该更新其所有的parent,即更新 $\sqrt n$ 种链表,如:

void del(int p, char c) {  // 删除位置p
    if (vis[p]) return;
    vis[p] = 1;
    s[p] = c;
    for (int i = 1; i < m; i++) {
        uf[i].unions(p, min(p+i, n+1));
    }
}

但是这样每次删除一个元素都要更新,虽然复杂度还是 $O(n \log n)$ ,但常数略高会T。

我们改成在询问 $< \sqrt n$ 的情况下才更新,并且只更新对应的 $a$,详情见代码。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+3;
 
struct UnionFind {
    int par[maxn];
    int finds(int u) {
        if (par[u] == u) return u;
        return par[u] = finds(par[u]);
    }
    void unions(int u, int v) {
        u = finds(u), v = finds(v);
        if (u < v) swap(u, v);
        par[v] = u;
    }
} uf[320];
 
struct Query {
    int i, a, k;
    char c;
} q[maxn];
int n, m, Q;
string s;
 
bool vis[maxn];
void del(int p, char c) {  // 删除位置p
    if (vis[p]) return;
    vis[p] = 1;
    s[p] = c;
}
int main() {
    fastio;
    cin >> s;
    n = s.size();
    s = "#" + s;
    m = sqrt(n);
    for (int i = 1; i < m; i++) {
        for (int j = 1; j <= n+1; j++) uf[i].par[j] = j;
    }
    cin >> Q;
    for (int i = 1; i <= Q; i++) cin >> q[i].i >> q[i].a >> q[i].k >> q[i].c;
    for (int i = Q; i >= 1; i--) {
        auto [st, a, k, c] = q[i];
        if (a >= m) {
            for (int j = st; j <= st+a*k; j += a) {
                del(j, c);
            }
        } else {
            int j = st;
            while (j <= st+a*k) {
                del(j, c);
                uf[a].unions(j, min(j+a, n+1));
                j = uf[a].finds(j);
            }
        }
    }
    for (int i = 1; i <= n; i++) cout << s[i];
    cout << "\n";
}

例4 TJOI2016

题意

给定一棵树,根为 $1$,现在有两个操作:

C u:给节点 $u$ 打标记(一个节点可能被打上多次标记)。

Q u:询问节点 $u$ 打了标记的最近祖先(算上它自己)。

一开始 $1$ 被标记了,其他所有节点无标记。

其中,$n,q \leq 10^5$。

题解

离线操作,将问题反过来看:

先给所有节点打上标记,然后利用并查集来维护支持删点的树。并查集上维护每一个节点上方的第一个打了标记的点。

然后将所有询问从后往前处理,遇到一个标记操作,就将这个节点的标记次数 -1(因为一个节点可以被标记多次),当一个节点的标记次数归0时,将这个节点从树中删除。

这样就可以 $O(n)$ 解决这个问题了。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
pair<char, int> q[maxn];
vector<int> adj[maxn];
int f[maxn];
int par[maxn], dep[maxn];

struct UnionFind {
    int par[maxn];
    int finds(int u) {
        if (par[u] == u) return u;
        return par[u] = finds(par[u]);
    }
    void unions(int u, int v) {
        u = finds(u), v = finds(v);
        if (dep[u] > dep[v]) swap(u, v);
        par[v] = u;
    }
} uf;

void dfs(int u, int p) {
    par[u] = p;
    dep[u] = dep[p] + 1;
    if (!f[u]) {
        uf.unions(u, p);
    }
    for (int v : adj[u]) {
        if (v == p) continue;
        dfs(v, u);
    }
}
int main() {
    int n, Q; cin >> n >> Q;
    for (int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    f[1] = 1;
    for (int i = 1; i <= Q; i++) {
        cin >> q[i].first >> q[i].second;
        if (q[i].first == 'C') {
            f[q[i].second]++;
        }
    }
    for (int i = 1; i <= n; i++) uf.par[i] = i;

    dfs(1, 0);
    vector<int> ans;
    for (int i = Q; i >= 1; i--) {
        int u = q[i].second;
        if (q[i].first == 'C') {
            f[u]--;
            if (!f[u])
                uf.unions(u, par[u]);
        } else {
            ans.push_back(uf.finds(u));
        }
    }
    reverse(ans.begin(), ans.end());
    for (int i : ans) cout << i << "\n";
}