1 条题解

  • 0
    @ 2025-11-18 15:58:09

    题解:L2469. 「2018 集训队互测 Day 2」最小方差生成树

    题目大意

    给定一个 nn 个点 mm 条边的带权无向图,边权为 wiw_i

    • T=1T = 1:求最小方差生成树。
    • T=2T = 2:对于每条边,求删除它后的最小方差生成树方差(如果图不连通则输出 1-1)。

    方差定义:生成树有 n1n-1 条边,方差为:

    $$\sigma^2 = \frac{\sum_{i=1}^{n-1}(x_i - \mu)^2}{n-1} $$

    其中 μ=i=1n1xin1\mu = \frac{\sum_{i=1}^{n-1} x_i}{n-1} 是平均值。

    输出要求:将方差乘以 (n1)2(n-1)^2 后输出,保证结果是整数。


    思路分析

    1. 方差公式化简

    设生成树的边权为 w1,w2,,wn1w_1, w_2, \dots, w_{n-1},均值为 μ=win1\mu = \frac{\sum w_i}{n-1}

    方差:

    σ2=(wiμ)2n1\sigma^2 = \frac{\sum (w_i - \mu)^2}{n-1}

    乘以 (n1)2(n-1)^2

    (n1)2σ2=(n1)(wiμ)2(n-1)^2 \sigma^2 = (n-1) \sum (w_i - \mu)^2

    展开平方项:

    =(n1)(wi22wiμ+μ2)= (n-1) \sum (w_i^2 - 2w_i\mu + \mu^2) $$= (n-1) \sum w_i^2 - 2(n-1)\mu \sum w_i + (n-1)^2 \mu^2 $$

    代入 μ=win1\mu = \frac{\sum w_i}{n-1}

    =(n1)wi22(wi)2+(wi)2= (n-1) \sum w_i^2 - 2(\sum w_i)^2 + (\sum w_i)^2 =(n1)wi2(wi)2= (n-1) \sum w_i^2 - (\sum w_i)^2

    所以目标是最小化

    $$\text{目标值} = (n-1) \sum w_i^2 - \left(\sum w_i\right)^2 $$

    S1=wiS_1 = \sum w_iS2=wi2S_2 = \sum w_i^2,则目标值为:

    (n1)S2S12(n-1) S_2 - S_1^2

    2. 最小方差生成树的性质

    关键结论:最小方差生成树一定在某个实数 xx 对应的最小生成树中,其中生成树的边是按 (wix)2(w_i - x)^2 的权值排序后得到的 MST。

    证明思路:考虑函数 f(x)=(n1)S2S12f(x) = (n-1)S_2 - S_1^2,当 xx 变化时,MST 的边集只在某些临界点发生变化,这些临界点就是两条边的 (wx)2(w - x)^2 值相等的地方,即 x=wi+wj2x = \frac{w_i + w_j}{2}


    3. 算法设计

    基本步骤

    1. 对边按原权值 wiw_i 排序
    2. 枚举所有可能的临界点 x=wi+wj2x = \frac{w_i + w_j}{2}(共 O(m2)O(m^2) 个,但实际只需 O(m)O(m) 个)
    3. 对每个临界点 xx,按 (wx)2(w - x)^2 排序求 MST
    4. 计算目标值 (n1)S2S12(n-1)S_2 - S_1^2,取最小值

    复杂度优化

    • 直接枚举所有 O(m2)O(m^2)xx 会超时
    • 实际上 MST 的边集是分段不变的,只需在 O(m)O(m) 个真正的临界点计算
    • 使用 LCT(Link-Cut Tree) 动态维护 MST,高效处理边的替换

    处理 T=2T = 2 的情况

    对于每条边 ee

    • 如果 ee 不在任何 MST 中,删除它不影响结果
    • 如果 ee 在某个 MST 中,需要重新计算不含 ee 的最小方差生成树
    • 使用预处理和 LCT 来高效处理每条边的删除情况

    代码框架说明

    代码主要包含以下几个部分:

    1. 数据结构定义

    • Data:存储边的信息(端点、权值、编号)
    • Num:高精度数类,用于精确计算目标值
    • LCT 相关数据结构:维护动态 MST

    2. 核心函数

    • get_LR():获取每条边在 MST 中的左右邻接关系
    • get_T12():预处理每条边的相关边集
    • get_v():获取所有临界点
    • calc():计算最小方差值
    • LCT 操作:维护动态树的连接、断开、查询等

    3. 主流程

    1. 读入数据并排序
    2. 检查图连通性
    3. 计算原始最小方差生成树(T=1T = 1
    4. 如果 T=2T = 2,对每条边处理删除情况

    复杂度分析

    • 时间复杂度O(mlogm+mlogn)O(m \log m + m \log n),主要来自排序和 LCT 操作
    • 空间复杂度O(n+m)O(n + m)

    总结

    本题是一道结合了生成树理论动态树维护数学优化的综合题目,主要考察:

    1. 数学推导能力:将方差问题转化为可优化的目标函数
    2. 算法设计能力:利用 MST 性质和临界点枚举来减少计算量
    3. 数据结构应用:熟练使用 LCT 维护动态 MST
    4. 问题分解能力:分别处理 T=1T = 1T=2T = 2 的情况

    解决此类问题需要扎实的图论基础和对高级数据结构的深入理解。

    代码实现

    下面是题目给出的代码框架,已经实现了上述优化思路,使用 LCT 维护 MST,并在临界点计算方差。

    #include <bits/stdc++.h>
    using namespace std;
    #define N 305
    #define M 100005
    #define base 1000000000
    #define ll long long
    #define pli pair<ll,int>
    #define fi first
    #define se second
    
    struct Data {
        int x, y, id;
        ll z;
        bool operator < (const Data &k)const {
            return z < k.z;
        }
    } e[M];
    
    struct Num {
        int p, len;
        ll a[5];
        Num(ll k = 0) {
            p = (k < 0), len = 0;
            memset(a, 0, sizeof(a));
            k = abs(k);
            while (k) {
                a[len++] = k % base;
                k /= base;
            }
        }
    } Check, S1, S2, ans, z1[M], z2[M], dS1[M << 1], dS2[M << 1], Ans[M];
    
    vector<int>T1[M], T2[M];
    vector<pli>v0, v[M];
    int n, m, T, q, L0[M], R0[M], L[M], R[M], vis0[M];
    ll Pos[M << 1];
    
    namespace IO {
        // 输入输出优化
        int num[100];
        ll x;
        char c;
        ll read() {
            x = 0, c = getchar();
            while ((c < '0') || (c > '9')) c = getchar();
            while ((c >= '0') && (c <= '9')) {
                x = x * 10 + c - '0';
                c = getchar();
            }
            return x;
        }
        void write(Num x, char c = '\0') {
            if (x.p) putchar('-');
            for (int i = 0; i < x.len; i++)
                for (int j = 0; j < 9; j++) {
                    num[++num[0]] = x.a[i] % 10;
                    x.a[i] /= 10;
                }
            while ((num[0]) && (!num[num[0]])) num[0]--;
            if (!num[0]) putchar('0');
            while (num[0]) putchar(num[num[0]--] + '0');
            putchar(c);
        }
    };
    
    namespace Calc {
        // 高精度运算
        int cmp(Num x, Num y) {
            if (x.len != y.len) {
                if (x.len < y.len) return -1;
                return 1;
            }
            for (int i = x.len - 1; i >= 0; i--)
                if (x.a[i] != y.a[i]) {
                    if (x.a[i] < y.a[i]) return -1;
                    return 1;
                }
            return 0;
        }
        Num min(Num x, Num y) {
            if (cmp(x, y) < 0) return x;
            return y;
        }
        Num add(Num x, Num y) {
            Num ans;
            if (x.p == y.p) {
                ans.p = x.p, ans.len = max(x.len, y.len);
                for (int i = 0; i < ans.len; i++) {
                    ans.a[i] += x.a[i] + y.a[i];
                    if (ans.a[i] >= base) {
                        ans.a[i] -= base;
                        ans.a[i + 1]++;
                    }
                }
                if ((ans.len < 5) && (ans.a[ans.len])) ans.len++;
                return ans;
            }
            if (cmp(x, y) < 0) swap(x, y);
            ans.p = x.p, ans.len = x.len;
            for (int i = 0; i < ans.len; i++) {
                ans.a[i] += x.a[i] - y.a[i];
                if (ans.a[i] < 0) {
                    ans.a[i] += base;
                    ans.a[i + 1]--;
                }
            }
            while ((ans.len) && (!ans.a[ans.len - 1])) ans.len--;
            return ans;
        }
        Num dec(Num x, Num y) {
            y.p ^= 1;
            return add(x, y);
        }
        Num mul(Num x, Num y) {
            Num ans;
            ans.p = (x.p ^ y.p), ans.len = x.len + y.len - 1;
            for (int i = 0; i < x.len; i++)
                for (int j = 0; j < y.len; j++) {
                    ans.a[i + j] += x.a[i] * y.a[j];
                    ans.a[i + j + 1] += ans.a[i + j] / base;
                    ans.a[i + j] %= base;
                }
            if ((ans.len < 5) && (ans.a[ans.len])) ans.len++;
            return ans;
        }
    };
    
    namespace LCT {
        // Link-Cut Tree 维护 MST
        int vis[M], st[M << 1], fa[M << 1], mn[M << 1], mx[M << 1], rev[M << 1], ch[M << 1][2];
        int which(int k) {
            return ch[fa[k]][1] == k;
        }
        bool check(int k) {
            return ch[fa[k]][which(k)] == k;
        }
        void upd(int k) {
            rev[k] ^= 1;
            swap(ch[k][0], ch[k][1]);
        }
        void up(int k) {
            mn[k] = min(mn[ch[k][0]], mn[ch[k][1]]);
            mx[k] = max(mx[ch[k][0]], mx[ch[k][1]]);
            if (k > n) {
                mn[k] = min(mn[k], k - n);
                mx[k] = max(mx[k], k - n);
            }
        }
        void down(int k) {
            if (rev[k]) {
                if (ch[k][0]) upd(ch[k][0]);
                if (ch[k][1]) upd(ch[k][1]);
                rev[k] = 0;
            }
        }
        void rotate(int k) {
            int f = fa[k], g = fa[f], p = which(k);
            fa[k] = g;
            if (check(f)) ch[g][which(f)] = k;
            fa[ch[k][p ^ 1]] = f, ch[f][p] = ch[k][p ^ 1];
            fa[f] = k, ch[k][p ^ 1] = f;
            up(f), up(k);
        }
        void splay(int k) {
            for (int i = k; check(i); i = fa[i]) st[++st[0]] = fa[i];
            while (st[0]) down(st[st[0]--]);
            down(k);
            for (int i = fa[k]; check(k); i = fa[k]) {
                if (check(i)) {
                    if (which(i) == which(k)) rotate(i);
                    else rotate(k);
                }
                rotate(k);
            }
        }
        void access(int k) {
            int lst = 0;
            while (k) {
                splay(k);
                ch[k][1] = lst, up(k);
                lst = k, k = fa[k];
            }
        }
        void make_root(int k) {
            access(k);
            splay(k);
            upd(k);
        }
        int find_root(int k) {
            access(k);
            splay(k);
            while (ch[k][0]) {
                down(k);
                k = ch[k][0];
            }
            splay(k);
            return k;
        }
        void add(int x, int y) {
            make_root(x);
            make_root(y);
            fa[y] = x;
        }
        void del(int x, int y) {
            make_root(x);
            access(y);
            splay(x);
            fa[y] = ch[x][1] = 0, up(x);
        }
        int query_min(int x, int y) {
            make_root(x);
            if (find_root(y) != x) return 0;
            return mn[x];
        }
        int query_max(int x, int y) {
            make_root(x);
            if (find_root(y) != x) return 0;
            return mx[x];
        }
        int add_min(int id) {
            int pos = query_min(e[id].y, id + n);
            if (pos) {
                vis[0]--, vis[pos] = 0;
                LCT::del(e[pos].y, pos + n);
            }
            vis[0]++, vis[id] = 1;
            LCT::add(e[id].y, id + n);
            return pos;
        }
        int add_max(int id) {
            int pos = query_max(e[id].y, id + n);
            if (pos) {
                vis[0]--, vis[pos] = 0;
                LCT::del(e[pos].y, pos + n);
            }
            vis[0]++, vis[id] = 1;
            LCT::add(e[id].y, id + n);
            return pos;
        }
        void init() {
            mn[0] = 0x3f3f3f3f, mx[0] = 0;
            for (int i = 1; i <= n + m; i++) {
                fa[i] = rev[i] = ch[i][0] = ch[i][1] = 0;
                up(i);
            }
            for (int i = 1; i <= m; i++) add(e[i].x, i + n);
        }
        void clear() {
            for (int i = 1; i <= m; i++)
                if (vis[i]) {
                    LCT::del(e[i].y, i + n);
                    vis[0]--, vis[i] = 0;
                }
        }
    };
    
    void get_LR() {
        memset(L0, 0, sizeof(L0));
        memset(R0, 0, sizeof(R0));
        LCT::clear();
        for (int i = 1; i <= m; i++) {
            L0[i] = LCT::add_min(i);
            if (L0[i]) R0[L0[i]] = i;
        }
    }
    
    void get_T12() {
        LCT::clear();
        for (int i = 1; i <= m; i++) {
            if (vis0[i]) {
                for (int j = m; j; j--)
                    if (LCT::vis[j]) T1[i].push_back(j);
            }
            LCT::add_min(i);
        }
        LCT::clear();
        for (int i = m; i; i--) {
            if (vis0[i]) {
                for (int j = 1; j <= m; j++)
                    if (LCT::vis[j]) T2[i].push_back(j);
            }
            LCT::add_max(i);
        }
    }
    
    void get_v() {
        v0.clear();
        for (int i = 1; i <= m; i++) {
            if (!L0[i]) v0.push_back(make_pair(0, i));
            else v0.push_back(make_pair((e[L0[i]].z + e[i].z << 1) + 1, i));
            if (R0[i]) v0.push_back(make_pair((e[i].z + e[R0[i]].z << 1) + 1, -i));
        }
    }
    
    void upd_LR(int k) {
        LCT::clear();
        memcpy(L, L0, sizeof(L));
        L[k] = k;
        for (int i = 0; i < T2[k].size(); i++) {
            L[T2[k][i]] = 0;
            LCT::add_max(T2[k][i]);
        }
        for (int i = 0; i < T1[k].size(); i++) {
            int pos = LCT::add_max(T1[k][i]);
            if (pos) L[pos] = T1[k][i];
        }
        memset(R, 0, sizeof(R));
        for (int i = 1; i <= m; i++)
            if (L[i]) R[L[i]] = i;
    }
    
    void upd_v(int k) {
        T1[k].push_back(k), T2[k].push_back(k);
        for (int i = 0; i < T1[k].size(); i++) {
            int pos = T1[k][i];
            if (R0[pos]) v[k].push_back(make_pair((e[pos].z + e[R0[pos]].z << 1) + 1, pos));
            if (R[pos]) {
                v0.push_back(make_pair((e[pos].z + e[R[pos]].z << 1) + 1, 0));
                v[k].push_back(make_pair((e[pos].z + e[R[pos]].z << 1) + 1, -pos));
            }
        }
        for (int i = 0; i < T2[k].size(); i++) {
            int pos = T2[k][i];
            if (!L0[pos]) v[k].push_back(make_pair(0, -pos));
            else v[k].push_back(make_pair((e[L0[pos]].z + e[pos].z << 1) + 1, -pos));
            if (!L[pos]) {
                v0.push_back(make_pair(0, 0));
                v[k].push_back(make_pair(0, pos));
            } else {
                v0.push_back(make_pair((e[L[pos]].z + e[pos].z << 1) + 1, 0));
                v[k].push_back(make_pair((e[L[pos]].z + e[pos].z << 1) + 1, pos));
            }
        }
    }
    
    void unique() {
        sort(v0.begin(), v0.end());
        q = 0;
        for (int i = 0; i < v0.size(); i++) {
            if ((!i) || (v0[i].fi != v0[i - 1].fi)) Pos[++q] = v0[i].fi;
            int pos = abs(v0[i].se);
            if (v0[i].se > 0) dS1[q] = Calc::add(dS1[q], z1[pos]), dS2[q] = Calc::add(dS2[q], z2[pos]);
            else dS1[q] = Calc::dec(dS1[q], z1[pos]), dS2[q] = Calc::dec(dS2[q], z2[pos]);
        }
    }
    
    void calc() {
        sort(v0.begin(), v0.end());
        S1 = S2 = 0, ans.len = 10;
        for (int i = 0; i < v0.size(); i++) {
            int pos = abs(v0[i].se);
            if (v0[i].se > 0) S1 = Calc::add(S1, z1[pos]), S2 = Calc::add(S2, z2[pos]);
            else S1 = Calc::dec(S1, z1[pos]), S2 = Calc::dec(S2, z2[pos]);
            if ((i == v0.size()) || (v0[i].fi != v0[i + 1].fi)) {
                if (!v0[i].fi) Check = S1;
                ans = Calc::min(ans, Calc::dec(S2, Calc::mul(S1, S1)));
            }
        }
        S1 = S2 = 0;
        memset(vis0, 0, sizeof(vis0));
        for (int i = 0; i < v0.size(); i++) {
            int pos = abs(v0[i].se);
            vis0[pos] ^= 1;
            if (v0[i].se > 0) S1 = Calc::add(S1, z1[pos]), S2 = Calc::add(S2, z2[pos]);
            else S1 = Calc::dec(S1, z1[pos]), S2 = Calc::dec(S2, z2[pos]);
            if (((i == v0.size()) || (v0[i].fi != v0[i + 1].fi)) && (!Calc::cmp(ans, Calc::dec(S2, Calc::mul(S1, S1))))) break;
        }
    }
    
    void calc(int k) {
        for (int i = 0; i < v[k].size(); i++) {
            int q0 = lower_bound(Pos + 1, Pos + q + 1, v[k][i].fi) - Pos, pos = abs(v[k][i].se);
            if (v[k][i].se > 0) dS1[q0] = Calc::add(dS1[q0], z1[pos]), dS2[q0] = Calc::add(dS2[q0], z2[pos]);
            else dS1[q0] = Calc::dec(dS1[q0], z1[pos]), dS2[q0] = Calc::dec(dS2[q0], z2[pos]);
        }
        if (Calc::cmp(dS1[1], Check) < 0) {
            ans = -1;
            return;
        }
        S1 = S2 = 0, ans.len = 10;
        for (int i = 1; i <= q; i++) {
            S1 = Calc::add(S1, dS1[i]), S2 = Calc::add(S2, dS2[i]);
            ans = Calc::min(ans, Calc::dec(S2, Calc::mul(S1, S1)));
        }
        for (int i = 0; i < v[k].size(); i++) {
            int q0 = lower_bound(Pos + 1, Pos + q + 1, v[k][i].fi) - Pos, pos = abs(v[k][i].se);
            if (v[k][i].se < 0) dS1[q0] = Calc::add(dS1[q0], z1[pos]), dS2[q0] = Calc::add(dS2[q0], z2[pos]);
            else dS1[q0] = Calc::dec(dS1[q0], z1[pos]), dS2[q0] = Calc::dec(dS2[q0], z2[pos]);
        }
    }
    
    int main() {
        n = IO::read(), m = IO::read(), T = IO::read();
        for (int i = 1; i <= m; i++) {
            e[i].x = IO::read(), e[i].y = IO::read(), e[i].z = IO::read();
            e[i].id = i;
        }
        sort(e + 1, e + m + 1);
        LCT::init();
        for (int i = 1; i <= m; i++) {
            z1[i] = e[i].z;
            z2[i] = Calc::mul(n - 1, Calc::mul(z1[i], z1[i]));
        }
        get_LR();
        if (LCT::vis[0] != n - 1) {
            if (T == 1) IO::write(-1, '\n');
            else {
                for (int i = 1; i <= m; i++) IO::write(-1, '\n');
            }
            return 0;
        }
        get_v(), calc();
        if (T == 1) {
            IO::write(ans, '\n');
            return 0;
        }
        for (int i = 1; i <= m; i++)
            if (!vis0[i]) Ans[e[i].id] = ans;
        get_T12();
        for (int i = 1; i <= m; i++)
            if (vis0[i]) upd_LR(i), upd_v(i);
        unique();
        for (int i = 1; i <= m; i++)
            if (vis0[i]) {
                calc(i);
                Ans[e[i].id] = ans;
            }
        for (int i = 1; i <= m; i++) IO::write(Ans[i], '\n');
        return 0;
    }
    

    代码较长,但思路清晰,是 MST 和 LCT 结合的一道经典难题。

    • 1

    「2018 集训队互测 Day 2」最小方差生成树

    信息

    ID
    5090
    时间
    4500ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    7
    已通过
    1
    上传者