介绍

01-Trie 指将整数拆成二进制,然后将二进制以字符串的形式储存在 Trie 中。

img

Trie可以解决以下问题:

  1. 给定一些数 $a_1,a_2,…,a_n$,给定 $x$,求 $a_i$,使得 $a_i \text{ xor } x$ 最大。

  2. 维护异或和:给定一些数,支持全体加一,插入数字,删除数字,求全体异或和等操作。

Trie可以将数字从低位到高位储存,也可以反过来,根据具体题目而定。

为了方便,我们在写 01Trie 的时候都会把每个数补成同样的位数。

• 注意 id = 1 的是 Root,所以 id 要从 $1$ 开始。

维护最大XOR

指第一个例子,这里我们的 01Trie 将会 从高位到低位 储存。

在给定一个查询 $x$ 时,我们从 $x$ 的高位开始看,设当前到了第 $i$ 位,那么我们看第 $i$ 位的bit $a_i$,然后判断一下第 $i$ 位的值为 a[i] ^ 1 的数字是否存在即可,如果存在,往那个方向走,否则往另外一个方向走。

板子
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;

int id = 1;  // 注意,从 1 开始
struct Node {
    int cnt = 0;
    int child[2];
} trie[maxn<<4];
void insert(int x) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;

        c = trie[c].child[k];
        trie[c].cnt++;
    }
}

int query(int x) {
    int c = 1;
    int res = 0;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        k ^= 1;
        if (trie[c].child[k]) {
            c = trie[c].child[k];
            res |= (1<<j);
        } else {
            c = trie[c].child[k^1];
        }
    }
    return res;
}

维护XOR和

指第二个例子,这里我们的 01Trie 从低位到高位储存。

要维护异或和,我们只需要知道每一位上 $0$ 和 $1$ 个数的奇偶性即可。

对于每一个节点,我们记录三个量:

  1. 两个子节点的编号
  2. w:指当前这个节点,到它的parent这条边 $(p,u)$ 上,被经过的次数。每有一个数字在插入过程中经过这个边,这个 w 就会加一。
  3. val:指以当前节点为根,它子树内包含的所有数字的 XOR和

所以维护信息的时候就有:

void push_up(int cur) {
    trie[cur].val = trie[cur].w = 0;  // 先清空
    int c0 = trie[cur].child[0], c1 = trie[cur].child[1];

    // 更新 w
    if (c0) trie[cur].w += trie[c0].w;
    if (c1) trie[cur].w += trie[c1].w;

    // 更新 val

    if (c0) {
        trie[cur].val ^= (trie[c0].val << 1);
    }
    if (c1) {
        trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1));  // 如果 c1 的 w 为奇数,就提供了 1 的贡献
    }
}

解释: 注意是从低位到高位储存的,所以子树内储存的只是数字的二进制的一部分。

所以要把 val 进行左移。

也就是说,实际上只有 根节点val 是真正有意义的,因为只有根节点才记录的是完整的数字的 XOR和。


插入和删除就不提了,看代码即可。注意插入删除合并起来,可以达成单点修改的效果。

最后说一下全局加一。

全局加一指的是将整个01Trie维护的所有数字都加一。

在二进制下的加一,相当于从低位到高位找第一个出现的 $0$,然后将其变成 $1$。

最后,比这一位低的所有位都将会从 $1$ 变成 $0$。

// add 1 to all numbers maintained by cur
void addall(int cur) {
    swap(trie[cur].child[0], trie[cur].child[1]);
    if (trie[cur].child[0]) {
        addall(trie[cur].child[0]);
    }
    push_up(cur);
}
板子
const int M = 21;  // 最大深度
int id = 0;
struct Node {
    int w;  // 到 parent 这条边上的个数
    int val;  // subtree 内的XOR和
    int child[2];
} trie[maxn*(M+1)];

void push_up(int cur) {
    trie[cur].val = trie[cur].w = 0;  // 先清空
    int c0 = trie[cur].child[0], c1 = trie[cur].child[1];

    // 更新 w
    if (c0) trie[cur].w += trie[c0].w;
    if (c1) trie[cur].w += trie[c1].w;

    // 更新 val

    if (c0) {
        trie[cur].val ^= (trie[c0].val << 1);
    }
    if (c1) {
        trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1));  // 如果 c1 的 w 为奇数,就提供了 1 的贡献
    }
}

