简介

树状数组可以看做简化版的线段树,它进行单点修改和区间查询的常数是比线段树更优的。

定义:对于原数组 a[n]a[n],树状数组 c[n]c[n] 是一个等长的数组,并且对于任意 c[i]c[i] 表示以 a[i]a[i] 结尾且长度为 lowbit(i)\text{lowbit}(i) 的区间和。

例如,对于 c[i],i=(10100)2c[i],i=(10100)_2 它能覆盖的区间是 (10000,10100](10000,10100],它的直接子结点就会是 c[10010],c[10011]c[10010],c[10011];间接子结点就是 1000110001,这是因为 c[10001]c[10001] 是包含于 c[10010]c[10010] 的。

因而可以看出,计算父结点的方式就是 i+lowbit(i)i+\text{lowbit}(i),计算出与当前区间无缝连接的上一个区间的方式就是 ilowbit(i)i-\text{lowbit}(i)

模板题

已知一个数列,你需要进行下面两种操作:

  • 将某一个数加上 xx

  • 求出某区间每一个数的和

题目链接:P3374

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <cstdio>
#include <algorithm>
using namespace std;

const int N = 500010;
int tr[N], n, m;

#define lowbit(x) ((x)&(-x))

void add(int p, int v) {
for (; p < N; p += lowbit(p))
tr[p] += v;
}

int query(int p) {
int res = 0;
for (; p; p -= lowbit(p))
res += tr[p];
return res;
}

int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
int v; scanf("%d", &v);
add(i, v);
}

while (m--) {
int op, x, y;
scanf("%d%d%d", &op, &x, &y);
if (op == 1) add(x, y);
else printf("%d\n", query(y) - query(x-1));
}
return 0;
}

逆序对

定义:i<ji<jai>aja_i>a_j 就称为一个逆序对,统计逆序对数目。

题目链接:P1908

本题可以用归并排序那样的分治算法,并且它更好,但是这里我们用树状数组来解决这个问题。

首先注意到值域比较大,所以需要离散化。当枚举到 aia_i 时,我们需要知道前面有多少个数大于 aia_i,如果我们用树状数组来统计每个数字出现的次数,也就是求一下 [ai+1,n][a_i+1,n] 的区间和,其中 nn 是离散化后的最大值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;

typedef long long LL;
const int N = 500010;
int n, a[N];
LL tr[N];
vector<int> nums;

#define lowbit(x) ((x)&(-x))

void add(int p, int v) {
for (; p < N; p += lowbit(p))
tr[p] += v;
}

LL query(int p) {
LL res = 0;
for (; p; p -= lowbit(p))
res += tr[p];
return res;
}

LL query(int l, int r) {
return query(r) - query(l-1);
}

int find(int x) {
return lower_bound(nums.begin(), nums.end(), x) - nums.begin() + 1;
}

int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
nums.push_back(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
LL res = 0;
for (int i = 1; i <= n; i++) {
int t = find(a[i]);
res += query(t+1, nums.size());
add(t, 1);
}
printf("%lld\n", res);
return 0;
}

The Battle of Chibi (LIS)

简单题意:给 TT 组数据,长度为 nn 的数列 aa 中,找出长度为 mm 的严格上升子序列的个数,答案对 1e9+7 取模。

题目链接:UVA12983UVA12983(Luogu)

这是一道 DP 题,但是可以用树状数组来加速。

  • 状态表示 f(i,j)f(i,j):长度为 iiaja_j 结尾的最长上升子序列的个数。

  • 状态转移:

    f(i,j)=ak<aj,k<jf(i1,k)f(i,j)=\sum_{a_k<a_j, k<j} f(i-1,k)

    只要我们在循环到 jj 时把之前的所有 f(i1,k)f(i-1,k)ak<aja_k<a_j 的值累加起来即可,这可以用树状数组优化,树状数组的索引是 aia_i 的值,值是每个 f(i1,k)f(i-1,k) 每次求的都是 [0,ai1][0,a_i-1] 间所有满足要求的值之和。

    由于牵扯到值域的问题,这里就需要用到离散化。

  • 初始化:f(1,j)=1f(1,j)=1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;

const int N = 1010, mod = 1e9+7;
vector<int> nums;
int tr[N], n, m;
int a[N], f[N][N];
#define lowbit(x) ((x)&(-x))

inline void add(int p, int v) {
for (; p < N; p += lowbit(p)) {
tr[p] = (tr[p] + v) % mod;
}
}

inline int query(int p) {
int res = 0;
for (; p; p -= lowbit(p))
res = (res + tr[p]) % mod;
return res;
}

int find(int x) {
return lower_bound(nums.begin(), nums.end(), x) - nums.begin() + 1;
}

int solve() {
// clear f
for (int j = 1; j <= n; j++) f[1][j] = 1;
for (int i = 2; i <= m; i++) {
memset(tr, 0, sizeof tr);
for (int j = 1; j <= n; j++) {
f[i][j] = query(a[j]-1);
add(a[j], f[i-1][j]);
}
}
int res = 0;
for (int j = 1; j <= n; j++) res = (res + f[m][j]) % mod;
return res;
}

int main() {
int T; scanf("%d", &T);
for (int C = 1; C <= T; C++) {
nums.clear();
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
nums.push_back(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
for (int i = 1; i <= n; i++) {
a[i] = find(a[i]);
}
printf("Case #%d: %d\n", C, solve());
}
return 0;
}