伸展树

例题:维护序列

题目描述

请写一个程序,要求维护一个数列,支持以下 $6$ 种操作:

编号 名称 格式 说明
1 插入 $\operatorname{INSERT}\ posi \ tot \ c_1 \ c_2 \cdots c_{tot}$ 在当前数列的第 $posi$ 个数字后插入 $tot$ 个数字:$c_1, c_2 \cdots c_{tot}$;若在数列首插入,则 $posi$ 为 $0$
2 删除 $\operatorname{DELETE} \ posi \ tot$ 从当前数列的第 $posi$ 个数字开始连续删除 $tot$ 个数字
3 修改 $\operatorname{MAKE-SAME} \ posi \ tot \ c$ 从当前数列的第 $posi$ 个数字开始的连续 $tot$ 个数字统一修改为 $c$
4 翻转 $\operatorname{REVERSE} \ posi \ tot$ 取出从当前数列的第 $posi$ 个数字开始的 $tot$ 个数字,翻转后放入原来的位置
5 求和 $\operatorname{GET-SUM} \ posi \ tot$ 计算从当前数列的第 $posi$ 个数字开始的 $tot$ 个数字的和并输出
6 求最大子列和 $\operatorname{MAX-SUM}$ 求出当前数列中和最大的一段子列,并输出最大和

输入格式

第一行包含两个整数 $N$ 和 $M$,$N$ 表示初始时数列中数的个数,$M$ 表示要进行的操作数目。

第二行包含 $N$ 个数字,描述初始时的数列。以下 $M$ 行,每行一条命令,格式参见问题描述中的表格。

输出格式

对于输入数据中的 $\operatorname{GET-SUM}$ 和 $\operatorname{MAX-SUM}$ 操作,向输出文件依次打印结果,每个答案(数字)占一行。

样例输入 #1

1
2
3
4
5
6
7
8
9
10
9 8 
2 -6 3 5 1 -5 -3 6 3
GET-SUM 5 4
MAX-SUM
INSERT 8 3 -5 7 2
DELETE 12 1
MAKE-SAME 3 3 2
REVERSE 3 6
GET-SUM 5 4
MAX-SUM

样例输出 #1

1
2
3
4
-1
10
1
10

数据规模与约定

  • 你可以认为在任何时刻,数列中至少有 $1$ 个数。
  • 输入数据一定是正确的,即指定位置的数在数列中一定存在。
  • 对于 $50%$ 的数据,任何时刻数列中最多含有 $3 \times 10^4$ 个数。
  • 对于 $100%$ 的数据,任何时刻数列中最多含有 $5 \times 10^5$ 个数,任何时刻数列中任何一个数字均在 $[-10^3, 10^3]$ 内,$1 \le M \le 2 \times 10^4$,插入的数字总数不超过 $4 \times 10^6$。

一、伸展树

伸展树$(Splay Tree)$,也叫分裂树,是一种二叉排序树,它能在$O(logN)$内完成插入、查找和删除操作。

在伸展树上的一般操作都基于伸展操作:
假设想要对一个二叉查找树执行一系列的查找操作,为了使整个查找时间更小,被查频率高的那些条目就应当经常处于靠近树根的位置。于是想到设计一个简单方法, 在每次查找之后对树进行重构,把被查找的条目搬移到离树根近一些的地方。伸展树应运而生。伸展树是一种自调整形式的二叉查找树,它会沿着从某个节点到树根之间的路径,通过一系列的旋转把这个节点搬移到树根去。

伸展树的核心思想为:每操作一个结点,就将改结点旋转至树根


二、左旋和右旋

a1c93ac051306a6624d317137bc1021.png
这里可以统一为自旋
以下图为例子,更新红色的边的关系即可
3cc8a6c7d3f0617fd5bfde3ffdfa89c.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 自旋函数
void rotate(int x)
{
int y = tr[x].p, z = tr[y].p;
// k = 0表示x是y的左儿子,k = 1表示x是y的右儿子
int k = tr[y].s[1] == x;

// 以k = 0的情况写即可
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;

// 旋转后需pushup
pushup(y), pushup(x);
}

三、旋转方式:

对于结点$x$,其存在两种不同的树结构,对应不同的旋转方式
9e3b27ca09ef705df8e5cdabb0e9bc6.png
8f3f193afbba38f9ee0d3a7dc2d03a2.png
即:
直链状: 先转$y$,再转$x$
非直链状:连续转两次$x$

Splay操作(核心)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 核心函数,将结点x旋转至k结点的下面
// k = 0时,表示将x旋转至根
void splay(int x, int k)
{
while(tr[x].p != k)
{
int y = tr[x].p, z = tr[y].p;

// z不是k时,需要旋转两次
if(z != k)
// 如果z, y, x的关系不是直链时,先旋转x
// 否则先旋转y
if((tr[z].s[0] == y) != (tr[y].s[0] == x)) rotate(x);
else rotate(y);

// 最后再旋转x
rotate(x);
}
// k = 0时, 表示将x旋转至根
if(!k) root = x;
}

四、$pushup$和$pushdown$

$pushup函数$在旋转结点后调用,使用子结点的信息更新当前结点的信息
$pushdown函数$在递归之前调用,使用当前结点的信息更新子结点的信息


五、完整代码

时间复杂度:$O(MlogN)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>

using namespace std;

const int N = 500010, M = 21, INF = 2e9;

struct node {
int s[2], p;
int val;
int size;
int sum, lsum, rsum, msum;
int same, rev;

void init(int _val, int _p)
{
s[0] = s[1] = 0, p = _p;
size = 1;
same = rev = 0;
val = sum = msum = _val;
lsum = rsum = max(0, _val);
}

} tr[N];

