1 条题解

  • 0
    @ 2026-5-5 21:55:44

    题意简述

    给定两个长度为 nn 的序列 AABB,需要从 nn 个元素中选出恰好 KK 个,每个元素 ii 被选中的代价为 AiA_i,且若元素 iijji<ji < j)同时被选中,则额外增加 BjB_j 的代价。求最小总代价。

    思路分析

    转化为网络流模型

    观察代价结构:

    • 选中元素 ii 本身付出 AiA_i
    • 若选中 ii 且后面选中了 jji<ji < j),则付出 BjB_j(可以理解为每个 jj 对前面所有选中的 ii 各贡献一次 BjB_j,但注意题目通常搭配是“若 ii 选且 jj 选”只算一次?需要根据原题确认。从代码看,建图方式是每个 ii 向源点连 AiA_i,向伪汇点连 BiB_i,说明总代价 = Ai+Bj\sum A_i + \sum B_j 对所有选中的 jj,且 i<ji<j 时也会重复算?实际上这是典型的“配对代价”网络流建模)

    更常见的题意:选择恰好 KK 个元素,代价为 $\sum A_i + \sum_{j \text{选}} B_j \cdot (\text{前面选中的个数})$。这正是代码建图所对应的模型。

    朴素建图

    将每个元素拆成两个点(上点 ii 和下点 ii'),源点 SS 连向 ii(容量 1,费用 AiA_i),ii' 连向汇点 TT(容量 1,费用 BiB_i),iijj'i<ji<j)连边(容量 1,费用 0)。再限制总流量为 KK。但这样边数为 O(n2)O(n^2)nn 可达 20002000,无法接受。

    优化建图:利用传递性

    注意到 iijj' 连边 等价于 ii 先到一个中间点 pip_ipip_i 再到 pi+1p_{i+1},...,最后到 jj'。因此我们引入一排中间点 1,2,,n1,2,\dots,n

    • iiii(中间点)连边(容量 1,费用 0)
    • 中间点 iii+1i+1 连边(容量无穷,费用 0)
    • 中间点 iiii' 连边(容量 1,费用 0)

    这样,若 ii 被选中,流量可以从 ii 流入中间点链,然后从某个 i\ge i 的中间点 jj 流出到 jj',表示 iijj 产生了贡献。由于链上边容量无穷,多个前驱可以共享同一条路径到同一个 jj',但每个 jj' 只能接受 1 单位流量(因为连向 TT 的边容量 1),正好对应每个 jj 只会被计算一次 BjB_j

    费用流实现

    • 源点 SS 连向每个 ii(容量 1,费用 AiA_i
    • 每个 ii 连向中间点 ii(容量 1,费用 0)
    • 中间点 ii 连向 i+1i+1(容量 ++\infty,费用 0)
    • 中间点 ii 连向伪汇点 TTTT(容量 1,费用 BiB_i
    • 伪汇点 TTTT 连向真汇点 TT(容量 KK,费用 0)

    跑最小费用最大流,最大流量为 KK 时的最小费用即为答案。

    时间复杂度

    • 点数 O(n)O(n),边数 O(n)O(n)
    • 每次 SPFA/SLF 优化后增广,实际运行效率很高,可处理 n=2000n=2000 的数据。

    AC 代码(含注释)

    #include<bits/stdc++.h>
    #define int long long
    using namespace std;
    
    int read() {
        int x = 0, f = 1; char ch = getchar();
        while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
        while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
        return x * f;
    }
    
    const int INF = 1e15;
    const int N = 5005, M = 100005;
    
    int n, K, S, TT, T, cnt;
    int a[N], b[N], d[N], head[N];
    bool inq[N], vis[N];
    deque<int> Q;
    
    struct Edge {
        int to, cap, cost, next;
    } e[M << 1];
    
    void add(int u, int v, int w, int c) {
        e[cnt] = {v, w, c, head[u]}; head[u] = cnt++;
        e[cnt] = {u, 0, -c, head[v]}; head[v] = cnt++;
    }
    
    bool spfa() {
        for (int i = 0; i <= T; i++) d[i] = INF;
        memset(inq, 0, sizeof(inq));
        d[T] = 0; Q.push_back(T); inq[T] = true;
        while (!Q.empty()) {
            int u = Q.front(); Q.pop_front(); inq[u] = false;
            for (int i = head[u]; ~i; i = e[i].next) {
                int v = e[i].to;
                if (e[i ^ 1].cap && d[v] > d[u] - e[i].cost) {
                    d[v] = d[u] - e[i].cost;
                    if (!inq[v]) {
                        inq[v] = true;
                        if (Q.empty() || d[v] >= d[Q.front()]) Q.push_back(v);
                        else Q.push_front(v);
                    }
                }
            }
        }
        return d[S] < INF;
    }
    
    int dfs(int u, int f) {
        vis[u] = true;
        if (u == T || f == 0) return f;
        int used = 0;
        for (int i = head[u]; ~i; i = e[i].next) {
            int v = e[i].to;
            if (!vis[v] && e[i].cap && d[v] == d[u] - e[i].cost) {
                int w = dfs(v, min(e[i].cap, f - used));
                e[i].cap -= w;
                e[i ^ 1].cap += w;
                used += w;
                if (used == f) return used;
            }
        }
        return used;
    }
    
    signed main() {
        n = read(), K = read();
        for (int i = 1; i <= n; i++) a[i] = read();
        for (int i = 1; i <= n; i++) b[i] = read();
    
        S = n + 1, TT = n + 2, T = n + 3;
        memset(head, -1, sizeof(head));
    
        // 源点连向每个元素
        for (int i = 1; i <= n; i++) {
            add(S, i, 1, a[i]);           // 选中 i 的代价
        }
    
        // 中间点链
        for (int i = 1; i <= n; i++) {
            add(i, i, 1, 0);              // i 进入中间点链
            if (i != n) add(i, i + 1, INF, 0); // 链上传递
            add(i, TT, 1, b[i]);          // 从中间点流出到伪汇点,贡献 b[i]
        }
    
        // 伪汇点限制总流量为 K
        add(TT, T, K, 0);
    
        int flow = 0, cost = 0;
        while (spfa()) {
            memset(vis, 0, sizeof(vis));
            int f = dfs(S, INF);
            flow += f;
            cost += d[S] * f;
        }
    
        printf("%lld\n", cost);
        return 0;
    }
    • 1

    信息

    ID
    6960
    时间
    1000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    1
    已通过
    1
    上传者