树状数组学习笔记

高效又好写的数据结构

简介

树状数组二叉索引树(英语:Binary Indexed Tree),又以其发明者命名为Fenwick树,最早由Peter M. Fenwick于1994年以A New Data Structure for Cumulative Frequency Tables为题发表在SOFTWARE PRACTICE AND EXPERIENCE。其初衷是解决数据压缩里的累积频率(Cumulative Frequency)的计算问题,现多用于高效计算数列的前缀和, 区间和。它可以以 {\displaystyle O(\log n)} 的时间得到任意前缀和 {\displaystyle \sum _{i=1}^{j}A[i],1\le j\le N} !,并同时支持在 {\displaystyle O(\log n)} 时间内支持动态单点值的修改。空间复杂度 {\displaystyle O(n)}

——Wikipedia

简单的说,树状数组就是一个便于在 O(\log n) 时间内维护一个数列 / 矩阵的前缀和,可以支持单点修改、查询,区间修改、查询的数据结构。

依据支持操作的不同(包含关系),我这里把它分为六类:

  • 支持序列单点加减、区间和查询的树状数组
  • 支持序列区间加减、单点查询的树状数组
  • 支持序列区间加减、区间和查询的树状数组
  • 支持矩阵单点加减、子矩阵和查询的树状数组
  • 支持矩阵的子矩阵加减、单点查询的树状数组
  • 支持矩阵的子矩阵加减、子矩阵和查询的树状数组

这些会一个一个的讲。

序列操作

单点加减、区间和查询

这个是最基础的树状数组,应该没有人不会吧……

原理就是通过维护前缀和,修改的时候像暴力维护前缀和一样一个一个往后加,不过每次增长的值不是1而是lowbit,其中“一个数取lowbit能跳到哪”这个关系连边后就形成了一个二叉搜索树。

按照Peter M. Fenwick的说法,正如所有的整数都可以表示成2的幂和,我们也可以把一串序列表示成一系列子序列的和。采用这个想法,我们可将一个前缀和划分成多个子序列的和,而划分的方法与数的2的幂和具有极其相似的方式。一方面,子序列的个数是其二进制表示中1的个数,另一方面,子序列代表的f[i]的个数也是2的幂。

——Wikipedia

KAb2Se.png

比如说这一棵就是八个元素的树状数组,对照下面的表可以发现上面的连边规律(点下面的是编号,请自动忽略根节点 9 以及那条边)。

1
2
3
4
5
6
7
8
9
10
1's lowbit = 1, 1 + lowbit = 2
2's lowbit = 2, 2 + lowbit = 4
3's lowbit = 1, 3 + lowbit = 4
4's lowbit = 4, 4 + lowbit = 8
5's lowbit = 1, 5 + lowbit = 6
6's lowbit = 2, 6 + lowbit = 8
7's lowbit = 1, 7 + lowbit = 8
8's lowbit = 8, 8 + lowbit = 16
9's lowbit = 1, 9 + lowbit = 10
10's lowbit = 2, 10 + lowbit = 12

那么代码就很容易写出来了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int n, tree[MAX_SIZE];
// n 为元素个数,tree[] 为树状数组维护的前缀和

int lowbit(int x) { return (x) & (-x); }

void Modify(int pos, int x) {
// 将 pos 位置的数加上 x
for (; pos <= n; pos += lowbit(pos)) tree[pos] += x;
}

int Query(int pos) {
// 查询 [1,pos] 之间的数的和
int ret = 0;
for (; pos >= 1; pos -= lowbit(pos)) ret += tree[pos];
return ret;
}

int rangeQuery(int l, int r) {
// 查询 [l,r] 之间的数的和
return Query(r) - Query(l - 1);
}

区间加减、单点查询

不知道你们有没有听说过一个东西叫做「差分」

定义差分数组 d[i] = a[i] - a[i - 1],其中 a[] 表示原数列
那么对 d[i] 求一个前缀和就可以得出 a[i]的值了
举个例子:

1
2
3
数组下标从 0 开始,元素存储从 1 开始,a[0] = d[0] = 0
a[] = {0, 1, 3, 4, 2}
d[] = {/, 1, 2, 2, -2}

