dataStructure_SegmentTree - 线段树

1 线段树原理

1.1 数组实现

1.2 线段树树形实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class SegTree {
public:
struct SegNode {
long long leftBound = 0, rightBound = 0, curSum = 0;
SegNode* lChild = nullptr;
SegNode* rChild = nullptr;
SegNode(long long lb, long long rb) :
leftBound(lb),
rightBound(rb),
curSum(0),
lChild(nullptr),
rChild(nullptr) {
}
};

SegNode* root;

SegTree(long long left, long long right) {
root = build(left, right);
}

SegTree() {}

SegNode* build(long long l, long long r) {
SegNode* node = new SegNode(l, r);
if(l == r) {
return node;
}

long long mid = (l + r) / 2;
node->lChild = build(l, mid);
node->rChild = build(mid + 1, r);
return node;
}

void insert(SegNode* root, long long tarIdx, long long val) {
root->curSum += val;
if(root->leftBound == root->rightBound) {
return;
}
long long mid = (root->leftBound + root->rightBound) >> 1;
// long long mid = (root->leftBound + root->rightBound) / 2;
// there are identicial difference between them two:
// eg: when left == -1, right = 0;
// case1 => (left + right) / 2 == 0
// case1 => (left + right) >> 1 == -1
if(tarIdx <= mid) {
if (nullptr == root->lChild) {
root->lChild = new SegNode(root->leftBound, mid);
}
insert(root->lChild, tarIdx, val);
}
else{
if (nullptr == root->rChild) {
root->rChild = new SegNode(mid + 1, root->rightBound);
}
insert(root->rChild, tarIdx, val);
}
}

long long getSum(SegNode* root, long long left, long long right) const {
if(nullptr == root) {
return 0;
}
// 当前节点位于目标区间外
if(left > root->rightBound || right < root->leftBound) {
return 0;
}
// 当前节点位于目标区间内
if(left <= root->leftBound && right >= root->rightBound) {
return root->curSum;
}
return getSum(root->lChild, left, right) + getSum(root->rChild, left, right);
}
};

0327. 区间和的个数

1 题目

https://leetcode-cn.com/problems/count-of-range-sum/

题和逆序数对的计算方式相同:https://leetcode-cn.com/problems/shu-zu-zhong-de-ni-xu-dui-lcof/
就是做了一个小改变而已,很多统计区间值的,st-ed < tar, 本来是让你找一个st,ed的对子的,那么就会转换思路为:
对于每一个ed找st,什么样的呢? st < ed + tar
然后找这样的st就有很多方法,比如hash,前缀和,bitree,priority_queue