// cur: index, x: 当前的数字, dep: 深度
void insert(int& cur, int x, int dep) {
    if (!cur) cur = ++id;
    if (dep >= M) {
        trie[cur].w++;  // w += 1
        return;
    }
    int c = x & 1;
    insert(trie[cur].child[c], (x>>1), dep+1);
    push_up(cur);
}

void del(int& cur, int x, int dep) {
    if (!cur) cur = ++id;
    if (dep >= M) {
        trie[cur].w--;  // w -= 1
        return;
    }
    int c = x & 1;
    insert(trie[cur].child[c], (x>>1), dep+1);
    push_up(cur);
}

// add 1 to all numbers maintained by cur
void addall(int cur) {
    swap(trie[cur].child[0], trie[cur].child[1]);
    if (trie[cur].child[0]) {
        addall(trie[cur].child[0]);
    }
    push_up(cur);
}

可持久化01-Trie

和主席树相似,可持久化01-Trie主要是为了让我们获得任何一个区间 $[L,R]$ 内组成的 01-Trie。

使用 Node 的 cnt 来决定 01-Trie 的内容。

板子
const int maxn = 6e5+5;
const int M = 28;
struct Node {
    int cnt;
    int child[2];
} trie[maxn * (M+1)];
int n, q, root[maxn], a[maxn], id;

void insert(int pre, int cur, int x) {
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        trie[cur].cnt = trie[pre].cnt + 1;  // cnt 加一
        if (!trie[cur].child[c]) trie[cur].child[c] = ++id;  // 新建节点
        trie[cur].child[c^1] = trie[pre].child[c^1];  // 复制另外一个子节点
        cur = trie[cur].child[c];
        pre = trie[pre].child[c];
    }
    trie[cur].cnt = trie[pre].cnt + 1;
}

// find the maximum value for a ^ x (a is in [pre, cur])
int query(int pre, int cur, int x) {
    int res = 0;
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        c ^= 1;
        if (trie[trie[cur].child[c]].cnt - trie[trie[pre].child[c]].cnt) {  // cnt > 0
            cur = trie[cur].child[c];
            pre = trie[pre].child[c];
            res |= (1<<i);
        } else {
            cur = trie[cur].child[c^1];
            pre = trie[pre].child[c^1];
        }
    }
    return res;
}
int main() {
    for (int i = 1; i <= n; i++) {
        cin >> a[i]; sum[i] = sum[i-1] ^ a[i];
        if (!root[i]) root[i] = ++id;
        insert(root[i-1], root[i], sum[i]);
    }
    // 查询 [L,R] -> (l-1,r)
    int res = query(root[max(0,l-1)], root[r], x);
}

例题

例1 洛谷P4551 最长异或路径

题意

给定一个 $n$ 个节点的树,边上有权值。

寻找树上的两个节点,使得路径上的边权 XOR和 最大,输出这个最大值。

其中,$n \leq 10^5$。

题解

常见套路:设 $f_x$ 为从 $1$(根节点)到 $x$ 的路径上的 XOR和。

那么 $(u,v)$ 路径上的 XOR和 就等于 $f_u \text{ xor } f_v$。因为 $1$ 到 $LCA(u,v)$ 的部分被抵消掉了。

所以就相当于建立一个 01Trie,储存所有的 $f_u$,然后对于每个 $u$,都询问一下最大的XOR即可。

代码
#include <bits/stdc++.h>
using namespace std;
const ll mod = 1e9+7;
const int maxn = 1e5+5;
const int maxm = 30;

struct Edge {
    int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, n;
void addEdge(int u, int v, int w) {
    Edge e = {v, head[u], w};
    edges[ecnt] = e;
    head[u] = ecnt++;
}
int dp[maxn];
void dfs(int u, int p) {
    for (int e = head[u]; e; e = edges[e].nxt) {
        int to = edges[e].to;
        if (to == p) continue;
        dp[to] = dp[u] ^ edges[e].w;
        dfs(to, u);
    }
}

int id = 1;
struct Node {
    int cnt = 0;
    int child[2];
} trie[maxn<<4];
void insert(int x) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;