发现了什么?
MATHJAX-SSR-45

如何修改 \text{[L,R]}+x
先给结论:在 \text{L} +x ,在 \text{R+1} -x
直观理解:

1
2
3
4
5
6
7
下标从 1 开始。
原数列:0 0 0 0 0 0
按照上面的方法 [2,4]+x
0 x 0 0 -x 0
看看前缀和之后会发生什么……
0 x x x 0 0
!!!!!

而维护前缀和这种事情,树状数组最在行了

可以写出代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
int n, a[MAXN], bit[MAXN];  
// n 为元素个数,a 为原数组,bit 为差分数组

void Modify(int pos, int x) {
for (; pos <= n; pos += lowbit(pos)) bit[pos] += x;
}

void rangeModify(int l, int r, int x) {
Modify(l, x); Modify(r + 1, -x);
}

int Query(int pos) {
int ret = 0;
for (; pos >= 1; pos -= lowbit(pos)) ret += bit[pos];
return ret;
}

int main() {
...
for (i = 1 to n increase 1) {
read a[i]
rangeModify(i, i, a[i]);
}
...
read x, read y, read k
rangeModify(x, y, k);
// 将 [x,y] 区间内的数加上 k
...
read k
printf("%d\n", Query(k));
// 查询 k 位置的元素
}

区间加减、区间和查询

线段树天下第一
但是线段树难写、难调,常数还大,占空间还多。。。

如果你只需要区间加减、区间和查询,树状数组无疑是你最好的选择


区间加减维护一下差分数组就行了

考虑区间和本质是

\sum_{a = 1}^{p}\sum_{i = 1}^{a}d_i

计算一下每个 d_i 被算的次数,顺便把式子变换一下

\sum_{a = 1}^{p}d_a \times (p - a + 1)

1
2
3
4
5
6
7
8
9
举个例子
比如说 p = 5 时,可以发现
ans =
d[1] +
d[1] + d[2] +
d[1] + d[2] + d[3] +
d[1] + d[2] + d[3] + d[4] +
d[1] + d[2] + d[3] + d[4] + d[5]
找一找规律就可以搞出上面的式子了

拆一下 \sum ,可以变换成

(p + 1)\sum_{a = 1}^{p}d_a - \sum_{a = 1}^{p}d_a \times a

这样的话,只需要分别维护两个差分数组,一个记 d_a ,一个记 d_a \times a 就行

修改 \text{[L,R] + }x 的时候,像上面区间加减、单点查询一样,把 \text{[L]} + x,\text{[R+1]} - x (对两个数组进行的修改可以合并到 Modify() 函数中,具体见代码)
查询的时候像上面单点加减、区间和查询一样,是前缀和作差

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// 代码没有经过提交,仅进行了一些小样例测试!

typedef long long int lli;

int n, m;
lli ss[MAXN];
lli biti[MAXN], bitpi[MAXN];
// ss 表示原数组,biti 表示维护 d[a] 的数组, bitpi 表示维护 d[a] * a 的数组

void Modify(int pos, lli x) {
int dx = pos;
for (; pos <= n; pos += lowbit(pos)) {
// 为了方便,可以把 rangeModify() 里的乘法挪到 Modify() 里面
biti[pos] += x; bitpi[pos] += x * 1ll * dx;
}
}
void rangeModify(int l, int r, lli x) {
// 这是把括号里的乘法挪到 Modify() 里面的写法
Modify(l, x); Modify(r + 1, -x);
}
lli Query(int pos) {
lli ret = 0, dx = pos;
while (pos >= 1) { ret += (dx + 1) * 1ll * biti[pos] - bitpi[pos]; pos -= lowbit(pos); }
return ret;
}
lli rangeQuery(int l, int r) { return Query(r) - Query(l - 1); }

矩阵操作

一维的操作都讲完了,那能不能把它推广到二维上面呢?答案是肯定的。
提前说一句,以下操作从访问 n 个元素变成了 nm 个元素,时间复杂度变为 O(\log(nm))

单点加减、子矩阵和查询

