1 条题解

  • 0
    @ 2025-10-19 16:17:19

    题解

    思路概述

    • 约束 a_i 形成了一张有向图:若排列中位置 k 放元素 p[k],则之前出现的元素里不能有值等于 p[k],等价于“点 p[k] 的前驱必须在 p[k] 被放到序列后才出现”。这意味着图中必须无环,否则根本不存在合法排列。
    • 首先检测是否有环:对入度做拓扑排序,若无法遍历完所有点,直接输出 -1
    • 拓扑序存在时,可以把合法排列理解为把每个强连通分量缩为一个点后按拓扑序合并。代码使用优先队列维护“当前可选的节点集合”,优先选平均权值最小的那个集合(val=sum/cnt)。
    • 依次弹出最小 val 的节点 u,把它合并到其父结点 a[u]ans += cnt[a[u]] * sum[u],同时更新父结点的 sumcnt 并重新入队。合并顺序恰好对应合法排列的逆序。
    • 最终 ans 就是最大权值。

    复杂度

    • 拓扑排序 + 并查集合并,时间 O(n log n);空间 O(n)
    #include <bits/stdc++.h>
    #define INF 1e18
    #define eps 1e-9
    #define scanf(...) assert(scanf(__VA_ARGS__))
    using namespace std;
    
    using ll = long long;
    using ld = long double;
    using ull = unsigned long long;
    
    namespace FastIO {
    static char buf[1000000], *p1 = buf, *p2 = buf;
    #define gc (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++)
    inline ll read() {
        ll res = 0;
        int w = 0, c = gc;
    
        for (; !isdigit(c); c = gc) {
            ((c == '-') && (w = 1));
        }
    
        for (; isdigit(c); c = gc) {
            res = (res << 1) + (res << 3) + (c ^ 48);
        }
    
        return (w ? -res : res);
    }
    inline char readC() {
        int c = gc;
    
        while (c == '\n' || c == '\r' || c == ' ') {
            c = gc;
        }
    
        return c;
    }
    inline string readS() {
        string res = "";
        char c = gc;
    
        for (; (c == '\n' || c == '\r' || c == ' ' || c == EOF); c = gc);
    
        for (; !(c == '\n' || c == '\r' || c == ' ' || c == EOF); c = gc) {
            res += c;
        }
    
        return res;
    }
    inline double readF() {
        double res = 0, tmp = 0.1;
        int w = 0;
        char c = gc;
    
        for (; !isdigit(c); c = gc) {
            ((c == '-') && (w = 1));
        }
    
        for (; isdigit(c); c = gc) {
            res = (res * 10) + (c ^ 48);
        }
    
        if (c == '.') {
            c = gc;
    
            for (; isdigit(c); c = gc) {
                res = res + tmp * (c ^ 48);
                tmp *= 0.1;
            }
        }
    
        return (w ? -res : res);
    }
    inline void write(ll x, char c = '\n') {
        ((x < 0) && (putchar('-'), x *= -1));
        static int sta[50], top = 0;
    
        do {
            sta[top++] = x % 10, x /= 10;
        } while (x);
    
        while (top) {
            putchar(sta[--top] + 48);
        }
    
        putchar(c);
    }
    };
    using namespace FastIO;
    
    const int N = 5e5 + 5;
    
    int n, a[N], w[N];
    
    int fa[N], cnt[N];
    ll sum[N], ans;
    int find(int x) {
        return (x == fa[x] ? x : fa[x] = find(fa[x]));
    }
    void merge(int x, int y) {
        ans += cnt[x] * sum[y];
        sum[x] += sum[y], cnt[x] += cnt[y];
        fa[y] = x;
    }
    ld val(int x) {
        return (ld)sum[x] / cnt[x];
    }
    
    priority_queue<tuple<ld, int>, vector<tuple<ld, int>>, greater<tuple<ld, int>>> qu;
    
    int d[N], vis[N];
    queue<int> q;
    bool topo() {
        d[0] = n + 1;
    
        for (int i = 1; i <= n; ++i) {
            if (!d[i]) {
                q.push(i);
            }
        }
    
        int num = 0;
    
        while (!q.empty()) {
            int u = q.front();
            q.pop();
            ++num;
            --d[a[u]];
    
            if (!d[a[u]]) {
                q.push(a[u]);
            }
        }
    
        return num == n;
    }
    
    int main() {
    #ifdef LOCAL
        assert(freopen("test.in", "r", stdin));
        assert(freopen("test.out", "w", stdout));
    #endif
    
        n = read();
    
        for (int i = 1; i <= n; ++i) {
            a[i] = read();
            ++d[a[i]];
        }
    
        if (!topo()) {
            puts("-1");
        	return 0;
        }
    
        cnt[0] = 1;
        for (int i = 1; i <= n; ++i) {
            sum[i] = read();
            cnt[i] = 1;
            fa[i] = i;
            qu.push({val(i), i});
        }
    
        while (!qu.empty()) {
            int u = get<1>(qu.top());
            qu.pop();
    
            if (vis[u]) {
                continue;
            }
            vis[u] = 1;
    
            int tp = find(a[u]);
            merge(tp, u);
            if (tp) {
                qu.push({val(tp), tp});
            }
        }
    
        write(ans);
    
    
        return 0;
    }
    
    • 1

    信息

    ID
    3372
    时间
    1000ms
    内存
    256MiB
    难度
    8
    标签
    递交数
    1
    已通过
    1
    上传者