        c = trie[c].child[k];
        trie[c].cnt++;
    }
}

int query(int x) {
    int c = 1;
    int res = 0;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        k ^= 1;
        if (trie[c].child[k]) {
            c = trie[c].child[k];
            res |= (1<<j);
        } else {
            c = trie[c].child[k^1];
        }
    }
    return res;
}


int main() {
    fastio;
    cin >> n;
    for (int i = 1; i <= n-1; i++)     {
        int u,v,w; cin >> u >> v >> w;
        addEdge(u,v,w); addEdge(v,u,w);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i++) insert(dp[i]);
    int ans = 0;
    for (int i = 1; i <= n; i++) {
        ans = max(ans, query(dp[i]));
    }
    cout << ans << endl;
}

例2 CF817E Choosing The Commander

题意

给定 $q$ 个询问,共有3种:

$1 ~ p$:将 $p$ 加入集合。

$2 ~ p$:将 $p$ 从集合中删除。

$3 ~ p ~ l$:询问集合中有多少个元素在 XOR $p$ 之后小于 $l$。

其中,$q \leq 10^5$。

题解

在 01-Trie 每个节点处加上一个 count 就可以了。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 27;
 
int Q, id = 1;
struct Node {
    int child[2];
    int cnt = 0;
} trie[maxn<<3];
void insert(int x) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;
        c = trie[c].child[k];
        trie[c].cnt++;
    }
}
void del(int x) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;
        c = trie[c].child[k];
        trie[c].cnt--;
    }
}
 
int query(int p, int l) {
    int c = 1;
    int res = 0;
    for (int j = maxm; j >= 0; j--) {
        if (l & (1<<j)) {  // 如果这一位有,说明要往这一位为 1 的方向走
            int k = ((p & (1<<j)) > 0);
            res += trie[trie[c].child[k]].cnt;
            if (trie[trie[c].child[k^1]].cnt) c = trie[c].child[k^1];
            else break;
        } else {  // 如果这一位没有,就往 0 的方向走
            int k = ((p & (1<<j)) > 0);
            if (trie[trie[c].child[k]].cnt) c = trie[c].child[k];
            else break;
        }
    }
    return res;
}
 
int main() {
    fastio;
    cin >> Q;
    while (Q--) {
        int op, p, l;
        cin >> op >> p;
        if (op == 1) {
            insert(p);
        } else if (op == 2) {
            del(p);
        } else {
            cin >> l;
            int res = query(p,l);
            cout << res << "\n";
        }
    }
}

例3 洛谷P6018 [Ynoi2010] Fusion tree

题意

给定一棵 $n$ 个节点的树,每个节点 $i$ 上有权值,初始权值为 $a_i$。

给出 $m$ 个询问,询问的类型有3种:

$1 ~ x$:将所有与节点 $x$ 距离为 $1$ 的节点的权值加一。

$2 ~ x ~ v$:将节点 $x$ 的权值减去 $v$。

$3 ~ x$:询问所有与节点 $x$ 距离为 $1$ 的节点权值的XOR和。

其中,$n,m \leq 5 \times 10^5$,初始权值 $\leq 10^5$。

题解

这里是 01-Trie 用来维护XOR和。

我们可以对于每一个节点,都把它的孩子(距离为 $1$)的XOR和 维护在一个 01Trie里面。

因为 01-Trie 是动态开点的,所以总共复杂度也就 $n\log(10^5)$。

然后因为每个节点都只有一个 parent,就单独维护即可。

然后对于操作 $1 ~ x$,我们需要做的几件事:

  1. 单独更新 $x$ 的parent。
  2. 给 $x$ 打上一个懒标记,代表它要给它的孩子加上 $lazy$ 的值。
  3. 在 $x$ 维护的 01-Trie 上进行全体加一操作。

然后在每次操作之前,都先检查一下 $parent$ 的懒标记,将其下放。

当然这里的下放不太一样,一个节点可能有非常多个孩子,所以我们选择给每一个节点 $x$ 都定义一个值 add[x],代表它已经从parent那里得到了多少下放的懒标记。

然后就给 $x$ 的权值加上 (lazy[p] - add[x])

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
const int M = 21;  // 最大深度

