Trie 和 01 Trie
Contents
介绍
01-Trie 指将整数拆成二进制,然后将二进制以字符串的形式储存在 Trie 中。
Trie可以解决以下问题:
-
给定一些数 $a_1,a_2,…,a_n$,给定 $x$,求 $a_i$,使得 $a_i \text{ xor } x$ 最大。
-
维护异或和:给定一些数,支持全体加一,插入数字,删除数字,求全体异或和等操作。
Trie可以将数字从低位到高位储存,也可以反过来,根据具体题目而定。
为了方便,我们在写 01Trie 的时候都会把每个数补成同样的位数。
• 注意 id = 1
的是 Root,所以 id
要从 $1$ 开始。
维护最大XOR
指第一个例子,这里我们的 01Trie 将会 从高位到低位 储存。
在给定一个查询 $x$ 时,我们从 $x$ 的高位开始看,设当前到了第 $i$ 位,那么我们看第 $i$ 位的bit $a_i$,然后判断一下第 $i$ 位的值为 a[i] ^ 1
的数字是否存在即可,如果存在,往那个方向走,否则往另外一个方向走。
板子
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
int id = 1; // 注意,从 1 开始
struct Node {
int cnt = 0;
int child[2];
} trie[maxn<<4];
void insert(int x) {
int c = 1;
for (int j = maxm; j >= 0; j--) {
int k = 0;
if (x & (1<<j)) k = 1;
if (!trie[c].child[k]) trie[c].child[k] = ++id;
c = trie[c].child[k];
trie[c].cnt++;
}
}
int query(int x) {
int c = 1;
int res = 0;
for (int j = maxm; j >= 0; j--) {
int k = 0;
if (x & (1<<j)) k = 1;
k ^= 1;
if (trie[c].child[k]) {
c = trie[c].child[k];
res |= (1<<j);
} else {
c = trie[c].child[k^1];
}
}
return res;
}
维护XOR和
指第二个例子,这里我们的 01Trie 从低位到高位储存。
要维护异或和,我们只需要知道每一位上 $0$ 和 $1$ 个数的奇偶性即可。
对于每一个节点,我们记录三个量:
- 两个子节点的编号
w
:指当前这个节点,到它的parent这条边 $(p,u)$ 上,被经过的次数。每有一个数字在插入过程中经过这个边,这个w
就会加一。val
:指以当前节点为根,它子树内包含的所有数字的 XOR和。
所以维护信息的时候就有:
void push_up(int cur) {
trie[cur].val = trie[cur].w = 0; // 先清空
int c0 = trie[cur].child[0], c1 = trie[cur].child[1];
// 更新 w
if (c0) trie[cur].w += trie[c0].w;
if (c1) trie[cur].w += trie[c1].w;
// 更新 val
if (c0) {
trie[cur].val ^= (trie[c0].val << 1);
}
if (c1) {
trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1)); // 如果 c1 的 w 为奇数,就提供了 1 的贡献
}
}
解释: 注意是从低位到高位储存的,所以子树内储存的只是数字的二进制的一部分。
所以要把 val
进行左移。
也就是说,实际上只有 根节点 的 val
是真正有意义的,因为只有根节点才记录的是完整的数字的 XOR和。
插入和删除就不提了,看代码即可。注意插入删除合并起来,可以达成单点修改的效果。
最后说一下全局加一。
全局加一指的是将整个01Trie维护的所有数字都加一。
在二进制下的加一,相当于从低位到高位找第一个出现的 $0$,然后将其变成 $1$。
最后,比这一位低的所有位都将会从 $1$ 变成 $0$。
// add 1 to all numbers maintained by cur
void addall(int cur) {
swap(trie[cur].child[0], trie[cur].child[1]);
if (trie[cur].child[0]) {
addall(trie[cur].child[0]);
}
push_up(cur);
}
板子
const int M = 21; // 最大深度
int id = 0;
struct Node {
int w; // 到 parent 这条边上的个数
int val; // subtree 内的XOR和
int child[2];
} trie[maxn*(M+1)];
void push_up(int cur) {
trie[cur].val = trie[cur].w = 0; // 先清空
int c0 = trie[cur].child[0], c1 = trie[cur].child[1];
// 更新 w
if (c0) trie[cur].w += trie[c0].w;
if (c1) trie[cur].w += trie[c1].w;
// 更新 val
if (c0) {
trie[cur].val ^= (trie[c0].val << 1);
}
if (c1) {
trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1)); // 如果 c1 的 w 为奇数,就提供了 1 的贡献
}
}
// cur: index, x: 当前的数字, dep: 深度
void insert(int& cur, int x, int dep) {
if (!cur) cur = ++id;
if (dep >= M) {
trie[cur].w++; // w += 1
return;
}
int c = x & 1;
insert(trie[cur].child[c], (x>>1), dep+1);
push_up(cur);
}
void del(int& cur, int x, int dep) {
if (!cur) cur = ++id;
if (dep >= M) {
trie[cur].w--; // w -= 1
return;
}
int c = x & 1;
insert(trie[cur].child[c], (x>>1), dep+1);
push_up(cur);
}
// add 1 to all numbers maintained by cur
void addall(int cur) {
swap(trie[cur].child[0], trie[cur].child[1]);
if (trie[cur].child[0]) {
addall(trie[cur].child[0]);
}
push_up(cur);
}
可持久化01-Trie
和主席树相似,可持久化01-Trie主要是为了让我们获得任何一个区间 $[L,R]$ 内组成的 01-Trie。
使用 Node 的 cnt
来决定 01-Trie 的内容。
板子
const int maxn = 6e5+5;
const int M = 28;
struct Node {
int cnt;
int child[2];
} trie[maxn * (M+1)];
int n, q, root[maxn], a[maxn], id;
void insert(int pre, int cur, int x) {
for (int i = M; i >= 0; i--) {
int c = 0;
if (x & (1 << i)) c = 1;
trie[cur].cnt = trie[pre].cnt + 1; // cnt 加一
if (!trie[cur].child[c]) trie[cur].child[c] = ++id; // 新建节点
trie[cur].child[c^1] = trie[pre].child[c^1]; // 复制另外一个子节点
cur = trie[cur].child[c];
pre = trie[pre].child[c];
}
trie[cur].cnt = trie[pre].cnt + 1;
}
// find the maximum value for a ^ x (a is in [pre, cur])
int query(int pre, int cur, int x) {
int res = 0;
for (int i = M; i >= 0; i--) {
int c = 0;
if (x & (1 << i)) c = 1;
c ^= 1;
if (trie[trie[cur].child[c]].cnt - trie[trie[pre].child[c]].cnt) { // cnt > 0
cur = trie[cur].child[c];
pre = trie[pre].child[c];
res |= (1<<i);
} else {
cur = trie[cur].child[c^1];
pre = trie[pre].child[c^1];
}
}
return res;
}
int main() {
for (int i = 1; i <= n; i++) {
cin >> a[i]; sum[i] = sum[i-1] ^ a[i];
if (!root[i]) root[i] = ++id;
insert(root[i-1], root[i], sum[i]);
}
// 查询 [L,R] -> (l-1,r)
int res = query(root[max(0,l-1)], root[r], x);
}
例题
例1 洛谷P4551 最长异或路径
题意
给定一个 $n$ 个节点的树,边上有权值。
寻找树上的两个节点,使得路径上的边权 XOR和 最大,输出这个最大值。
其中,$n \leq 10^5$。
题解
常见套路:设 $f_x$ 为从 $1$(根节点)到 $x$ 的路径上的 XOR和。
那么 $(u,v)$ 路径上的 XOR和 就等于 $f_u \text{ xor } f_v$。因为 $1$ 到 $LCA(u,v)$ 的部分被抵消掉了。
所以就相当于建立一个 01Trie,储存所有的 $f_u$,然后对于每个 $u$,都询问一下最大的XOR即可。
代码
#include <bits/stdc++.h>
using namespace std;
const ll mod = 1e9+7;
const int maxn = 1e5+5;
const int maxm = 30;
struct Edge {
int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, n;
void addEdge(int u, int v, int w) {
Edge e = {v, head[u], w};
edges[ecnt] = e;
head[u] = ecnt++;
}
int dp[maxn];
void dfs(int u, int p) {
for (int e = head[u]; e; e = edges[e].nxt) {
int to = edges[e].to;
if (to == p) continue;
dp[to] = dp[u] ^ edges[e].w;
dfs(to, u);
}
}
int id = 1;
struct Node {
int cnt = 0;
int child[2];
} trie[maxn<<4];
void insert(int x) {
int c = 1;
for (int j = maxm; j >= 0; j--) {
int k = 0;
if (x & (1<<j)) k = 1;
if (!trie[c].child[k]) trie[c].child[k] = ++id;
c = trie[c].child[k];
trie[c].cnt++;
}
}
int query(int x) {
int c = 1;
int res = 0;
for (int j = maxm; j >= 0; j--) {
int k = 0;
if (x & (1<<j)) k = 1;
k ^= 1;
if (trie[c].child[k]) {
c = trie[c].child[k];
res |= (1<<j);
} else {
c = trie[c].child[k^1];
}
}
return res;
}
int main() {
fastio;
cin >> n;
for (int i = 1; i <= n-1; i++) {
int u,v,w; cin >> u >> v >> w;
addEdge(u,v,w); addEdge(v,u,w);
}
dfs(1, 0);
for (int i = 1; i <= n; i++) insert(dp[i]);
int ans = 0;
for (int i = 1; i <= n; i++) {
ans = max(ans, query(dp[i]));
}
cout << ans << endl;
}
例2 CF817E Choosing The Commander
题意
给定 $q$ 个询问,共有3种:
$1 ~ p$:将 $p$ 加入集合。
$2 ~ p$:将 $p$ 从集合中删除。
$3 ~ p ~ l$:询问集合中有多少个元素在 XOR $p$ 之后小于 $l$。
其中,$q \leq 10^5$。
题解
在 01-Trie 每个节点处加上一个 count 就可以了。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 27;
int Q, id = 1;
struct Node {
int child[2];
int cnt = 0;
} trie[maxn<<3];
void insert(int x) {
int c = 1;
for (int j = maxm; j >= 0; j--) {
int k = 0;
if (x & (1<<j)) k = 1;
if (!trie[c].child[k]) trie[c].child[k] = ++id;
c = trie[c].child[k];
trie[c].cnt++;
}
}
void del(int x) {
int c = 1;
for (int j = maxm; j >= 0; j--) {
int k = 0;
if (x & (1<<j)) k = 1;
if (!trie[c].child[k]) trie[c].child[k] = ++id;
c = trie[c].child[k];
trie[c].cnt--;
}
}
int query(int p, int l) {
int c = 1;
int res = 0;
for (int j = maxm; j >= 0; j--) {
if (l & (1<<j)) { // 如果这一位有,说明要往这一位为 1 的方向走
int k = ((p & (1<<j)) > 0);
res += trie[trie[c].child[k]].cnt;
if (trie[trie[c].child[k^1]].cnt) c = trie[c].child[k^1];
else break;
} else { // 如果这一位没有,就往 0 的方向走
int k = ((p & (1<<j)) > 0);
if (trie[trie[c].child[k]].cnt) c = trie[c].child[k];
else break;
}
}
return res;
}
int main() {
fastio;
cin >> Q;
while (Q--) {
int op, p, l;
cin >> op >> p;
if (op == 1) {
insert(p);
} else if (op == 2) {
del(p);
} else {
cin >> l;
int res = query(p,l);
cout << res << "\n";
}
}
}
例3 洛谷P6018 [Ynoi2010] Fusion tree
题意
给定一棵 $n$ 个节点的树,每个节点 $i$ 上有权值,初始权值为 $a_i$。
给出 $m$ 个询问,询问的类型有3种:
$1 ~ x$:将所有与节点 $x$ 距离为 $1$ 的节点的权值加一。
$2 ~ x ~ v$:将节点 $x$ 的权值减去 $v$。
$3 ~ x$:询问所有与节点 $x$ 距离为 $1$ 的节点权值的XOR和。
其中,$n,m \leq 5 \times 10^5$,初始权值 $\leq 10^5$。
题解
这里是 01-Trie 用来维护XOR和。
我们可以对于每一个节点,都把它的孩子(距离为 $1$)的XOR和 维护在一个 01Trie里面。
因为 01-Trie 是动态开点的,所以总共复杂度也就 $n\log(10^5)$。
然后因为每个节点都只有一个 parent,就单独维护即可。
然后对于操作 $1 ~ x$,我们需要做的几件事:
- 单独更新 $x$ 的parent。
- 给 $x$ 打上一个懒标记,代表它要给它的孩子加上 $lazy$ 的值。
- 在 $x$ 维护的 01-Trie 上进行全体加一操作。
然后在每次操作之前,都先检查一下 $parent$ 的懒标记,将其下放。
当然这里的下放不太一样,一个节点可能有非常多个孩子,所以我们选择给每一个节点 $x$ 都定义一个值 add[x]
,代表它已经从parent那里得到了多少下放的懒标记。
然后就给 $x$ 的权值加上 (lazy[p] - add[x])
。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
const int M = 21; // 最大深度
int id = 0;
struct Node {
int w; // 到 parent 这条边上的个数
int val; // subtree 内的XOR和
int child[2];
} trie[maxn*(M+1)];
int root[maxn], lazy[maxn], add[maxn];
void push_up(int cur) {
trie[cur].val = trie[cur].w = 0; // 先清空
int c0 = trie[cur].child[0], c1 = trie[cur].child[1];
// 更新 w
if (c0) trie[cur].w += trie[c0].w;
if (c1) trie[cur].w += trie[c1].w;
// 更新 val
if (c0) {
trie[cur].val ^= (trie[c0].val << 1);
}
if (c1) {
trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1)); // 如果 c1 的 w 为奇数,就提供了 1 的贡献
}
}
// cur: index, x: 当前的数字, dep: 深度
void insert(int& cur, int x, int dep) {
if (!cur) cur = ++id;
if (dep >= M) {
trie[cur].w++;
return;
}
int c = x & 1;
insert(trie[cur].child[c], (x>>1), dep+1);
push_up(cur);
}
void del(int& cur, int x, int dep) {
if (!cur) cur = ++id;
if (dep >= M) {
trie[cur].w--;
return;
}
int c = x & 1;
insert(trie[cur].child[c], (x>>1), dep+1);
push_up(cur);
}
// add 1 to all numbers maintained by cur
void addall(int cur) {
swap(trie[cur].child[0], trie[cur].child[1]);
if (trie[cur].child[0]) {
addall(trie[cur].child[0]);
}
push_up(cur);
}
int n, m, a[maxn], par[maxn];
struct Edge {
int to, nxt;
} edges[maxn<<1];
int head[maxn], ecnt = 1;
void addEdge(int u, int v) {
Edge e = {v, head[u]};
head[u] = ecnt;
edges[ecnt++] = e;
}
void dfs(int u) {
for (int e = head[u]; e; e = edges[e].nxt) {
int to = edges[e].to;
if (to == par[u]) continue;
par[to] = u;
insert(root[u], a[to], 0);
dfs(to);
}
}
// modify u to have the value val
void modify(int u, int val) {
// a[u] is previous value
int p = par[u];
if (p) {
del(root[p], a[u], 0);
insert(root[p], val, 0);
}
a[u] = val;
}
int main() {
fastio;
cin >> n >> m;
for (int i = 1; i <= n-1; i++) {
int u,v; cin >> u >> v;
addEdge(u,v); addEdge(v,u);
}
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
dfs(1);
while (m--) {
int op, x; cin >> op >> x;
if (op == 1) {
addall(root[x]);
int p = par[x];
if (p) {
a[p] = a[p] + lazy[par[p]] - add[p];
add[p] = lazy[par[p]];
modify(par[x], a[p] + 1);
}
lazy[x]++;
}
if (op == 2) {
int v; cin >> v;
a[x] = a[x] + lazy[par[x]] - add[x];
add[x] = lazy[par[x]];
modify(x, a[x] - v);
}
if (op == 3) {
int res = trie[root[x]].val;
int p = par[x];
if (p) {
a[p] = a[p] + lazy[par[p]] - add[p];
add[p] = lazy[par[p]];
res ^= a[p];
}
cout << res << "\n";
}
}
}
例4 洛谷P4735 最大异或和
题意
给定一个非负整数序列 $a_1,a_2,…,a_n$。
现给定 $m$ 个操作,有以下两种操作类型:
$A ~ x$:在序列末尾添加一个数 $x$。
$Q ~ L ~ R ~ x$:输出一个位置 $p$,满足 $p \in [L,R]$,使得 $a_p \bigoplus a_{p+1} \bigoplus … \bigoplus a_n \bigoplus x$ 最大。
其中,$n,m \leq 3 \times 10^5, a_i \in [0,10^7]$。
题解
可持久化 01-Trie 的模版题。
首先,先考虑前缀XOR数组 $s_i = a_1 \bigoplus a_2 \bigoplus … \bigoplus a_i$。
那么
$$a_p \bigoplus a_{p+1} \bigoplus … \bigoplus a_n \bigoplus x = (s_n \bigoplus s_{p-1}) \bigoplus x$$
所以现在问题就变成:
每次询问 $Q ~ L ~ R ~ x$,输出一个位置 $p$,满足 $p \in [L-1,R-1]$,使得 $(s_n \bigoplus s_{p-1}) \bigoplus x$ 最大。
那就很简单了,令 $y = s_n \bigoplus x$,剩下的就是在一个区间 $[L-1,R-1]$ 内找出一个元素 $s_p$ 使得 $s_p \bigoplus y$ 最大。
对于每一个区间都可以用 01-Trie 解决的问题,就是可持久化 01-Trie了。
最后需要注意一下这个样例:
4 1
2 4 8 16
Q 1 4 1
答案应为 $31$,但我输出 $29$。
这是因为我们要单独处理一下 $s_0$ 的情况。换而言之我们要将 $s_0 = 0$ 插入到 root[1]
当中。
所以 root[1]
需要包含 $2$ 个元素:$s_0,s_1$。
对于一个版本插入 $2$ 个元素,只要在第二次插入时使用 insert(root[1], root[1], 0)
即可。
最后再注意一下一些边界条件如 $[L-1,R-1]=[0,0]$ 即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 6e5+5;
const int M = 28;
struct Node {
int cnt;
int child[2];
} trie[maxn * (M+1)];
int sum[maxn]; // sum[n] 代表 a[1] ^ a[2] ... ^ a[n]
int n, q, root[maxn], a[maxn], id;
void insert(int pre, int cur, int x) {
for (int i = M; i >= 0; i--) {
int c = 0;
if (x & (1 << i)) c = 1;
trie[cur].cnt = trie[pre].cnt + 1; // cnt 加一
if (!trie[cur].child[c]) trie[cur].child[c] = ++id; // 新建节点
trie[cur].child[c^1] = trie[pre].child[c^1]; // 复制另外一个子节点
cur = trie[cur].child[c];
pre = trie[pre].child[c];
}
trie[cur].cnt = trie[pre].cnt + 1;
}
// find the maximum value for a ^ x (a is in [pre, cur])
int query(int pre, int cur, int x) {
int res = 0;
for (int i = M; i >= 0; i--) {
int c = 0;
if (x & (1 << i)) c = 1;
c ^= 1;
if (trie[trie[cur].child[c]].cnt - trie[trie[pre].child[c]].cnt) { // cnt > 0
cur = trie[cur].child[c];
pre = trie[pre].child[c];
res |= (1<<i);
} else {
cur = trie[cur].child[c^1];
pre = trie[pre].child[c^1];
}
}
return res;
}
int main() {
fastio;
cin >> n >> q;
for (int i = 1; i <= n; i++) {
cin >> a[i]; sum[i] = sum[i-1] ^ a[i];
if (!root[i]) root[i] = ++id;
insert(root[i-1], root[i], sum[i]);
if (i == 1) insert(root[1], root[1], 0);
}
while (q--) {
char op; cin >> op;
if (op == 'A') {
int x; cin >> x;
a[++n] = x;
sum[n] = sum[n-1] ^ a[n];
if (!root[n]) root[n] = ++id;
insert(root[n-1], root[n], sum[n]);
} else {
int l,r,x; cin >> l >> r >> x;
l--; r--;
x ^= (sum[n]);
if (l == 0 && r == 0) {
cout << x << "\n"; continue;
}
int res = query(root[max(0,l-1)], root[r], x);
cout << res << "\n";
}
}
}
例5 CF241B Friends
题意
给定一个长度为 $n$ 的非负整数数组 $a_1,a_2,…,a_n$。
现在选择 $m$ 个pair $(i,j)$,使得这 $m$ 个pair 对应的 $a_i \text{ xor } a_j$ 值的 sum 最大。
其中,$n \leq 5 \times 10^4, a_i \leq 10^9$,答案对 $10^9+7$ 取模。
• $(i,j)$ 和 $(j,i)$ 算作一个 pair。
题解
神仙题。
本题要求的是前 $m$ 大的XOR pair的和。
我们先列出本题的过程:
- 求出第 $m$ 大。
- 求出 XOR 小于等于 $x$ 的所有pair的和。
- 对于一个区间 $[L,R]$,在 $O(1)$ 求出 $\sum\limits_{i=L}^R y \text{ xor } a_i$。
第一步:求第 $m$ 大
• 虽然 $(i,j)$ 和 $(j,i)$ 算作一个 pair,但我们只要把所有东西都乘以 $2$ 就行了。所以我们求第 $2m$ 大。
最简单粗暴的方法当然是二分最终的结果 $x$,然后对于每一个 $a_i$,判断一下在 Trie 中有多少个数字和它 XOR 起来 $\geq x$,把这些 count 累加起来然后查看 count >= 2m
与否,来调整 $x$ 的值。
这样的时间复杂度是 $O(n\log^2 10^9)$,本题可过,但是另外一个版本不可过。
于是我们思考一下能不能直接在 Trie 上模拟这个二分的过程来构建最终的结果 $x$。
发现是可以的,我们从高位到低位,枚举每一位取 $1$ 还是 取 $0$。
然后我们对于每一个 $a_i$ 都判断一下 count 的总和,如果 count >= 2m
,则说明这一位可以取 $1$。否则的话需要取 $0$,同时问题变成寻找第 $2m - count$ 大(有点类似于主席树寻找第 $k$ 大的思路)。
• 这个过程是在所有 $a_i$ 上进行的,所以我们维护 $n$ 个指针 ptr[]
代表 $a_i$ 目前在 Trie 的什么位置。
复杂度为 $O(n\log 10^9)$。
第二步:求出 XOR 小于等于 $x$ 的所有pair的和
如果这个问题不是求出和,而是求出个数的话,就是例 $2$ 的问题了。
可惜不是,那怎么处理?
一样,我们从每个 $a_i$ 开始进行 check,对于每一个 Trie 上的节点,我们可以维护一个 cnt[31]
,其中 cnt[j]
代表当前Trie节点的子树里有多少个数字满足第 $j$ 位为 $1$。
这样当然可以算,但是空间复杂度 $O(n\log^2 10^9)$,本题或许可过,但另外一个版本仍然不可过。
我们这里给出一个非常优秀的结论:
如果 $a_i$ 是 sorted 的,那么 Trie 上任意一个节点的子树 对应 $a_i$ 上的一段连续区间。
这个也不难理解,因为 Trie 的一个子树对应的是拥有一个 “特定前缀” (如 “1101”)的所有数字的集合。
有了这个结论以后,我们在枚举 $a_i$ 时,假设我们目前到了第 $j$ 位,就分两种情况:
- $x$ 的第 $j$ 位为 $1$:那么我们就往 $1$ 的方向走。
- $x$ 的第 $j$ 位为 $0$:我们把 $a_i$ XOR 上 $1$ 的那个子树,求出sum,给答案加上,然后往 $0$ 的方向走。
注意到 $a_i$ XOR 上一个子树的 sum,可以转化成一个区间上的问题,那么就有第三步:
第三步:对于一个区间 $[L,R]$,在 $O(1)$ 求出 $\sum\limits_{i=L}^R y \text{ xor } a_i$。
由于 $a_i$ 总共有 $n$ 个,所以不能直接求前缀和。
但我们可以把区间内的每个数都 拆成二进制。
然后对于拆开的二进制,每一位分别 XOR 上 $0$ 和 $1$,然后都进行一个前缀和。
这样给定一个数字 $y$,就可以把 $y$ 拆成二进制,然后对于每一位 $j$,把 $[L,R]$ 内对应的和用 $O(1)$ 时间内求出来即可。
• 所以我们要做的事就是先 sort 一下整个 $a_i$ 数组,然后在构建 Trie 的时候记录一下每个子树对应哪个区间即可。
最后要注意,由于第 $2m$ 大有可能 等于 第 $2m+1, 2m+2 …$ 大,所以要处理掉额外加上的部分。
答案除以 $2$ 即可。
代码
#include <bits/stdc++.h>
using namespace std;
const ll mod = 1e9+7;
const int maxn = 5e4+5;
const int M = 31;
struct Node {
int cnt = 0;
int child[2];
int L = 1e9, R = -1; // left and right segments
} trie[maxn * (M+1)];
int id = 1;
void insert(int idx, ll x) {
int c = 1;
for (int i = M; i >= 0; i--) {
int k = ((x & (1LL<<i)) ? 1 : 0);
if (!trie[c].child[k]) trie[c].child[k] = ++id;
c = trie[c].child[k];
trie[c].cnt++;
trie[c].L = min(trie[c].L, idx); // 维护子树对应的区间
trie[c].R = max(trie[c].R, idx);
}
}
ll a[maxn];
ll n,m;
// 寻找第 k 大 (本题是找到第 2m 大)
int ptr[maxn]; // 每一个 ai 对应的 Trie 上的 ptr
ll find_kmax(ll t) {
ll res = 0;
fill(ptr+1, ptr+n+1, 1); // 开始都在 1
for (int j = M; j >= 0; j--) {
ll cnt = 0;
// 找这一位 >= 1 的有多少个
for (int i = 1; i <= n; i++) {
int k = ((a[i] & (1LL<<j)) ? 1 : 0);
k ^= 1;
cnt += trie[trie[ptr[i]].child[k]].cnt;
}
int nxt;
if (cnt >= t) {
res |= (1LL<<j);
nxt = 1;
} else {
t -= cnt;
nxt = 0;
}
for (int i = 1; i <= n; i++) {
int k = ((a[i] & (1LL<<j)) ? 1 : 0);
k ^= nxt; // 由这一步决定的哪一个来判断每一个 ptr 的更新位置
ptr[i] = trie[ptr[i]].child[k];
}
}
return res;
}
int sum[maxn][M+1][2];
// 计算 a[L,R] ^ x 之和
ll sum_range(ll l, ll r, ll x) {
ll ans = 0;
for (int j = M; j >= 0; j--) {
int k = ((x & (1LL<<j)) ? 1 : 0);
ll s = sum[r][j][k] - sum[l-1][j][k];
ans += s * (1LL << j);
ans %= mod;
}
return ans;
}
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
// 获得 ai ^ aj >= x 的所有数之和
ll get_sum(ll x) {
// 先把每个数拆成二进制,做前缀和
for (int j = M; j >= 0; j--) {
for (int i = 1; i <= n; i++) {
for (int k = 0; k <= 1; k++) {
int c = ((a[i] & (1LL<<j)) ? 1 : 0);
c ^= k;
sum[i][j][k] = sum[i-1][j][k] + c;
}
}
}
ll ans = 2 * x % mod;
ll cnt = 0; // 使用的对数
for (int i = 1; i <= n; i++) {
int c = 1;
for (int j = M; j >= 0; j--) {
int k = ((a[i] & (1LL<<j)) ? 1 : 0);
int d = ((x & (1LL<<j)) ? 1 : 0);
if (d == 0) {
// 加上所有 xor 起来为 1 的部分, 然后走到 0
int p = trie[c].child[k^1];
if (p) {
ans += sum_range(trie[p].L, trie[p].R, a[i]);
cnt += (trie[p].R - trie[p].L + 1);
ans %= mod;
}
c = trie[c].child[k];
} else {
c = trie[c].child[k^1];
}
}
}
cnt += 2;
ans = (ans - ((cnt - 2 * m) % mod * x % mod) + mod) % mod; // 减去重复的部分
ans = ans * qpow(2, mod-2) % mod;
return ans;
}
int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
if (!m) {
cout << 0 << endl;
return 0;
}
sort(a+1, a+n+1);
for (int i = 1; i <= n; i++) insert(i, a[i]);
ll x = find_kmax(2 * m);
ll ans = get_sum(x);
cout << ans << endl;
}
例6 CF1625D Binary Spiders
题意
给定 $n$ 个非负整数,和一个非整数 $k$。
求 $n$ 个非负整数的一个 subset,使得这个 subset 最大,并且 subset内 每两个数之间的 XOR $\geq k$。
其中,$n \leq 3 \times 10^5, k,a_i \in [0,2^{30}-1]$。
题解
首先我们设 $k$ 的最高位为 $j$。
那么我们将所有的 $a_i$ 分成三个部分:
- 最高位 $\geq j+1$ 的。
- 最高位 $= j$ 的。
- 最高位 $\leq j-1$ 的。
那么我们可以发现,对于 Case 2, Case 3,各只能最多选出 $1$ 个数来。
那么对于 Case 1 呢?
不如我们只考虑它们 bitmask的 prefix,即 $[j+1,30]$ 的这一部分。
我们注意到,如果两个 prefix 不同的数,一定是可以共存的。
如果两个数的 prefix 相同呢?是 有可能 可以共存的!
更准确的来说,如果两个数的 prefix 相同,我们可以把这部分抵消掉,那么问题就变成了最高位 $\leq j$ 的一些数可以最多选多少个出来共存。
这实际上就是 Case 2 和 Case 3 的并集,我们很容易发现这个并集内,最多可以选 $2$ 个数。
所以问题就简化了,我们只要根据 prefix 分类这些数,然后在每个分类中,选出最多两个数来即可。
而选择最多两个数判断是否 $\geq k$,用 01-Trie 可以轻松解决。
实现的过程:
- 先找出 $k$ 的最高位 $j$。
- 将所有的数字按照 $[j+1,30]$ 这一部分的prefix分类。
- 对于每一类,都构建 01-Trie 判断是否存在两个数的 XOR $\geq k$,存在就选择这两个数,否则只选一个。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+5;
const int maxm = 30;
int id = 1; // 注意,从 1 开始
struct Node {
int cnt = 0;
int child[2];
} trie[maxn<<4];
void insert(int x, int delta) {
int c = 1;
for (int j = maxm; j >= 0; j--) {
int k = 0;
if (x & (1<<j)) k = 1;
if (!trie[c].child[k]) trie[c].child[k] = ++id;
c = trie[c].child[k];
trie[c].cnt += delta;
}
}
// maximum query
int query(int x) {
int c = 1;
int res = 0;
for (int j = maxm; j >= 0; j--) {
int k = 0;
if (x & (1<<j)) k = 1;
k ^= 1;
if (trie[trie[c].child[k]].cnt) {
c = trie[c].child[k];
res |= (1<<j);
} else {
c = trie[c].child[k^1];
// res |= (1<<j);
}
}
return res;
}
int n,k;
struct Nd {
int val, id;
} a[maxn];
int b[maxn];
vector<int> ans;
vector<pii> vec[maxn];
int main() {
fastio;
cin >> n >> k;
for (int i = 1; i <= n; i++) cin >> a[i].val, a[i].id = i;
sort(a+1, a+n+1, [](auto a, auto b) {
return a.val > b.val;
});
if (k == 0) {
cout << n << endl;
for (int i = 1; i <= n; i++) cout << i << " ";
cout << endl;
return 0;
}
int p = 0;
for (int j = 0; j <= 30; j++) {
if (k & (1<<j)) p = j;
}
int ptr = 0;
for (int i = 1; i <= n; i++) {
b[i] = a[i].val;
for (int j = 0; j <= p; j++) {
if (b[i] & (1<<j)) b[i] ^= (1<<j);
}
if (!(i > 1 && b[i] == b[i-1])) {
++ptr;
}
vec[ptr].push_back({a[i].val, a[i].id});
}
for (int i = 1; i <= ptr; i++) {
if (vec[i].size() >= 2) {
bool ok = 0;
for (int j = 0; j < vec[i].size(); j++) {
if (!ok && query(vec[i][j].first) >= k) {
for (int a = 0; a < j; a++) {
if ((vec[i][j].first ^ vec[i][a].first) >= k) {
ans.push_back(vec[i][j].second);
ans.push_back(vec[i][a].second);
ok = 1;
break;
}
}
}
insert(vec[i][j].first, 1);
}
if (!ok) {
ans.push_back(vec[i][0].second);
}
for (int j = 0; j < vec[i].size(); j++) {
insert(vec[i][j].first, -1);
}
} else if (vec[i].size() == 1) {
ans.push_back(vec[i][0].second);
}
}
if (ans.size() <= 1) cout << -1 << endl;
else {
cout << ans.size() << endl;
for (int j : ans) cout << j << " ";
cout << endl;
}
}