1 条题解
-
0
「SDOI2017」相关分析 详细题解
题目大意
给定 组观测数据 ,需要支持三种操作:
- 查询操作:用最小二乘法拟合区间 的数据,求回归直线 的斜率
- 加法操作:区间 内每个 加上 ,每个 加上
- 重置操作:区间 内每个 重置为 ,每个 重置为
算法思路
1. 数学模型分析
对于区间 ,设 ,需要计算:
- 平均数:,
- 斜率:$a = \frac{\sum (x_i - \bar{x})(y_i - \bar{y})}{\sum (x_i - \bar{x})^2}$
展开公式:
$$a = \frac{\sum x_iy_i - \frac{\sum x_i \sum y_i}{len}}{\sum x_i^2 - \frac{(\sum x_i)^2}{len}} = \frac{len \cdot \sum x_iy_i - \sum x_i \sum y_i}{len \cdot \sum x_i^2 - (\sum x_i)^2} $$因此我们需要维护四个值:
2. 线段树设计
我们需要维护以下信息:
struct Node { int l, r; double sum_x, sum_y; // ∑x, ∑y double sum_xx, sum_xy; // ∑x², ∑xy double tag_x, tag_y; // 加法标记 bool cover; // 覆盖标记 };3. 区间加法更新
对于加法操作 ,推导更新公式:
设原值为 ,新值为
- $\sum x'^2 = \sum (x_i + S)^2 = \sum x_i^2 + 2S \sum x_i + len \cdot S^2$
- $\sum x'y' = \sum (x_i + S)(y_i + T) = \sum x_iy_i + S \sum y_i + T \sum x_i + len \cdot ST$
4. 区间重置更新
对于重置操作,, (其中 是原始下标)
利用公式:
- $\sum_{i=l}^r i^2 = \frac{r(r+1)(2r+1)}{6} - \frac{(l-1)l(2l-1)}{6}$
因此:
- $\sum x_i = \sum (S + i) = len \cdot S + \sum_{i=l}^r i$
- $\sum y_i = \sum (T + i) = len \cdot T + \sum_{i=l}^r i$
- $\sum x_i^2 = \sum (S + i)^2 = len \cdot S^2 + 2S \sum i + \sum i^2$
- $\sum x_iy_i = \sum (S + i)(T + i) = len \cdot ST + (S+T) \sum i + \sum i^2$
5. 懒标记处理
需要处理两种标记:
- 覆盖标记:表示该区间被重置
- 加法标记:表示需要加上的值
下传顺序:先处理覆盖标记,再处理加法标记。
代码实现
#include <bits/stdc++.h> using namespace std; const int N = 1e5 + 10; using f8 = double; namespace seg_tree { struct Node { int l, r; f8 sum_x, sum_y, sum_xx, sum_xy; f8 tag_x, tag_y; bool cover; }; inline int ls(int u) { return u << 1; } inline int rs(int u) { return u << 1 | 1; } // 计算 1² + 2² + ... + n² inline f8 sqsum(f8 n) { return n * (n + 1) * (2 * n + 1) / 6; } // 计算 l + (l+1) + ... + r inline f8 arith_sum(f8 l, f8 r) { return (l + r) * (r - l + 1) / 2; } struct SegTree { vector<Node> tr; int n; SegTree(const vector<f8> &x, const vector<f8> &y) { n = x.size(); tr.resize(n * 4); build(1, 0, n - 1, x, y); } void merge(Node &res, const Node &le, const Node &ri) { res.sum_x = le.sum_x + ri.sum_x; res.sum_y = le.sum_y + ri.sum_y; res.sum_xx = le.sum_xx + ri.sum_xx; res.sum_xy = le.sum_xy + ri.sum_xy; } void build(int u, int l, int r, const vector<f8> &x, const vector<f8> &y) { tr[u].l = l, tr[u].r = r; tr[u].tag_x = tr[u].tag_y = 0; tr[u].cover = false; if (l == r) { tr[u].sum_x = x[l]; tr[u].sum_y = y[l]; tr[u].sum_xx = x[l] * x[l]; tr[u].sum_xy = x[l] * y[l]; return; } int mid = (l + r) >> 1; build(ls(u), l, mid, x, y); build(rs(u), mid + 1, r, x, y); merge(tr[u], tr[ls(u)], tr[rs(u)]); } // 应用覆盖标记(重置为 i, i) void apply_cover(int u) { f8 l = tr[u].l + 1, r = tr[u].r + 1; // 转换为1-indexed f8 sum_i = arith_sum(l, r); f8 sum_i2 = sqsum(r) - sqsum(l - 1); tr[u].sum_x = sum_i; tr[u].sum_y = sum_i; tr[u].sum_xx = sum_i2; tr[u].sum_xy = sum_i2; tr[u].cover = true; tr[u].tag_x = tr[u].tag_y = 0; } // 应用加法标记 void apply_add(int u, f8 dx, f8 dy) { int len = tr[u].r - tr[u].l + 1; // 按照推导公式更新 tr[u].sum_xy += dx * tr[u].sum_y + dy * tr[u].sum_x + dx * dy * len; tr[u].sum_xx += 2 * dx * tr[u].sum_x + dx * dx * len; tr[u].sum_x += dx * len; tr[u].sum_y += dy * len; tr[u].tag_x += dx; tr[u].tag_y += dy; } void pushdown(int u) { if (tr[u].cover) { apply_cover(ls(u)); apply_cover(rs(u)); tr[u].cover = false; } if (tr[u].tag_x != 0 || tr[u].tag_y != 0) { apply_add(ls(u), tr[u].tag_x, tr[u].tag_y); apply_add(rs(u), tr[u].tag_x, tr[u].tag_y); tr[u].tag_x = tr[u].tag_y = 0; } } void range_add(int u, int l, int r, f8 dx, f8 dy) { if (l <= tr[u].l && tr[u].r <= r) { apply_add(u, dx, dy); return; } pushdown(u); int mid = (tr[u].l + tr[u].r) >> 1; if (l <= mid) range_add(ls(u), l, r, dx, dy); if (r > mid) range_add(rs(u), l, r, dx, dy); merge(tr[u], tr[ls(u)], tr[rs(u)]); } void range_set(int u, int l, int r, f8 s, f8 t) { if (l <= tr[u].l && tr[u].r <= r) { apply_cover(u); // 先重置为 (i, i) apply_add(u, s, t); // 再加上 (S, T) return; } pushdown(u); int mid = (tr[u].l + tr[u].r) >> 1; if (l <= mid) range_set(ls(u), l, r, s, t); if (r > mid) range_set(rs(u), l, r, s, t); merge(tr[u], tr[ls(u)], tr[rs(u)]); } Node query(int u, int l, int r) { if (l <= tr[u].l && tr[u].r <= r) { return tr[u]; } pushdown(u); int mid = (tr[u].l + tr[u].r) >> 1; if (r <= mid) return query(ls(u), l, r); if (l > mid) return query(rs(u), l, r); Node res, le = query(ls(u), l, r), ri = query(rs(u), l, r); merge(res, le, ri); return res; } // 对外接口 void range_add(int l, int r, f8 s, f8 t) { range_add(1, l, r, s, t); } void range_set(int l, int r, f8 s, f8 t) { range_set(1, l, r, s, t); } f8 range_slope(int l, int r) { Node res = query(1, l, r); int len = r - l + 1; // 使用稳定计算公式 f8 num = len * res.sum_xy - res.sum_x * res.sum_y; f8 den = len * res.sum_xx - res.sum_x * res.sum_x; return num / den; } }; } int main() { int n, m; scanf("%d%d", &n, &m); vector<f8> x(n), y(n); for (int i = 0; i < n; i++) scanf("%lf", &x[i]); for (int i = 0; i < n; i++) scanf("%lf", &y[i]); seg_tree::SegTree sgt(x, y); for (int i = 0; i < m; i++) { int op, l, r; scanf("%d%d%d", &op, &l, &r); l--; r--; if (op == 1) { printf("%.10lf\n", sgt.range_slope(l, r)); } else { f8 s, t; scanf("%lf%lf", &s, &t); if (op == 2) { sgt.range_add(l, r, s, t); } else { sgt.range_set(l, r, s, t); } } } return 0; }复杂度分析
- 时间复杂度:
- 建树:
- 每次操作:
- 空间复杂度:
注意事项
- 精度问题:使用稳定的计算公式,避免大数相减
- 标记下传:注意覆盖标记和加法标记的处理顺序
- 边界处理:注意数组下标从0还是1开始
- 公式推导:确保所有更新公式的正确性
这道题综合考察了数学推导、线段树设计和懒标记应用,是一道很好的数据结构练习题。
- 1
信息
- ID
- 4073
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 10
- 标签
- 递交数
- 6
- 已通过
- 1
- 上传者