int id = 0;
struct Node {
    int w;  // 到 parent 这条边上的个数
    int val;  // subtree 内的XOR和
    int child[2];
} trie[maxn*(M+1)];
int root[maxn], lazy[maxn], add[maxn];

void push_up(int cur) {
    trie[cur].val = trie[cur].w = 0;  // 先清空
    int c0 = trie[cur].child[0], c1 = trie[cur].child[1];

    // 更新 w
    if (c0) trie[cur].w += trie[c0].w;
    if (c1) trie[cur].w += trie[c1].w;

    // 更新 val

    if (c0) {
        trie[cur].val ^= (trie[c0].val << 1);
    }
    if (c1) {
        trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1));  // 如果 c1 的 w 为奇数,就提供了 1 的贡献
    }
}

// cur: index, x: 当前的数字, dep: 深度
void insert(int& cur, int x, int dep) {
    if (!cur) cur = ++id;
    if (dep >= M) {
        trie[cur].w++;
        return;
    }
    int c = x & 1;
    insert(trie[cur].child[c], (x>>1), dep+1);
    push_up(cur);
}

void del(int& cur, int x, int dep) {
    if (!cur) cur = ++id;
    if (dep >= M) {
        trie[cur].w--;
        return;
    }
    int c = x & 1;
    insert(trie[cur].child[c], (x>>1), dep+1);
    push_up(cur);
}

// add 1 to all numbers maintained by cur
void addall(int cur) {
    swap(trie[cur].child[0], trie[cur].child[1]);
    if (trie[cur].child[0]) {
        addall(trie[cur].child[0]);
    }
    push_up(cur);
}

int n, m, a[maxn], par[maxn];
struct Edge {
    int to, nxt;
} edges[maxn<<1];
int head[maxn], ecnt = 1;
void addEdge(int u, int v) {
    Edge e = {v, head[u]};
    head[u] = ecnt;
    edges[ecnt++] = e;
}

void dfs(int u) {
    for (int e = head[u]; e; e = edges[e].nxt) {
        int to = edges[e].to;
        if (to == par[u]) continue;
        par[to] = u;
        insert(root[u], a[to], 0);
        dfs(to);
    }
}

// modify u to have the value val
void modify(int u, int val) {
    // a[u] is previous value
    int p = par[u];
    if (p) {
        del(root[p], a[u], 0);
        insert(root[p], val, 0);
    }
    a[u] = val;
}

int main() {
    fastio;
    cin >> n >> m;
    for (int i = 1; i <= n-1; i++) {
        int u,v; cin >> u >> v;
        addEdge(u,v); addEdge(v,u);
    }
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    dfs(1);

    while (m--) {
        int op, x; cin >> op >> x;
        if (op == 1) {
            addall(root[x]);
            int p = par[x];
            if (p) {
                a[p] = a[p] + lazy[par[p]] - add[p];
                add[p] = lazy[par[p]];
                modify(par[x], a[p] + 1);
            }
            lazy[x]++;
        }
        if (op == 2) {
            int v; cin >> v;
            a[x] = a[x] + lazy[par[x]] - add[x];
            add[x] = lazy[par[x]];
            modify(x, a[x] - v);
        }
        if (op == 3) {
            int res = trie[root[x]].val;
            int p = par[x];
            if (p) {
                a[p] = a[p] + lazy[par[p]] - add[p];
                add[p] = lazy[par[p]];
                res ^= a[p];
            }
            cout << res << "\n";
        }
    }
}

例4 洛谷P4735 最大异或和

题意

给定一个非负整数序列 $a_1,a_2,…,a_n$。

现给定 $m$ 个操作,有以下两种操作类型:

$A ~ x$:在序列末尾添加一个数 $x$。

$Q ~ L ~ R ~ x$:输出一个位置 $p$,满足 $p \in [L,R]$,使得 $a_p \bigoplus a_{p+1} \bigoplus … \bigoplus a_n \bigoplus x$ 最大。

其中,$n,m \leq 3 \times 10^5, a_i \in [0,10^7]$。

题解

可持久化 01-Trie 的模版题。

首先,先考虑前缀XOR数组 $s_i = a_1 \bigoplus a_2 \bigoplus … \bigoplus a_i$。

那么

$$a_p \bigoplus a_{p+1} \bigoplus … \bigoplus a_n \bigoplus x = (s_n \bigoplus s_{p-1}) \bigoplus x$$