前面说过,树状数组是利用前缀和的思想进行实现的,既然二维也有前缀和,何不照葫芦画瓢把而为树状数组搞出来呢?


先来复习一下。

\sum_{i = l}^{r} a_i = \sum_{i = 1}^{r} a_i - \sum_{i = 1}^{l - 1} a_i

为了方便,定义 MATHJAX-SSR-50

\sum_{i = x_1}^{x_2}\sum_{j = y_1}^{y_2}a_{i,j}=f(x_2,y_2)-f(x_1 - 1,y_2)-f(x_2,y_1-1)+f(x_1-1,y_1-1)

直观来看,

KE3Zkj.png

定义 \text{Sum}(a,b,c,d) 为以 (a,b) 为左下角, (c,d) 为右上角(对于矩阵是反着的)的矩阵元素之和,那么很显然能看出 \text{Sum}(5,4,7,5)=\text{Sum}(1,1,7,5)-\text{Sum}(1,1,7,3)-\text{Sum}(1,1,4,5)+\text{Sum}(1,1,4,3) ,也就是四边形 \text{ABCD}-\text{ABGI}-\text{AHFD}+\text{AHEI} 元素的值

二维树状数组和一位的除了多了一维之外没多大区别,手法从一维前缀和换到了二维前缀和

看代码就知道了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// 代码没有经过提交,仅进行了一些小样例测试!

int n, m, q, bit[MAXN][MAXN];

// 是不是和一维的手法差不多(逃
void Modify(int x, int y, int w) {
for (; x <= n; x += lowbit(x)) {
for (int fy = y; fy <= m; fy += lowbit(fy)) {
// 说一个坑:这里不要对 y 进行直接修改
// 因为下一次循环 x 的时候需要用 y
// 我当初在这里栽坑调了快 10min。。。
bit[x][fy] += w;
}
}
}
int Query(int x, int y) {
int ans = 0;
for (; x >= 1; x -= lowbit(x)) {
for (int fy = y; fy >= 1; fy -= lowbit((fy))) {
ans += bit[x][fy];
}
}
return ans;
}
int matrixQuery(int x1, int y1, int x2, int y2) {
// x1 <= x2, y1 <= y2
int a = Query(x2, y2);
int b = Query(x1 - 1, y2);
int c = Query(x2, y1 - 1);
int d = Query(x1 - 1, y1 - 1);
return a - b - c + d;
}

子矩阵加减、单点查询

还记得区间加减、单点查询吗?
接下来把它推广到二维!


查询手法一样的,二维前缀和

如何修改 (x_1,y_1)\text{ to }(x_2,y_2)
先说结论:
d[x_1][y_1] + x,d[x_1][y_2+1]-x,d[x_2+1][y_1]-x,d[x_2+1][y_2+1]+x
直观理解:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
下标从 1 开始
1 2 3 4 5
1 0 0 0 0 0
2 0 0 0 0 0
3 0 0 0 0 0
4 0 0 0 0 0
修改(1,2)->(4,3) + x
1 2 3 4 5
1 0 0 0 0 0
2 x 0 0 0 -x
3 0 0 0 0 0
4 -x 0 0 0 x
前缀和:
1 2 3 4 5
1 0 0 0 0 0
2 x x x x 0
3 x x x x 0
4 0 0 0 0 0

放代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// 代码没有经过提交,仅进行了一些小样例测试!

int n, m, q, bit[MAXN][MAXN];
// n,m 为矩阵大小,bit 为差分数组

void Modify(int x, int y, int w) {
for (int i = x; i <= n; i += lowbit(i)) {
for (int j = y; j <= m; j += lowbit(j)) {
bit[i][j] += w;
}
}
}
void matrixModify(int x1, int y1, int x2, int y2, int w) {
// x1 <= x2
Modify(x1, y1, w); Modify(x1, y2 + 1, -w);
Modify(x2 + 1, y1, -w); Modify(x2 + 1, y2 + 1, w);
}
int Query(int x, int y) {
int ret = 0;
for (int i = x; i >= 1; i -= lowbit(i)) {
for (int j = y; j >= 1; j -= lowbit(j)) {
ret += bit[i][j];
}
}
return ret;
}

