线段树

线段树学习


简介

常用的用来维护 区间信息 的数据结构。

线段树可以在 $O(\log N)$ 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

实现过程即为将每个长度不为 $1$ 的片段分成左右两个片段,不断递归下去,把区间分为树形结构,通过合并两端来求值


模板

单点修改

操作 $1$ 单点修改
操作 $2$ 查询区间最小值

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+7;
#define ls now<<1
#define rs now<<1|1

int n,m;
int a[N];

struct Node{
    int minv;
}tr[N<<2];

void update(int now){
    tr[now].minv=min(tr[ls].minv,tr[rs].minv);
}

void build(int now,int l,int r){
    if(l==r){
        tr[now].minv=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    update(now);
}

void change(int now,int s,int t,int pos,int val){
    if(s==t){
        tr[now].minv=val;
        return;
    }
    int mid=(s+t)>>1;
    if(pos<=mid)
        change(ls,s,mid,pos,val);
    else 
        change(rs,mid+1,t,pos,val);
    update(now);
}

int query(int now,int l,int r,int s,int t){
    if(l<=s&&r>=t)return tr[now].minv;
    int mid=(s+t)>>1;
    if(r<=mid)
        return query(ls,l,r,s,mid);
    else if(l>mid)
        return query(rs,l,r,mid+1,t);
    else 
        return min(query(ls,l,r,s,mid),query(rs,l,r,mid+1,t));
}

int main(){
    cin>>n>>m;
    for(int i=1;i<=n;i++)cin>>a[i];
    build(1,1,n);
    while(m--){
        int op,x,y;
        cin>>op>>x>>y;
        if(op==1){
            change(1,1,n,x,y);
        }
        else cout<<query(1,x,y,1,n)<<endl;
    }
    return 0;
}

区间合并

不光记录最小值,同时记录最小值出现次数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+7;
#define ls now<<1
#define rs now<<1|1

int n,m;
int a[N];

struct Node{
    int minv,cnt;
}tr[N<<2];

Node operator + (const Node &l,const Node &r){
    Node a;
    a.minv=min(l.minv,r.minv);
    if(l.minv==r.minv)a.cnt=l.cnt+r.cnt;
    else if(l.minv<r.minv)a.cnt=l.cnt;
    else a.cnt=r.cnt;
    return a;
}

void update(int now){
    tr[now]=tr[ls]+tr[rs];
}

void build(int now,int l,int r){
    if(l==r){
        tr[now]={a[l],1};
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    update(now);
}

void change(int now,int s,int t,int pos,int val){
    if(s==t){
        tr[now]={val,1};;
        return;
    }
    int mid=(s+t)>>1;
    if(pos<=mid)
        change(ls,s,mid,pos,val);
    else 
        change(rs,mid+1,t,pos,val);
    update(now);
}

Node query(int now,int l,int r,int s,int t){
    if(l<=s&&r>=t)return tr[now];
    int mid=(s+t)>>1;
    if(r<=mid)
        return query(ls,l,r,s,mid);
    else if(l>mid)
        return query(rs,l,r,mid+1,t);
    else 
        return (query(ls,l,r,s,mid)+query(rs,l,r,mid+1,t));
}

int main(){
    cin>>n>>m;
    for(int i=1;i<=n;i++)cin>>a[i];
    build(1,1,n);
    while(m--){
        int op,x,y;
        cin>>op>>x>>y;
        if(op==1){
            change(1,1,n,x,y);
        }
        else cout<<query(1,x,y,1,n).minv<<" "<<query(1,x,y,1,n).cnt<<endl;
    }
    return 0;
}

最大子段和

数组片段中有正有负,需要统计的数据变多,合并时可能涉及到左右片段相接部分

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+7;
#define ls now<<1
#define rs now<<1|1

int n,m;
int a[N];

struct Node{
    int mss,mpre,msuf,s;
}tr[N<<2];

Node operator + (const Node &l,const Node &r){
    Node a;
    a.mss=max({l.mss,r.mss,l.msuf+r.mpre});
    a.mpre=max(l.mpre,l.s+r.mpre);
    a.msuf=max(r.msuf,r.s+l.msuf);
    a.s=l.s+r.s;
    return a;
}

void update(int now){
    tr[now]=tr[ls]+tr[rs];
}

void build(int now,int l,int r){
    if(l==r){
        tr[now]={a[l],a[l],a[l],a[l]};
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    update(now);
}

void change(int now,int s,int t,int pos,int val){
    if(s==t){
        tr[now]={val,val,val,val};
        return;
    }
    int mid=(s+t)>>1;
    if(pos<=mid)
        change(ls,s,mid,pos,val);
    else 
        change(rs,mid+1,t,pos,val);
    update(now);
}

Node query(int now,int l,int r,int s,int t){
    if(l<=s&&r>=t)return tr[now];
    int mid=(s+t)>>1;
    if(r<=mid)
        return query(ls,l,r,s,mid);
    else if(l>mid)
        return query(rs,l,r,mid+1,t);
    else 
        return (query(ls,l,r,s,mid)+query(rs,l,r,mid+1,t));
}

int main(){
    cin>>n>>m;
    for(int i=1;i<=n;i++)cin>>a[i];
    build(1,1,n);
    while(m--){
        int op,x,y;
        cin>>op>>x>>y;
        if(op==1){
            change(1,1,n,x,y);
        }
        else cout<<query(1,x,y,1,n).mss<<endl;
    }
    return 0;
}

懒惰标记

对片段修改,用懒惰标记标记当前段,之后再取到时,将标记传到子段

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 7;
#define ls (now << 1)
#define rs (now << 1 | 1)

int n, m;
int a[N];

struct Node {
    int len, sum, tag;
} tr[N << 2];

Node operator + (const Node &l, const Node &r) {
    Node a;
    a.sum = l.sum + r.sum;
    a.tag = 0;
    a.len = l.len + r.len;
    return a;
}

void update(int now) {
    tr[now] = tr[ls] + tr[rs];
}

void settag(int now, int k) {
    tr[now].tag += k;
    tr[now].sum += tr[now].len * k;
}

void pushdown(int now) {
    if (tr[now].tag) {
        settag(ls, tr[now].tag);
        settag(rs, tr[now].tag);
        tr[now].tag = 0;
    }
}

void build(int now, int l, int r) {
    if (l == r) {
        tr[now] = {1, a[l], 0};
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    update(now);
}

void modify(int now, int l, int r, int s, int t, int val) {
    if (l <= s && r >= t) {
        settag(now, val);
        return;
    }
    pushdown(now);
    int mid = (s + t) >> 1;
    if (l <= mid) modify(ls, l, r, s, mid, val);
    if (r > mid) modify(rs, l, r, mid + 1, t, val);
    update(now);
}

int query(int now, int l, int r, int s, int t) {
    if (l <= s && r >= t) return tr[now].sum;
    pushdown(now);
    int mid = (s + t) >> 1;
    int ans = 0;
    if (l <= mid) ans += query(ls, l, r, s, mid);
    if (r > mid) ans += query(rs, l, r, mid + 1, t);
    return ans;
}

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i];
    build(1, 1, n);
    while (m--) {
        int op, x, y, z;
        cin >> op >> x >> y;
        if (op == 1) {
            cin >> z;
            modify(1, x, y, 1, n, z);
        } else {
            cout << query(1, x, y, 1, n) << endl;
        }
    }
    return 0;
}

懒惰标记 2

不光有区间加,还有区间乘

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+7;
#define ls now<<1
#define rs now<<1|1

int n, mod, q;
int a[N];

struct Tag {
    int mul, add;
};

struct Node {
    int len, sum;
    Tag tag;
} tr[N<<2];

Tag operator + (const Tag &a, const Tag &b) {
    return {a.mul * b.mul % mod, (a.add * b.mul + b.add) % mod};
}

Node operator + (const Node &l, const Node &r) {
    Node a;
    a.sum = (l.sum + r.sum) % mod;
    a.tag = {1, 0};
    a.len = l.len + r.len;
    return a;
}

void update(int now) {
    tr[now] = tr[ls] + tr[rs];
}

void settag(int now, Tag t) {
    tr[now].tag = tr[now].tag + t;
    tr[now].sum = (tr[now].sum * t.mul + tr[now].len * t.add) % mod;
}

void pushdown(int now) {
    if (tr[now].tag.mul != 1 || tr[now].tag.add) {
        settag(ls, tr[now].tag);
        settag(rs, tr[now].tag);
        tr[now].tag = {1, 0};
    }
}

void build(int now, int l, int r) {
    if (l == r) {
        tr[now] = {1, a[l], {1, 0}};
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    update(now);
}

void modify(int now, int l, int r, int s, int t, Tag val) {
    if (l <= s && r >= t) {
        settag(now, val);
        return;
    }
    pushdown(now);
    int mid = (s + t) >> 1;
    if (l <= mid) modify(ls, l, r, s, mid, val);
    if (r > mid) modify(rs, l, r, mid + 1, t, val);
    update(now);
}

int query(int now, int l, int r, int s, int t) {
    if (l <= s && r >= t) return tr[now].sum;
    pushdown(now);
    int mid = (s + t) >> 1;
    int ans = 0;
    if (l <= mid) ans = (ans + query(ls, l, r, s, mid)) % mod;
    if (r > mid) ans = (ans + query(rs, l, r, mid + 1, t)) % mod;
    return ans;
}

int main() {
    cin >> n >> q >> mod;
    for (int i = 1; i <= n; i++) cin >> a[i];
    build(1, 1, n);
    while (q--) {
        int op, x, y;
        cin >> op >> x >> y;
        Tag z = {1, 0};
        if (op == 1) {
            cin >> z.mul;
            modify(1, x, y, 1, n, z);
        } else if (op == 2) {
            cin >> z.add;
            modify(1, x, y, 1, n, z);
        } else {
            cout << query(1, x, y, 1, n) << endl;
        }
    }
    return 0;
}

线段树上二分

重点改的是查询部分

找到区间中第一个大于等于 $d$ 的位置

但是 $l == s$, $r == t$ 时不返回,继续递归

时间复杂度仍然是 $O(\log n)$

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 7;
#define ls now << 1
#define rs now << 1 | 1

int n, m;
int a[N];

struct Node
{
    int mx;
} tr[N << 2];

Node operator+(const Node &l, const Node &r)
{
    Node a;
    a.mx = max(l.mx, r.mx);
    return a;
}

void update(int now)
{
    tr[now] = tr[ls] + tr[rs];
}

void build(int now, int l, int r)
{
    if (l == r)
    {
        tr[now] = {a[l]};
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    update(now);
}

void change(int now, int s, int t, int pos, int val)
{
    if (s == t)
    {
        tr[now] = {val};
        return;
    }
    int mid = (s + t) >> 1;
    if (pos <= mid)
        change(ls, s, mid, pos, val);
    if (pos > mid)
        change(rs, mid + 1, t, pos, val);
    update(now);
}

int search(int now, int l, int r, int s, int t, int d)
{
    if (l == s && r == t)
    {
        if (tr[now].mx < d)
            return -1;
        if (l == r)
            return l;
        int mid = (l + r) >> 1;
        if (tr[ls].mx >= d)
            return search(ls, l, mid, s, mid, d);
        return search(rs, mid + 1, r, mid + 1, t, d);
    }
    int mid = (s + t) >> 1;
    if (l <= mid)
        return search(ls, l, r, s, mid, d);
    if (r > mid)
        return search(rs, l, r, mid + 1, t, d);
    int pos = search(ls, l, mid, s, mid, d);
    if (pos != -1)
        return pos;
    return search(rs, mid + 1, r, mid + 1, t, d);
}

int main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    while (m--)
    {
        int op, x, y, d;
        cin >> op >> x >> y;
        if (op == 1)
        {
            change(1, 1, n, x, y);
        }
        else
        {
            cin >> d;
            cout << search(1, x, y, 1, n, d) << endl;
        }
    }
    return 0;
}
发表了95篇文章 · 总计15万2千字
永远相信美好的事情即将发生。
使用 Hugo 构建
主题 StackJimmy 设计