所以现在问题就变成:

每次询问 $Q ~ L ~ R ~ x$,输出一个位置 $p$,满足 $p \in [L-1,R-1]$,使得 $(s_n \bigoplus s_{p-1}) \bigoplus x$ 最大。

那就很简单了,令 $y = s_n \bigoplus x$,剩下的就是在一个区间 $[L-1,R-1]$ 内找出一个元素 $s_p$ 使得 $s_p \bigoplus y$ 最大。

对于每一个区间都可以用 01-Trie 解决的问题,就是可持久化 01-Trie了。


最后需要注意一下这个样例:

4 1
2 4 8 16
Q 1 4 1

答案应为 $31$,但我输出 $29$。

这是因为我们要单独处理一下 $s_0$ 的情况。换而言之我们要将 $s_0 = 0$ 插入到 root[1] 当中。

所以 root[1] 需要包含 $2$ 个元素:$s_0,s_1$。

对于一个版本插入 $2$ 个元素,只要在第二次插入时使用 insert(root[1], root[1], 0) 即可。

最后再注意一下一些边界条件如 $[L-1,R-1]=[0,0]$ 即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 6e5+5;
const int M = 28;

struct Node {
    int cnt;
    int child[2];
} trie[maxn * (M+1)];
int sum[maxn];  // sum[n] 代表 a[1] ^ a[2] ... ^ a[n]
int n, q, root[maxn], a[maxn], id;

void insert(int pre, int cur, int x) {
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        trie[cur].cnt = trie[pre].cnt + 1;  // cnt 加一
        if (!trie[cur].child[c]) trie[cur].child[c] = ++id;  // 新建节点
        trie[cur].child[c^1] = trie[pre].child[c^1];  // 复制另外一个子节点
        cur = trie[cur].child[c];
        pre = trie[pre].child[c];
    }
    trie[cur].cnt = trie[pre].cnt + 1;
}

// find the maximum value for a ^ x (a is in [pre, cur])
int query(int pre, int cur, int x) {
    int res = 0;
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        c ^= 1;
        if (trie[trie[cur].child[c]].cnt - trie[trie[pre].child[c]].cnt) {  // cnt > 0
            cur = trie[cur].child[c];
            pre = trie[pre].child[c];
            res |= (1<<i);
        } else {
            cur = trie[cur].child[c^1];
            pre = trie[pre].child[c^1];
        }
    }
    return res;
}


int main() {
    fastio;
    cin >> n >> q;

    for (int i = 1; i <= n; i++) {
        cin >> a[i]; sum[i] = sum[i-1] ^ a[i];
        if (!root[i]) root[i] = ++id;
        insert(root[i-1], root[i], sum[i]);
        if (i == 1) insert(root[1], root[1], 0);
    }
    while (q--) {
        char op; cin >> op;
        if (op == 'A') {
            int x; cin >> x;
            a[++n] = x;
            sum[n] = sum[n-1] ^ a[n];
            if (!root[n]) root[n] = ++id;
            insert(root[n-1], root[n], sum[n]);
        } else {
            int l,r,x; cin >> l >> r >> x;
            l--; r--;
            x ^= (sum[n]);
            if (l == 0 && r == 0) {
                cout << x << "\n"; continue;
            }
            int res = query(root[max(0,l-1)], root[r], x);
            cout << res << "\n";
        }
    }
}

例5 CF241B Friends

题意

给定一个长度为 $n$ 的非负整数数组 $a_1,a_2,…,a_n$。

现在选择 $m$ 个pair $(i,j)$,使得这 $m$ 个pair 对应的 $a_i \text{ xor } a_j$ 值的 sum 最大。

其中,$n \leq 5 \times 10^4, a_i \leq 10^9$,答案对 $10^9+7$ 取模。

• $(i,j)$ 和 $(j,i)$ 算作一个 pair。

题解

神仙题。

本题要求的是前 $m$ 大的XOR pair的和。

我们先列出本题的过程:

  1. 求出第 $m$ 大。
  2. 求出 XOR 小于等于 $x$ 的所有pair的和。
  3. 对于一个区间 $[L,R]$,在 $O(1)$ 求出 $\sum\limits_{i=L}^R y \text{ xor } a_i$。