子矩阵加减、子矩阵和查询

最后一种操作,也是最难的操作

……其实并不难,如果你把前面都学懂了。

区间加减、区间和查询一样,先看看查询操作的本质

\sum_{i=1}^{x} \sum_{j=1}^{y} \sum_{k=1}^{i} \sum_{h=1}^{j} d[h][k]

先统计一下 d[i][j] 被访问了多少次,然后稍微整理一下式子,变成

{\sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j] \times(x+1-i) \times(y+1-j)} \\= {(x+1)(y+1) \times \sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j]} \\ {-(y+1) \times \sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j] \times i} \\ {-(x+1) \times \sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j] \times j} \\ {\quad+\sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j] \times i \times j}

所以,实现区修区查需要维护四个差分数组!

  • 第一个:维护 d[i][j]
  • 第二个:维护 d[i][j]\times i
  • 第三个:维护 d[i][j]\times j
  • 第四个:维护 d[i][j]\times i\times j

接下来是完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//  
// Created by HandwerSTD on 2019/10/17.
//

// 洛谷 P4514 《上帝造题的七分钟》
// 常数略大。。开O2过的

#include <algorithm>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdio>
#include <vector>

#define FILE_IN(__fname) freopen(__fname, "r", stdin)
#define FILE_OUT(__fname) freopen(__fname, "w", stdout)
#define rap(a,s,t,i) for (int a = s; a <= t; a += i)
#define basketball(a,t,s,i) for (int a = t; a > s; a -= i)
#define countdown(s) while (s --> 0)
#define IMPROVE_IO() std::ios::sync_with_stdio(false)
#define lowbit(x) ((x & (-x)))

using std::cin;
using std::cout;
using std::endl;

typedef long long int lli;

int getint() { int x; scanf("%d", &x); return x; }
lli getll() { long long int x; scanf("%lld", &x); return x; }

const int MAXN = 2048 + 10;

int n, m, q;

namespace BIT {
int d[MAXN][MAXN], di[MAXN][MAXN];
int dj[MAXN][MAXN], dij[MAXN][MAXN];

void Modify(int x, int y, int w) {
for (int i = x; i <= n; i += lowbit(i)) {
for (int j = y; j <= m; j += lowbit(j)) {
d[i][j] += w; di[i][j] += w * x;
dj[i][j] += w * y; dij[i][j] += w * x * y;
}
}
}
void matrixModify(int x1, int y1, int x2, int y2, int w) {
Modify(x1, y1, w); Modify(x2 + 1, y2 + 1, w);
Modify(x1, y2 + 1, -w); Modify(x2 + 1, y1, -w);
}
int Query(int x, int y) {
int ret = 0;
for (int i = x; i >= 1; i -= lowbit(i)) {
for (int j = y; j >= 1; j -= lowbit(j)) {
ret += d[i][j] * (x + 1) * (y + 1)
- (y + 1) * di[i][j]
- (x + 1) * dj[i][j]
+ dij[i][j];
}
}
return ret;
}
int matrixQuery(int x1, int y1, int x2, int y2) {
int a = Query(x2, y2);
int b = Query(x1 - 1, y1 - 1);
int c = Query(x1 - 1, y2);
int d = Query(x2, y1 - 1);
return a - c - d + b;
}
}

int main() {
std::ios::sync_with_stdio(false);
std::string _s; cin >> _s;
cin >> n >> m;
// rap (i, 1, n, 1) {
// rap (j, 1, m, 1) {
// int fx = 0;
// scanf("%d", &fx);
// BIT::matrixModify(i, j, i, j, fx);
// }
// }
char ch = 0;
while (cin >> ch) {
int a = 0, b = 0, c = 0, d = 0;
cin >> a >> b >> c >> d;
if (ch == 'L') {
int delta = 0;
cin >> delta;
BIT::matrixModify(a, b, c, d, delta);
} else {
// scanf("\n");
// printf("%d\n", BIT::matrixQuery(a, b, c, d));
cout << BIT::matrixQuery(a, b, c, d) << endl;
}
// getchar();
}
return 0;
}