1 条题解

  • 0
    @ 2025-10-19 17:11:16

    题目解法:强连通缩点 + 树链剖分 + 区间合并

    问题分析

    题目要求计算每次游行可能经过的城市数量,关键点在于:

    1. C 国原有道路满足特殊性质:若 xzx \Rightarrow zyzy \Rightarrow z,则 xyx \Rightarrow yyxy \Rightarrow x

    2. 每次游行最多添加 22 条临时边

    3. 需要计算从起点到终点能到达的所有城市

    关键性质

    原图强连通缩点后形成有向无环图 (DAG),且该 DAG 具有特殊结构:

    • 对于任意两个节点,如果它们都能到达某个节点,则其中一个能到达另一个

    • 这意味着缩点后的 DAG 实际上是一棵树(更准确地说,是外向树)

    算法步骤

    1. 强连通分量分解

    使用 Tarjan 算法将原图缩点为 DAG

    每个 SCC 的大小记录在 csz[] 中

    2. 构建树结构

    在缩点后的 DAG 上,由于特殊性质,可以找到唯一的外向树结构

    通过拓扑排序构建树,fa[i] 记录父节点

    3. 树链剖分预处理

    进行 DFS 序编号,便于处理路径

    计算 DFS 序前缀和 z[],用于快速计算区间内城市数量

    4. 查询处理

    对于每个查询 (s,t,k(s, t, k 条临时边))

    1.映射到缩点树:将起点、终点和临时边的端点映射到对应的 SCC 节点

    2.构建辅助图:包含原树边和临时边,共 2k+22k+2 个节点(起点、终点、临时边端点)

    3.可达性分析:

    从起点 DFS 标记能到达的节点 (vs[])

    从终点反向 DFS 标记能到达它的节点 (vt[])

    4.路径合并:

    对于所有 vs[x] && vt[y] && g[x][y] 的节点对,将其在树上的路径加入序列

    这些路径的并集就是可能经过的 SCC 节点

    5. 区间合并计算:

    将路径按 DFS 序排序后合并区间

    使用前缀和计算区间内城市总数

    复杂度分析

    • Tarjan 缩点:O(n+m)O(n + m)

    • 树链剖分:O(n)O(n)

    • 每次查询:O(k2+klogn)O(k^2 + k \log n),主要在于路径合并排序

    • 总复杂度:O(n+m+qk2logn)O(n + m + qk^2 \log n),可过 n,q3×105n,q \leq 3\times 10^5

    代码亮点

    1. 高效缩点:使用 Tarjan 算法,栈优化

    2. 树链剖分:快速判断祖先关系和提取路径

    3. 区间合并:将多条路径合并为不重叠区间,用前缀和快速求和

    4. 临时边处理:将临时边融入辅助图,统一处理可达性

    AC代码

    #include <bits/stdc++.h>
    using namespace std;
    
    #define fi first
    #define se second
    #define ll long long
    #define gc getchar()
    #define pb push_back
    #define pii pair <int,int>
    #define mem(x,v) memset(x,v,sizeof(x))
    
    inline int read() {
        int x = 0;
        char s = gc;
    
        while (!isdigit(s))
            s = gc;
    
        while (isdigit(s))
            x = (x << 1) + (x << 3) + s - '0', s = gc;
    
        return x;
    }
    
    const int N = 3e5 + 5;
    const int K = 20;
    
    struct EDGE {
        int cnt, hd[N], nxt[N << 1], to[N << 1];
        void add(int u, int v) {
            nxt[++cnt] = hd[u], hd[u] = cnt, to[cnt] = v;
        }
    } E, G;
    
    int tdn, cn, ttop, tdfn[N], low[N], stc[N], col[N], csz[N], vis[N];
    void tarjan(int x) {
        vis[x] = 1, tdfn[x] = low[x] = ++tdn, stc[++ttop] = x;
    
        for (int i = E.hd[x]; i; i = E.nxt[i]) {
            int it = E.to[i];
    
            if (!tdfn[it])
                tarjan(it), low[x] = min(low[x], low[it]);
            else if (vis[it])
                low[x] = min(low[x], tdfn[it]);
        }
    
        if (tdfn[x] == low[x]) {
            col[x] = ++cn, vis[x] = 0, csz[cn] = 1;
    
            while (stc[ttop] != x)
                vis[stc[ttop]] = 0, col[stc[ttop--]] = cn, csz[cn]++;
    
            ttop--;
        }
    }
    
    int n, m, Q, k, r, deg[N];
    vector <int> e[N];
    
    int dn, sz[N], dep[N], son[N], dfn[N], fa[N], top[N];
    void dfs1(int id, int d) {
        dep[id] = d++, sz[id] = 1;
    
        for (int it : e[id]) {
            dfs1(it, d);
    
            if (sz[son[id]] < sz[it])
                son[id] = it;
    
            sz[id] += sz[it];
        }
    }
    void dfs2(int id, int t) {
        top[id] = t, dfn[id] = ++dn;
    
        if (son[id])
            dfs2(son[id], t);
    
        for (int it : e[id])
            if (it != son[id])
                dfs2(it, it);
    }
    bool anc(int x, int y) {
        if (dep[x] > dep[y])
            return 0;
    
        while (dep[top[y]] > dep[x])
            y = fa[top[y]];
    
        return top[x] == top[y];
    }
    
    pii seq[5000];
    int cnt, p[6], g[6][6], vs[6], vt[6], z[N];
    void check(int i, int j) {
        if (!p[i] || !p[j])
            return;
    
        g[i][j] = anc(p[i], p[j]);
    }
    void dfs(int id, int *v, bool tp) {
        v[id] = 1;
    
        for (int i = 0; i < 6; i++)
            if ((tp == 0 && g[id][i] || tp == 1 && g[i][id]) && !v[i])
                dfs(i, v, tp);
    }
    void add(int x, int y) {
        if (!vs[x] || !vt[y] || !g[x][y])
            return;
    
        x = p[x], y = p[y];
    
        if (dep[x] > dep[y])
            swap(x, y);
    
        while (dep[top[y]] > dep[x])
            seq[++cnt] = {dfn[top[y]], dfn[y]}, y = fa[top[y]];
        seq[++cnt] = {dfn[x], dfn[y]};
    }
    
    int main() {
        //freopen("celebration.in", "r", stdin);
        //freopen("celebration.out", "w", stdout);
        cin >> n >> m >> Q >> k;
    
        for (int i = 1, u, v; i <= m; i++)
            u = read(), v = read(), E.add(u, v);
    
        for (int i = 1; i <= n; i++)
            if (!tdfn[i])
                tarjan(i);
    
        for (int i = 1; i <= n; i++)
            for (int j = E.hd[i]; j; j = E.nxt[j]) {
                int to = E.to[j];
    
                if (col[i] != col[to])
                    G.add(col[i], col[to]), deg[col[to]]++;
            }
    
        queue <int> q;
        n = cn;
    
        for (int i = 1; i <= n; i++)
            if (!deg[i])
                q.push(i);
    
        while (!q.empty()) {
            int t = q.front();
            q.pop();
    
            for (int i = G.hd[t]; i; i = G.nxt[i]) {
                int to = G.to[i];
    
                if (!--deg[to])
                    e[t].pb(to), fa[to] = t, q.push(to);
            }
        }
    
        for (int i = 1; i <= n; i++)
            if (!fa[i])
                r = i;
    
        dfs1(r, 1), dfs2(r, r);
    
        for (int i = 1; i <= n; i++)
            z[dfn[i]] = csz[i];
    
        for (int i = 1; i <= n; i++)
            z[i] += z[i - 1];
    
        for (int i = 1; i <= Q; i++, cnt = 0, mem(g, 0), mem(vs, 0), mem(vt, 0)) {
            p[0] = col[read()], p[1] = col[read()];
    
            if (k)
                p[2] = col[read()], p[3] = col[read()], g[2][3] = 1;
    
            if (k == 2)
                p[4] = col[read()], p[5] = col[read()], g[4][5] = 1;
    
            check(0, 1), check(0, 2), check(0, 4);
            check(1, 2), check(1, 4);
            check(3, 1), check(3, 2), check(3, 4);
            check(5, 1), check(5, 2), check(5, 4);
    
            for (int i = 0; i < 2 * k + 2; i++)
                for (int j = 0; j < 2 * k + 2; j++)
                    if (p[i] == p[j])
                        g[i][j] = g[j][i] = 1;
    
            dfs(0, vs, 0), dfs(1, vt, 1);
            add(0, 1), add(0, 2), add(0, 4);
            add(1, 2), add(1, 4);
            add(3, 1), add(3, 2), add(3, 4);
            add(5, 1), add(5, 2), add(5, 4);
    
            if (cnt == 0) {
                puts("0");
                continue;
            }
    
            sort(seq + 1, seq + cnt + 1);
            int l = seq[1].fi, r = seq[1].se, ans = 0;
    
            for (int i = 2; i <= cnt; i++) {
                if (seq[i].fi > r + 1)
                    ans += z[r] - z[l - 1], l = seq[i].fi, r = seq[i].se;
                else
                    r = seq[i].se;
            }
    
            printf("%d\n", ans + z[r] - z[l - 1]);
        }
    
        return 0;
    }
    
    • 1

    信息

    ID
    3391
    时间
    2000ms
    内存
    1024MiB
    难度
    10
    标签
    递交数
    7
    已通过
    1
    上传者