第一步:求第 $m$ 大

• 虽然 $(i,j)$ 和 $(j,i)$ 算作一个 pair,但我们只要把所有东西都乘以 $2$ 就行了。所以我们求第 $2m$ 大。

最简单粗暴的方法当然是二分最终的结果 $x$,然后对于每一个 $a_i$,判断一下在 Trie 中有多少个数字和它 XOR 起来 $\geq x$,把这些 count 累加起来然后查看 count >= 2m 与否,来调整 $x$ 的值。

这样的时间复杂度是 $O(n\log^2 10^9)$,本题可过,但是另外一个版本不可过。

于是我们思考一下能不能直接在 Trie 上模拟这个二分的过程来构建最终的结果 $x$。

发现是可以的,我们从高位到低位,枚举每一位取 $1$ 还是 取 $0$。

然后我们对于每一个 $a_i$ 都判断一下 count 的总和,如果 count >= 2m,则说明这一位可以取 $1$。否则的话需要取 $0$,同时问题变成寻找第 $2m - count$ 大(有点类似于主席树寻找第 $k$ 大的思路)。

• 这个过程是在所有 $a_i$ 上进行的,所以我们维护 $n$ 个指针 ptr[] 代表 $a_i$ 目前在 Trie 的什么位置。

复杂度为 $O(n\log 10^9)$。


第二步:求出 XOR 小于等于 $x$ 的所有pair的和

如果这个问题不是求出和,而是求出个数的话,就是例 $2$ 的问题了。

可惜不是,那怎么处理?

一样,我们从每个 $a_i$ 开始进行 check,对于每一个 Trie 上的节点,我们可以维护一个 cnt[31],其中 cnt[j] 代表当前Trie节点的子树里有多少个数字满足第 $j$ 位为 $1$。

这样当然可以算,但是空间复杂度 $O(n\log^2 10^9)$,本题或许可过,但另外一个版本仍然不可过。

我们这里给出一个非常优秀的结论:

如果 $a_i$ 是 sorted 的,那么 Trie 上任意一个节点的子树 对应 $a_i$ 上的一段连续区间。

这个也不难理解,因为 Trie 的一个子树对应的是拥有一个 “特定前缀” (如 “1101”)的所有数字的集合。

有了这个结论以后,我们在枚举 $a_i$ 时,假设我们目前到了第 $j$ 位,就分两种情况:

  1. $x$ 的第 $j$ 位为 $1$:那么我们就往 $1$ 的方向走。
  2. $x$ 的第 $j$ 位为 $0$:我们把 $a_i$ XOR 上 $1$ 的那个子树,求出sum,给答案加上,然后往 $0$ 的方向走。

注意到 $a_i$ XOR 上一个子树的 sum,可以转化成一个区间上的问题,那么就有第三步:


第三步:对于一个区间 $[L,R]$,在 $O(1)$ 求出 $\sum\limits_{i=L}^R y \text{ xor } a_i$。

由于 $a_i$ 总共有 $n$ 个,所以不能直接求前缀和。

但我们可以把区间内的每个数都 拆成二进制

然后对于拆开的二进制,每一位分别 XOR 上 $0$ 和 $1$,然后都进行一个前缀和。

这样给定一个数字 $y$,就可以把 $y$ 拆成二进制,然后对于每一位 $j$,把 $[L,R]$ 内对应的和用 $O(1)$ 时间内求出来即可。

• 所以我们要做的事就是先 sort 一下整个 $a_i$ 数组,然后在构建 Trie 的时候记录一下每个子树对应哪个区间即可。


最后要注意,由于第 $2m$ 大有可能 等于 第 $2m+1, 2m+2 …$ 大,所以要处理掉额外加上的部分。

答案除以 $2$ 即可。

代码
#include <bits/stdc++.h>
using namespace std;

const ll mod = 1e9+7;
const int maxn = 5e4+5;
const int M = 31;
struct Node {
    int cnt = 0;
    int child[2];
    int L = 1e9, R = -1;  // left and right segments
} trie[maxn * (M+1)];
int id = 1;
void insert(int idx, ll x) {
    int c = 1;
    for (int i = M; i >= 0; i--) {
        int k = ((x & (1LL<<i)) ? 1 : 0);
        if (!trie[c].child[k]) trie[c].child[k] = ++id;
        c = trie[c].child[k];
        trie[c].cnt++;
        trie[c].L = min(trie[c].L, idx);  // 维护子树对应的区间
        trie[c].R = max(trie[c].R, idx);
    }
}