int root;
int nodes[N], top; // 内存回收机制, 类似栈

int n, m;
int w[N];

void pushup(int x)
{
auto &u = tr[x], &l = tr[tr[x].s[0]], &r = tr[tr[x].s[1]];
// 记得考虑u本身
u.size = l.size + r.size + 1;
u.sum = l.sum + r.sum + u.val;
u.msum = max(max(l.msum, r.msum), l.rsum + u.val + r.lsum);
u.lsum = max(l.lsum, l.sum + u.val + r.lsum);
u.rsum = max(r.rsum, r.sum + u.val + l.rsum);
}

void pushdown(int x)
{
auto &u = tr[x], &l = tr[tr[x].s[0]], &r = tr[tr[x].s[1]];

if(u.same)
{
if(u.s[0]) l.same = 1, l.val = u.val, l.sum = u.val * l.size;
if(u.s[1]) r.same = 1, r.val = u.val, r.sum = u.val * r.size;

if(u.val > 0)
{
if(u.s[0]) l.msum = l.lsum = l.rsum = l.sum;
if(u.s[1]) r.msum = r.lsum = r.rsum = r.sum;
}
else
{
if(u.s[0]) l.msum = u.val, l.lsum = l.rsum = 0;
if(u.s[1]) r.msum = u.val, r.lsum = r.rsum = 0;
}

u.same = u.rev = 0; // 整个区间都变成同一个数,rev标签没用意义,也需要清空
}

if(u.rev)
{
if(u.s[0]) l.rev ^= 1, swap(l.s[0], l.s[1]), swap(l.lsum, l.rsum);
if(u.s[1]) r.rev ^= 1, swap(r.s[0], r.s[1]), swap(r.lsum, r.rsum);
u.rev = 0;
}
}

void rotate(int x)
{
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;

tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;

pushup(y), pushup(x);
}

void splay(int x, int k)
{
while(tr[x].p != k)
{
int y = tr[x].p, z = tr[y].p;
if(z != k)
if((tr[z].s[1] == y) == (tr[y].s[1] == x)) rotate(y);
else rotate(x);
rotate(x);
}
if(!k) root = x;
}

int get_k(int k)
{
int u = root;
while(1)
{
pushdown(u);
if(tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
else if(tr[tr[u].s[0]].size + 1 == k) return u;
else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
}
return -1;
}

// 递归建立splay, p为父结点信息
int build(int l, int r, int p)
{
int mid = l + r >> 1;
int u = nodes[top --]; // 分配结点
tr[u].init(w[mid], p); // 结点初始化
// 递归创建儿子结点
if(l < mid) tr[u].s[0] = build(l, mid - 1, u); // 左区间为[l, mid - 1]
if(r > mid) tr[u].s[1] = build(mid + 1, r, u); // 右区间为[mid + 1, r]
// 记得pushup
pushup(u);
return u;
}

// 递归回收根结点为u的子树
void dfs(int u)
{
if(tr[u].s[0]) dfs(tr[u].s[0]);
if(tr[u].s[1]) dfs(tr[u].s[1]);
nodes[++ top] = u;
}

int main()
{
for(int i = 1; i < N; i ++) nodes[++ top] = i; // 初始化结点存储栈
scanf("%d%d", &n, &m);
// 初始化哨兵
w[0] = w[n + 1] = -INF;
// 0号点是空结点的编号,结点初始化tr[0]信息
tr[0].msum = -INF;
for(int i = 1; i <= n; i ++) scanf("%d", &w[i]);
root = build(0, n + 1, 0);

while(m --)
{
char op[M];
scanf("%s", op);

if(strcmp(op, "INSERT") == 0)
{
int posi, tot;
scanf("%d%d", &posi, &tot);
for(int i = 1; i <= tot; i ++) scanf("%d", &w[i]);
int l = get_k(posi + 1), r = get_k(posi + 2);
splay(l, 0), splay(r, l);
tr[r].s[0] = build(1, tot, r);
pushup(r), pushup(l);
}
else if(strcmp(op, "DELETE") == 0)
{
int posi, tot;
scanf("%d%d", &posi, &tot);
int l = get_k(posi), r = get_k(posi + tot + 1);
splay(l, 0), splay(r, l);
dfs(tr[r].s[0]);
tr[r].s[0] = 0;
pushup(r), pushup(l);
}
else if(strcmp(op, "MAKE-SAME") == 0)
{
int posi, tot, c;
scanf("%d%d%d", &posi, &tot, &c);
int l = get_k(posi), r = get_k(posi + tot + 1);
splay(l, 0), splay(r, l);
auto &u = tr[tr[r].s[0]];
u.same = 1;
u.val = c, u.sum = c * u.size;
if(c > 0) u.msum = u.lsum = u.rsum = u.sum;
else u.msum = c, u.lsum = u.rsum = 0;
pushup(r), pushup(l);
}
else if(strcmp(op, "REVERSE") == 0)
{
int posi, tot;
scanf("%d%d", &posi, &tot);
int l = get_k(posi), r = get_k(posi + tot + 1);
splay(l, 0), splay(r, l);
auto &u = tr[tr[r].s[0]];
u.rev ^= 1;
swap(u.s[0], u.s[1]);
swap(u.lsum, u.rsum);
pushup(r), pushup(l);
}
else if(strcmp(op, "GET-SUM") == 0)
{
int posi, tot;
scanf("%d%d", &posi, &tot);
int l = get_k(posi), r = get_k(posi + tot + 1);
splay(l, 0), splay(r, l);
printf("%d\n", tr[tr[r].s[0]].sum);
}
else if(strcmp(op, "MAX-SUM") == 0)
{
printf("%d\n", tr[root].msum);
}
}
return 0;
}