2 解题思路

  • 1 求逆序对的思路:
    • 1.1 首先注意到:对于数组{5,5,2,3,6}而言,得到每个value的个数的统计:
      • index -> 1 2 3 4 5 6 7 8 9
      • value -> 0 1 1 0 2 1 0 0 0
    • 1.2 那么上述过程中,比如对于5,其贡献的逆序数对为5之前所有数字出现次数的和,也就是value数组中2之前的前缀和!
  • 2 那么如何快速获得前缀和呢?考虑使用BST来获取,参考:https://xychen5.github.io/2021/12/15/dataStructure-BinaryIndexedTree/
    • 2.1 整体思路如下:
      • a 使用数字在数组中的排名来代替数字(这不会对逆序数对的个数产生影响)
      • b 对数组nums中的元素nums[i]从右到左构建BITree(i 从 n-1 到 0),注意,BITree所对应的前缀和是数组里数字出现次数的和
        • 比如进行到nums[i],那么nums[i]右边的数字都已经统计了他们的出现次数,而后获取nums[i] - 1的前缀和,即可获取所有 < nums[i]的数字在nums[i:n]中的出现次数之和,也就是nums[i]贡献的逆序数对的个数
        • 之所以是逆序遍历构建BITree,是因为对于nums[i],它能够贡献的逆序数对的个数仅仅出现在它的右侧,所以需要在右侧进行
    • 2.2 额外说一下数组离散化,也就是不关系数字大小本身,只关心他们之间的相对排位
  • 3 使用线段树解题
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class Solution {
public:
class SegTree {
public:
struct SegNode {
long long leftBound = 0, rightBound = 0, curSum = 0;
SegNode* lChild = nullptr;
SegNode* rChild = nullptr;
SegNode(long long lb, long long rb) :
leftBound(lb),
rightBound(rb),
curSum(0),
lChild(nullptr),
rChild(nullptr) {
}
};

SegNode* root;


SegTree(long long left, long long right) {
root = build(left, right);
}

SegTree() {}

SegNode* build(long long l, long long r) {
SegNode* node = new SegNode(l, r);
if(l == r) {
return node;
}

long long mid = (l + r) / 2;
node->lChild = build(l, mid);
node->rChild = build(mid + 1, r);
return node;
}

void insert(SegNode* root, long long tarIdx, long long val) {
root->curSum += val;
if(root->leftBound == root->rightBound) {
return;
}
long long mid = (root->leftBound + root->rightBound) >> 1;

if(tarIdx <= mid) {
// cout << "d2" << endl;
if (nullptr == root->lChild) {
// cout << "d2.5" << endl;
root->lChild = new SegNode(root->leftBound, mid);
}
insert(root->lChild, tarIdx, val);
}
else{
// cout << "d3" << endl;
if (nullptr == root->rChild) {
root->rChild = new SegNode(mid + 1, root->rightBound);
}
insert(root->rChild, tarIdx, val);
}
}

long long getSum(SegNode* root, long long left, long long right) const {
if(nullptr == root) {
return 0;
}
// 当前节点位于目标区间外
if(left > root->rightBound || right < root->leftBound) {
return 0;
}
// 当前节点位于目标区间内
if(left <= root->leftBound && right >= root->rightBound) {
// cout << "left/right" << left << "/" << right << " => " << root->curSum << endl;
return root->curSum;
}
return getSum(root->lChild, left, right) + getSum(root->rChild, left, right);
}

};

int countRangeSum(vector<int>& nums, int lower, int upper) {
unordered_map<int, int> numToIdx;
set<long long> tmpNums;
vector<long long> prefixSum = {0};

for(int i = 0; i < nums.size(); ++i) {
prefixSum.emplace_back(nums[i] + prefixSum.back());
}

for(auto ps : prefixSum) {
tmpNums.insert(ps);
tmpNums.insert(ps - lower);
tmpNums.insert(ps - upper);
}

int i = 1;
for(auto num : tmpNums) {
numToIdx[num] = i++;
}

// for a valid s(i, j) we shall find:
// preSum[j] - ub <= preSum[i] <= preSum[j] - lb
// we just need to statistic those preSum[i] for each j
int n = tmpNums.size();
// BITree tree(n + 1);
long long ans = 0;

// we do not do the deserialization
long long minLeft = LLONG_MAX, maxRight = LLONG_MIN;
for(long long x : prefixSum) {
minLeft = min({minLeft, x, x - lower, x - upper});
maxRight = max({maxRight, x, x - lower, x - upper});
}

// cout << "minL, maxR" << minLeft << " " << maxRight <<endl;
SegTree tree;
tree.root = new SegTree::SegNode(minLeft, maxRight);
// reason why we insert the prefixSum of 0, because for the first ele:
// if it statisfy the interval, then it will be statisticed because the 0
for(long long x : prefixSum) {
ans += tree.getSum(tree.root, x - upper, x - lower);
// cout << "lb, rb = " << x - upper<< " " << x - lower << " ==> ans = " << ans << endl;
tree.insert(tree.root, x, 1);
// cout << "insert: " << x << "curRoot: " << tree.root->curSum << endl;
}


return ans;
}
}

5 可能发生的问题

注意其中很重要的一点:

对于线段树中插入一个节点时,需要对沿路所有节点的sum加上要插入的节点的值,找这个节点位置的时候,
需要找到root左右管辖范围的中间值mid,此时务必使用>>1去做,因为获得mid我们要求其为 floor(left + right),
(cpp和python对于移位和除法的逻辑是相同的),这里就显示出了2者的区别,
当然在数字都为正数的时候不会出错!

1
2
3
4
>>> int((-1 + 0) / 2)
0
>>> int((-1 + 0) >> 1)
-1