ll a[maxn];
ll n,m;

// 寻找第 k 大 (本题是找到第 2m 大)
int ptr[maxn];  // 每一个 ai 对应的 Trie 上的 ptr
ll find_kmax(ll t) {
    ll res = 0;
    fill(ptr+1, ptr+n+1, 1);  // 开始都在 1
    for (int j = M; j >= 0; j--) {
        ll cnt = 0;
        // 找这一位 >= 1 的有多少个
        for (int i = 1; i <= n; i++) {
            int k = ((a[i] & (1LL<<j)) ? 1 : 0);
            k ^= 1;
            cnt += trie[trie[ptr[i]].child[k]].cnt;
        }
        int nxt;
        if (cnt >= t) {
            res |= (1LL<<j);
            nxt = 1;
        } else {
            t -= cnt;
            nxt = 0;
        }

        for (int i = 1; i <= n; i++) {
            int k = ((a[i] & (1LL<<j)) ? 1 : 0);
            k ^= nxt;  // 由这一步决定的哪一个来判断每一个 ptr 的更新位置
            ptr[i] = trie[ptr[i]].child[k];
        }
    }
    return res;
}


int sum[maxn][M+1][2];
// 计算 a[L,R] ^ x 之和
ll sum_range(ll l, ll r, ll x) {
    ll ans = 0;
    for (int j = M; j >= 0; j--) {
        int k = ((x & (1LL<<j)) ? 1 : 0);
        ll s = sum[r][j][k] - sum[l-1][j][k];
        ans += s * (1LL << j);
        ans %= mod;
    }
    return ans;
}

ll qpow(ll a, ll b) {
    ll res = 1;
    while (b) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

// 获得 ai ^ aj >= x 的所有数之和
ll get_sum(ll x) {
    // 先把每个数拆成二进制,做前缀和
    for (int j = M; j >= 0; j--) {
        for (int i = 1; i <= n; i++) {
            for (int k = 0; k <= 1; k++) {
                int c = ((a[i] & (1LL<<j)) ? 1 : 0);
                c ^= k;
                sum[i][j][k] = sum[i-1][j][k] + c;
            }
        }
    }

    ll ans = 2 * x % mod;
    ll cnt = 0;  // 使用的对数
    for (int i = 1; i <= n; i++) {
        int c = 1;
        for (int j = M; j >= 0; j--) {
            int k = ((a[i] & (1LL<<j)) ? 1 : 0);
            int d = ((x & (1LL<<j)) ? 1 : 0);

            if (d == 0) {
                // 加上所有 xor 起来为 1 的部分, 然后走到 0
                int p = trie[c].child[k^1];
                if (p) {
                    ans += sum_range(trie[p].L, trie[p].R, a[i]);
                    cnt += (trie[p].R - trie[p].L + 1);
                    ans %= mod;
                }
                c = trie[c].child[k];
            } else {
                c = trie[c].child[k^1];
            }
        }
    }
    cnt += 2;
    ans = (ans - ((cnt - 2 * m) % mod * x % mod) + mod) % mod;  // 减去重复的部分

    ans = ans * qpow(2, mod-2) % mod;
    return ans;
}

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i];
    if (!m) {
        cout << 0 << endl;
        return 0;
    }

    sort(a+1, a+n+1);
    for (int i = 1; i <= n; i++) insert(i, a[i]);

    ll x = find_kmax(2 * m);
    ll ans = get_sum(x);
    cout << ans << endl;
}

例6 CF1625D Binary Spiders

题意

给定 $n$ 个非负整数,和一个非整数 $k$。

求 $n$ 个非负整数的一个 subset,使得这个 subset 最大,并且 subset内 每两个数之间的 XOR $\geq k$。

其中,$n \leq 3 \times 10^5, k,a_i \in [0,2^{30}-1]$。

题解

首先我们设 $k$ 的最高位为 $j$。

