并查集-下

Posted by Marlin on August 18, 2025

并查集-下

并查集的小扩展 可以定制信息:并查集目前有多少个集合,以及给每个集合打上标签信息

题目 1 移除最多的同行或同列石头

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
struct DSU {
    vector<int> fa;
    int sets;
    DSU(int n) : fa(n) {
        for (int i = 0; i < n; i++) {
            fa[i] = i;
        }
        sets = n;
    }
    int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
    bool merge(int x, int y) {
        int fx = find(x);
        int fy = find(y);
        if (fx == fy) {
            return false;
        }
        fa[fx] = fy;
        sets--;
        return true;
    }
};
int removeStones(vector<vector<int>> &stones) {
    int n = stones.size();
    map<int, int> col;
    map<int, int> row;
    DSU dsu(n);
    for (int i = 0; i < n; i++) {
        int temp_col = stones[i][0];
        int temp_row = stones[i][1];
        if (col.count(temp_col)) {
            int idx = col[temp_col];
            dsu.merge(i, idx);
        }
        if (row.count(temp_row)) {
            int idx = row[temp_row];
            dsu.merge(i, idx);
        }
        col[temp_col] = i;
        row[temp_row] = i;
    }
    return n - dsu.sets;
}

题目 2 知晓秘密的专家

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
struct DSU {
    vector<int> secret;
    vector<int> fa;
    DSU(int n) : fa(n), secret(n, 0) {
        for (int i = 0; i < n; i++) {
            fa[i] = i;
        }
        secret[0] = true;
    }
    int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
    bool merge(int x, int y) {
        int fx = find(x);
        int fy = find(y);
        if (fx == fy) {
            return false;
        }
        fa[fx] = fy;
        secret[fy] |= secret[fx];
        return true;
    }
};
vector<int> findAllPeople(int n, vector<vector<int>> &meetings,
                            int firstPerson) {
    sort(meetings.begin(), meetings.end(),
            [](vector<int> &a, vector<int> &b) { return a[2] < b[2]; });
    int m = meetings.size();
    DSU dsu(n);
    dsu.merge(0, firstPerson);
    for (int l = 0, r; l < m;) {
        r = l;
        while (r + 1 < m && meetings[l][2] == meetings[r + 1][2]) {
            r++;
        }
        for (int i = l; i <= r; i++) {
            dsu.merge(meetings[i][0], meetings[i][1]);
        }
        for (int i = l, a, b; i <= r; i++) { // 撤销行为
            a = meetings[i][0];
            b = meetings[i][1];
            if (!dsu.secret[dsu.find(a)]) {
                dsu.fa[a] = a;
            }
            if (!dsu.secret[dsu.find(b)]) {
                dsu.fa[b] = b;
            }
        }
        l = r + 1;
    }
    vector<int> ans;
    for (int i = 0; i < n; i++) {
        if (dsu.secret[dsu.find(i)]) {
            ans.push_back(i);
        }
    }
    return ans;
}

题目 3 好路径的数目

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
vector<int> fa, maxcnt;
void init(int n) {
    fa.resize(n);
    maxcnt.resize(n);
    for (int i = 0; i < n; i++) {
        fa[i] = i;
        maxcnt[i] = 1;
    }
}
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
int merge(int x, int y, vector<int> &vals) {
    int fx = find(x);
    int fy = find(y);
    int path = 0;
    if (vals[fx] > vals[fy]) {
        fa[fy] = fx; // 大值做头
    } else if (vals[fx] < vals[fy]) {
        fa[fx] = fy;
    } else {
        path = maxcnt[fx] * maxcnt[fy];
        fa[fy] = fx;
        maxcnt[fx] += maxcnt[fy];
    }
    return path;
}

int numberOfGoodPaths(vector<int> &vals, vector<vector<int>> &edges) {
    int n = vals.size();
    int ans = n;
    init(n);
    sort(edges.begin(), edges.end(),
            [&vals](vector<int> &a, vector<int> &b) {
                return max(vals[a[0]], vals[a[1]]) <
                    max(vals[b[0]], vals[b[1]]);
            });
    for (auto edge : edges) {
        ans += merge(edge[0], edge[1], vals);
    }
    return ans;
}

题目 4 尽量减少恶意软件的传播 II

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
static const int MAXN = 1e5 + 10;
vector<int> fa;
int cnts[MAXN];
int infect[MAXN];
int size[MAXN];
bool virus[MAXN];
void init(int n, vector<int> &initial) {
    fa.resize(n);
    for (int i = 0; i < n; i++) {
        fa[i] = i;
        virus[i] = false;
        cnts[i] = 0;
        infect[i] = -1;
        size[i] = 1;
    }
    for (int i : initial) {
        virus[i] = true;
    }
}
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
bool merge(int x, int y) {
    int fx = find(x);
    int fy = find(y);
    if (fx == fy) {
        return false;
    }
    fa[fx] = fy;
    size[fy] += size[fx];
    return true;
}
// 不是病毒的点,普通点合并
int minMalwareSpread(vector<vector<int>> &graph, vector<int> &initial) {
    int n = graph.size();
    init(n, initial);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            if (graph[i][j] == 1 && !virus[i] && !virus[j]) {
                merge(i, j);
            }
        }
    }
    // 病毒周围的普通点去设置源头
    for (int sick : initial) {
        for (int neighbor = 0; neighbor < n; neighbor++) {
            if (sick != neighbor && !virus[neighbor] &&
                graph[sick][neighbor] == 1) {
                int fn = find(neighbor);
                if (infect[fn] == -1) {
                    infect[fn] = sick;
                } else if (infect[fn] != -2 && infect[fn] != sick) {
                    infect[fn] = -2;
                }
            }
        }
    }
    for (int i = 0; i < n; i++) {
        if (i == find(i) && infect[i] >= 0) {
            cnts[infect[i]] += size[i];
        }
    }
    sort(initial.begin(), initial.end());
    int ans = initial[0];
    int max = cnts[ans];
    for (int i : initial) {
        if (cnts[i] > max) {
            ans = i;
            max = cnts[i];
        }
    }
    return ans;
}