线段树区间查询之最大子段和 线段树 区间查询 区间合并 单点修改 T3 245. 你能回答这些问题吗 - AcWing题库
给定长度为 的数列 ,以及 条指令,每条指令可能是以下两种之一:
1 x y
,查询区间 中的最大连续子段和,即 {}。2 x y
,把 改成 。
对于每个查询指令,输出一个整数表示答案。
输入格式
第一行两个整数 。
第二行 个整数 。
接下来 行每行 个整数 , 表示查询(此时如果 ,请交换 ), 表示修改。
输出格式
对于每个查询指令输出一个整数表示答案。
每个答案占一行。
数据范围
,
输入样例:
5 3
1 2 -3 4 5
1 2 3
2 2 -1
1 3 2
输出样例:
2
-1
思路
单点修改+区间查询最大字段和。
因为是单点修改, 故不需要用lazy标记, 直接改就行。最大子段和的查询需要注意对于 u 节点查最大字段和时, 有可能答案是横跨两个子区间之间。故需要进行额外的处理, 类似于 Hotel区间合并 的处理方式, 定义 lsum 和 rsum 代表最大前缀以及最大后缀, sum为当前左右子区间的最大和, 这样 当前区间的最大子序和tsum就从以下三个值中选:(lson代表左子节点, rson代表右子节点)
- lson.sum
- rson.sum
- lson.rsum + rson.lsum
这样就可以确保答案完全考虑到 最大子序和 分别在左区间, 右区间, 还是横跨两个区间。
怎么在pushup中更新 lsum 和 rsum 呢? u.lsum = max(lson.lsum, lson.sum + rson.lsum) 左子节点的lsum或者整个左子节点+右子节点的左前缀和
当我们用 lson.sum + rson.lsum 更新时, 有没有可能 lson.sum 不是从左边第一个数开始加的呢?并不会, 我们更新 sum 值时是把区间内所有的值都加上, 不重不漏的。
有个小技巧是利用多态特性, 定义两个pushup, 一个单独一个u表示更新u节点, 另一个为 Node &u, Node &l, Node &r 表示更细化地用 l,r 两个子节点更新 u 节点, 方便在输出结果时得到 res。
void pushup(Node &u, Node &l, Node &r) {}
void pushup(int u)
{
pushup(t[u], t[u << 1], t[u << 1 | 1]);
}
代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <string>
#include <cmath>
#define lson l, mid, u << 1
#define rson mid + 1, r, u << 1 | 1
using namespace std;
const int N = 5e5 + 10;
int a[N];
int n,m;
struct Node {
int lsum, rsum, sum, tsum;
}t[N * 4];
void pushup(Node &u, Node &l, Node &r)
{
u.sum = l.sum + r.sum;
u.lsum = max(l.lsum, l.sum + r.lsum);
u.rsum = max(r.rsum, r.sum + l.rsum);
u.tsum = max(max(l.tsum, r.tsum), l.rsum + r.lsum);
}
void pushup(int u)
{
pushup(t[u], t[u << 1], t[u << 1 | 1]);
}
void build(int l , int r, int u = 1)
{
if(l == r)
{
t[u] = {a[l], a[l], a[l], a[l]};
return;
}
int mid = l + r >> 1;
build(lson); build(rson);
pushup(u);
}
void update(int pos, int val, int l = 1, int r = n, int u = 1)
{
if(l == pos && r == pos)
{
t[u] = {val,val, val, val};
return;
}
int mid = l + r >> 1;
if(pos <= mid) update(pos, val, lson);
else update(pos, val, rson);
pushup(u);
}
Node query(int L, int R, int l = 1, int r = n, int u = 1)
{
if(L <= l && r <= R)
return t[u];
int mid = l + r >> 1;
int v = -0x3f3f3f3f, sum = 0;
if(L > mid) return query(L,R,rson);
else if(R <= mid) return query(L,R,lson);
else {
auto left = query(L, R, lson);
auto right = query(L,R,rson);
Node res;
pushup(res, left, right);
return res;
}
pushup(u);
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, n);
while(m--)
{
int t, a, b;
cin >> t >> a >> b;
if(t == 1)
{
if(a > b) swap(a,b);
cout << query(a, b).tsum << endl;
}
else
{
update(a, b);
}
}
return 0;
}