堆结构和堆排序

heap structure and heap sort

Posted by Marlin on July 26, 2025

堆结构

最大堆: 父节点的值大于等于其子节点的值 最小堆: 父节点的值小于等于其子节点的值 本文以最大堆为例。

堆中最重要的操作为向上调整向下调整

  • 向上调整(Heapify Up):将一个节点向上移动,直至其父节点的值小于等于其子节点的值。 0-based 数组下标时,i 的父节点为 (i-1)/2。

  • 向下调整(Heapify Down):将一个节点向下移动,直至其子节点的值都小于等于其父节点的值。 0-based 数组下标时,i 的左子节点为 2i+1,右子节点为 2i+2。

1. 插入(Insert)

步骤: 1. 将新元素添加到数组末尾(完全二叉树的最后一个位置)。

  1. 向上调整(Heapify Up):将新元素与父节点比较,若不满足堆属性则交换,重复此过程直至根节点或满足属性。
  • 时间复杂度:$O(log n)$(树的高度为 log n)。
1
2
3
4
5
6
7
8
9
10
11
void heap_insert(int arr[], int n, int key){
    int i = n;
    arr[i] = key;
    int fa = (i - 1) / 2;

    while (arr[i] > arr[fa]) {
        swap(arr[i], arr[fa]);
        i = fa;
        fa = (i - 1) / 2;
    }
}

2. 修改堆顶元素

步骤: 1. 将新元素代替为堆顶元素。

  1. 向下调整(Heapify Down):将堆顶元素与其子节点比较,若不满足堆属性则交换,重复此过程直至叶节点或满足属性。
  • 时间复杂度:$O(log n)$(树的高度为 log n)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void heap_modify(int arr[], int n, int i) {
    while (1) {
        int max = i;
        // 每次循环重新计算子节点坐标
        int lchil = 2 * i + 1;
        int rchil = 2 * i + 2;

        if (lchil < n && arr[lchil] > arr[max]) {
            max = lchil;
        }
        if (rchil < n && arr[rchil] > arr[max]) {
            max = rchil;
        }
        if (max == i) {
            break;
        }
        swap(arr[i], arr[max]);
        i = max; // 更新当前位置
    }
}

堆排序

步骤: 1. 构造最大堆 $O(n log n)$。

  1. 交换堆顶元素与当前堆的最后一个元素(将最大值归位)。
  2. 重复步骤 2,直至堆大小为 1。
  • 时间复杂度:$O(n log n)$。
1
2
3
4
5
6
7
8
9
10
11
12
void heap_sort(int arr[], int n){
    // 自顶向下(逐个插入):O(nlogn)
    for(int i = 0; i < n; i++){
        heap_insert(arr, n, arr[i]); // 向上调整 构造最大堆
    }
    int size = n;
    while(size > 1){
        swap(arr[0], arr[size - 1]); // 交换堆顶元素与最后一个元素
        size--; // 堆的大小减一
        heap_modify(arr, size, 0); // 向下调整
    }
}

优化: 自顶向下 $O(n log n)$ -> 自底向上建堆 $O(n)$ 堆排序 $O(n log n)$ 步骤: 1. 构造最大堆 $O(n)$。 2. 交换堆顶元素与当前堆的最后一个元素(将最大值归位)。 3. 重复步骤 2,直至堆大小为 1。

  • 时间复杂度:$O(n log n)$。
1
2
3
4
5
6
7
8
9
10
11
12
13
void heap_sort(int arr[], int n) {
    // 自底向上(Floyd算法):O(n)
    for (int i = (n - 1 - 1) / 2; i >= 0; i--) {
        // i从最后一个非叶子节点(最后一个节点的父节点)开始向上建堆
        heap_modify(arr, n, i); // 向下调整 构建最大堆
    }
    int size = n;
    while (size > 1) {
        swap(arr[0], arr[size - 1]); // 交换堆顶元素与最后一个元素
        size--;                      // 堆的大小减一
        heap_modify(arr, size, 0);   // 向下调整
    }
}

对顶堆

可以动态维护一个序列上第 k 大的数,k 值会发生变化。比写线段树或 BST 简单。
对顶堆由一个大根堆与一个小根堆组成,小根堆维护前 k 大的数(包含第 k 个),大根堆维护比第 k 大数小的数。

步骤

  1. 插入:若插入的元素小根堆堆顶元素,则将其插入小根堆,否则将其插入大根堆。
  2. 维护:当小根堆的大小> k时,不断将小根堆堆顶元素取出关插入大根堆,直到小根堆的大小等于 k。
    当小根堆的大小< k时,不断将大根堆堆顶元素取出并插入小根堆,直到小根堆的大小等于 k。
  3. 查询第 k 大元素:小根堆堆顶元素。
  4. 删除第 k 大元素:删除小根堆堆顶元素。
  • 时间复杂度:$O(log k)$。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
priority_queue<int> a; // 大根堆(维护较小的n-k个元素)
priority_queue<int, vector<int>, greater<int>> b; // 小根堆(维护最大的k个元素)

void solve() {
    int n, k;
    cin >> n >> k;
    vector<int> nums(n);
    for (int i = 0; i < n; i++) cin >> nums[i];
    for (int i = 0; i < n; i++) {
        // 插入新元素
        if (b.empty() || nums[i] >= b.top()) {
            b.push(nums[i]);
        } else {
            a.push(nums[i]);
        }

        // 调整堆大小:保持b中恰好k个元素
        while (b.size() > k) {
            a.push(b.top());
            b.pop();
        }
        while (b.size() < k && !a.empty()) {
            b.push(a.top());
            a.pop();
        }

        // 输出当前第k大(当i >= k-1时)
        if (b.size() >= k) {
            cout << b.top() << endl;
        } else {
            cout << "Not enough elements" << endl;
        }
    }
}

例题:黑匣子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
int m, n;
priority_queue<int, vector<int>, greater<>> a; // 小,维护剩余元素
priority_queue<int> b;                         // 大,维护前k小元素
int add_num[N];
int get_num[N];

void solve() {
    cin >> m >> n;
    for (int i = 1; i <= m; i++) {
        cin >> add_num[i];
    }
    for (int i = 1; i <= n; i++) {
        cin >> get_num[i];
    }
    int idx = 1;
    for (int i = 1; i <= n; i++) {
        while (idx <= get_num[i]) {
            if (b.empty() || add_num[idx] <= b.top()) {
                b.push(add_num[idx]);
            } else {
                a.push(add_num[idx]);
            }
            idx++;
        }
        int k = i;
        while (b.size() > k) {
            a.push(b.top());
            b.pop();
        }
        while (!a.empty() && b.size() < k) {
            b.push(a.top());
            a.pop();
        }
        cout << b.top() << endl;
    }
}

例题:中位数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
int n, m;
priority_queue<int, vector<int>, greater<int>> small;
priority_queue<int> big;
void solve() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        int num;
        cin >> num;
        if (big.empty() || num <= big.top()) {
            big.push(num);
        } else {
            small.push(num);
        }
        if (i % 2 == 0) {
            continue;
        }
        int k = i / 2 + 1;
        if (big.size() > k) {
            small.push(big.top());
            big.pop();
        }
        if (big.size() < k) {
            big.push(small.top());
            small.pop();
        }
        cout << big.top() << endl;
    }
}