1 条题解
-
0
题解
思路概述
- 约束
a_i
形成了一张有向图:若排列中位置k
放元素p[k]
,则之前出现的元素里不能有值等于p[k]
,等价于“点p[k]
的前驱必须在p[k]
被放到序列后才出现”。这意味着图中必须无环,否则根本不存在合法排列。 - 首先检测是否有环:对入度做拓扑排序,若无法遍历完所有点,直接输出
-1
。 - 拓扑序存在时,可以把合法排列理解为把每个强连通分量缩为一个点后按拓扑序合并。代码使用优先队列维护“当前可选的节点集合”,优先选平均权值最小的那个集合(
val=sum/cnt
)。 - 依次弹出最小
val
的节点u
,把它合并到其父结点a[u]
:ans += cnt[a[u]] * sum[u]
,同时更新父结点的sum
、cnt
并重新入队。合并顺序恰好对应合法排列的逆序。 - 最终
ans
就是最大权值。
复杂度
- 拓扑排序 + 并查集合并,时间
O(n log n)
;空间O(n)
。
#include <bits/stdc++.h> #define INF 1e18 #define eps 1e-9 #define scanf(...) assert(scanf(__VA_ARGS__)) using namespace std; using ll = long long; using ld = long double; using ull = unsigned long long; namespace FastIO { static char buf[1000000], *p1 = buf, *p2 = buf; #define gc (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++) inline ll read() { ll res = 0; int w = 0, c = gc; for (; !isdigit(c); c = gc) { ((c == '-') && (w = 1)); } for (; isdigit(c); c = gc) { res = (res << 1) + (res << 3) + (c ^ 48); } return (w ? -res : res); } inline char readC() { int c = gc; while (c == '\n' || c == '\r' || c == ' ') { c = gc; } return c; } inline string readS() { string res = ""; char c = gc; for (; (c == '\n' || c == '\r' || c == ' ' || c == EOF); c = gc); for (; !(c == '\n' || c == '\r' || c == ' ' || c == EOF); c = gc) { res += c; } return res; } inline double readF() { double res = 0, tmp = 0.1; int w = 0; char c = gc; for (; !isdigit(c); c = gc) { ((c == '-') && (w = 1)); } for (; isdigit(c); c = gc) { res = (res * 10) + (c ^ 48); } if (c == '.') { c = gc; for (; isdigit(c); c = gc) { res = res + tmp * (c ^ 48); tmp *= 0.1; } } return (w ? -res : res); } inline void write(ll x, char c = '\n') { ((x < 0) && (putchar('-'), x *= -1)); static int sta[50], top = 0; do { sta[top++] = x % 10, x /= 10; } while (x); while (top) { putchar(sta[--top] + 48); } putchar(c); } }; using namespace FastIO; const int N = 5e5 + 5; int n, a[N], w[N]; int fa[N], cnt[N]; ll sum[N], ans; int find(int x) { return (x == fa[x] ? x : fa[x] = find(fa[x])); } void merge(int x, int y) { ans += cnt[x] * sum[y]; sum[x] += sum[y], cnt[x] += cnt[y]; fa[y] = x; } ld val(int x) { return (ld)sum[x] / cnt[x]; } priority_queue<tuple<ld, int>, vector<tuple<ld, int>>, greater<tuple<ld, int>>> qu; int d[N], vis[N]; queue<int> q; bool topo() { d[0] = n + 1; for (int i = 1; i <= n; ++i) { if (!d[i]) { q.push(i); } } int num = 0; while (!q.empty()) { int u = q.front(); q.pop(); ++num; --d[a[u]]; if (!d[a[u]]) { q.push(a[u]); } } return num == n; } int main() { #ifdef LOCAL assert(freopen("test.in", "r", stdin)); assert(freopen("test.out", "w", stdout)); #endif n = read(); for (int i = 1; i <= n; ++i) { a[i] = read(); ++d[a[i]]; } if (!topo()) { puts("-1"); return 0; } cnt[0] = 1; for (int i = 1; i <= n; ++i) { sum[i] = read(); cnt[i] = 1; fa[i] = i; qu.push({val(i), i}); } while (!qu.empty()) { int u = get<1>(qu.top()); qu.pop(); if (vis[u]) { continue; } vis[u] = 1; int tp = find(a[u]); merge(tp, u); if (tp) { qu.push({val(tp), tp}); } } write(ans); return 0; }
- 约束
- 1
信息
- ID
- 3372
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 8
- 标签
- 递交数
- 1
- 已通过
- 1
- 上传者