那么我们将所有的 $a_i$ 分成三个部分:

  1. 最高位 $\geq j+1$ 的。
  2. 最高位 $= j$ 的。
  3. 最高位 $\leq j-1$ 的。

那么我们可以发现,对于 Case 2, Case 3,各只能最多选出 $1$ 个数来。

那么对于 Case 1 呢?

不如我们只考虑它们 bitmask的 prefix,即 $[j+1,30]$ 的这一部分。

我们注意到,如果两个 prefix 不同的数,一定是可以共存的。

如果两个数的 prefix 相同呢?是 有可能 可以共存的!

更准确的来说,如果两个数的 prefix 相同,我们可以把这部分抵消掉,那么问题就变成了最高位 $\leq j$ 的一些数可以最多选多少个出来共存。

这实际上就是 Case 2 和 Case 3 的并集,我们很容易发现这个并集内,最多可以选 $2$ 个数。

所以问题就简化了,我们只要根据 prefix 分类这些数,然后在每个分类中,选出最多两个数来即可。

而选择最多两个数判断是否 $\geq k$,用 01-Trie 可以轻松解决。


实现的过程:

  1. 先找出 $k$ 的最高位 $j$。
  2. 将所有的数字按照 $[j+1,30]$ 这一部分的prefix分类。
  3. 对于每一类,都构建 01-Trie 判断是否存在两个数的 XOR $\geq k$,存在就选择这两个数,否则只选一个。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+5;
const int maxm = 30;

int id = 1;  // 注意,从 1 开始
struct Node {
    int cnt = 0;
    int child[2];
} trie[maxn<<4];
void insert(int x, int delta) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;

        c = trie[c].child[k];
        trie[c].cnt += delta;
    }
}

// maximum query
int query(int x) {
    int c = 1;
    int res = 0;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        k ^= 1;
        if (trie[trie[c].child[k]].cnt) {
            c = trie[c].child[k];
            res |= (1<<j);
        } else {
            c = trie[c].child[k^1];
            // res |= (1<<j);
        }
    }
    return res;
}

int n,k;
struct Nd {
    int val, id;
} a[maxn];
int b[maxn];

vector<int> ans;
vector<pii> vec[maxn];
int main() {
    fastio;
    cin >> n >> k;
    for (int i = 1; i <= n; i++) cin >> a[i].val, a[i].id = i;
    sort(a+1, a+n+1, [](auto a, auto b) {
        return a.val > b.val;
    });

    if (k == 0) {
        cout << n << endl;
        for (int i = 1; i <= n; i++) cout << i << " ";
        cout << endl;
        return 0;
    }
    int p = 0;
    for (int j = 0; j <= 30; j++) {
        if (k & (1<<j)) p = j;
    }

    int ptr = 0;
    for (int i = 1; i <= n; i++) {
        b[i] = a[i].val;
        for (int j = 0; j <= p; j++) {
            if (b[i] & (1<<j)) b[i] ^= (1<<j);
        }
        if (!(i > 1 && b[i] == b[i-1])) {
            ++ptr;
        }
        vec[ptr].push_back({a[i].val, a[i].id});
    }

    for (int i = 1; i <= ptr; i++) {
        if (vec[i].size() >= 2) {
            bool ok = 0;
            for (int j = 0; j < vec[i].size(); j++) {
                if (!ok && query(vec[i][j].first) >= k) {
                    for (int a = 0; a < j; a++) {
                        if ((vec[i][j].first ^ vec[i][a].first) >= k) {
                            ans.push_back(vec[i][j].second);
                            ans.push_back(vec[i][a].second);
                            ok = 1;
                            break;
                        }
                    }
                }
                insert(vec[i][j].first, 1);
            }
            if (!ok) {
                ans.push_back(vec[i][0].second);
            }
            for (int j = 0; j < vec[i].size(); j++) {
                insert(vec[i][j].first, -1);
            }
        } else if (vec[i].size() == 1) {
            ans.push_back(vec[i][0].second);
        }
    }


    if (ans.size() <= 1) cout << -1 << endl;
    else {
        cout << ans.size() << endl;
        for (int j : ans) cout << j << " ";
        cout << endl;
    }
}

参考链接

  1. https://oi-wiki.org/string/trie/#_6
  2. https://oi-wiki.org/ds/persistent-trie/