Heap Sort - 堆排序

堆排序通常基於二元堆 實現,以大根堆(根結點為最大值)爲例,堆排序的實現過程分爲兩個子過程。第一步爲取出大根堆的根節點(當前堆的最大值), 由於取走了一個節點,故需要對餘下的元素重新建堆。重新建堆後繼續取根節點,循環直至取完所有節點,此時數組已經有序。基本思想就是這樣,不過實現上還是有些小技巧的。

堆的操作

以大根堆爲例,堆的常用操作如下。

  1. 最大堆調整(Max_Heapify):將堆的末端子節點作調整,使得子節點永遠小於父節點
  2. 創建最大堆(Build_Max_Heap):將堆所有數據重新排序
  3. 堆排序(HeapSort):移除位在第一個數據的根節點,並做最大堆調整的遞歸運算

其中步驟1是給步驟2和3用的。

Heapsort-example

建堆時可以自頂向下,也可以採取自底向上,以下先採用自底向上的思路分析。我們可以將數組的後半部分節點想象爲堆的最下面的那些節點,由於是單個節點,故顯然滿足二叉堆的定義,於是乎我們就可以從中間節點向上逐步構建二叉堆,每前進一步都保證其後的節點都是二叉堆,這樣一來前進到第一個節點時整個數組就是一個二叉堆了。下面用 C++/Java 實現一個堆的類。C++/Java 中推薦使用 PriorityQueue 來使用堆。

堆排在空間比較小(嵌入式設備和手機)時特別有用,但是因爲現代系統往往有較多的快取,堆排序無法有效利用快取,數組元素很少和相鄰的其他元素比較,故快取未命中的機率遠大於其他在相鄰元素間比較的算法。但是在大數據的排序下又重新發揮了重要作用,因爲它在插入操作和刪除最大元素的混合動態場景中能保證對數級別的運行時間。

C++

  1. #include <iostream>
  2. #include <vector>
  3. using namespace std;
  4. class HeapSort {
  5. // get the parent node index
  6. int parent(int i) {
  7. return (i - 1) / 2;
  8. }
  9. // get the left child node index
  10. int left(int i) {
  11. return 2 * i + 1;
  12. }
  13. // get the right child node index
  14. int right(int i) {
  15. return 2 * i + 2;
  16. }
  17. // build max heap
  18. void build_max_heapify(vector<int> &nums, int heap_size) {
  19. for (int i = heap_size / 2; i >= 0; --i) {
  20. max_heapify(nums, i, heap_size);
  21. }
  22. print_heap(nums, heap_size);
  23. }
  24. // build min heap
  25. void build_min_heapify(vector<int> &nums, int heap_size) {
  26. for (int i = heap_size / 2; i >= 0; --i) {
  27. min_heapify(nums, i, heap_size);
  28. }
  29. print_heap(nums, heap_size);
  30. }
  31. // adjust the heap to max-heap
  32. void max_heapify(vector<int> &nums, int k, int len) {
  33. // int len = nums.size();
  34. while (k < len) {
  35. int max_index = k;
  36. // left leaf node search
  37. int l = left(k);
  38. if (l < len && nums[l] > nums[max_index]) {
  39. max_index = l;
  40. }
  41. // right leaf node search
  42. int r = right(k);
  43. if (r < len && nums[r] > nums[max_index]) {
  44. max_index = r;
  45. }
  46. // node after k are max-heap already
  47. if (k == max_index) {
  48. break;
  49. }
  50. // keep the root node the largest
  51. int temp = nums[k];
  52. nums[k] = nums[max_index];
  53. nums[max_index] = temp;
  54. // adjust not only just current index
  55. k = max_index;
  56. }
  57. }
  58. // adjust the heap to min-heap
  59. void min_heapify(vector<int> &nums, int k, int len) {
  60. // int len = nums.size();
  61. while (k < len) {
  62. int min_index = k;
  63. // left leaf node search
  64. int l = left(k);
  65. if (l < len && nums[l] < nums[min_index]) {
  66. min_index = l;
  67. }
  68. // right leaf node search
  69. int r = right(k);
  70. if (r < len && nums[r] < nums[min_index]) {
  71. min_index = r;
  72. }
  73. // node after k are min-heap already
  74. if (k == min_index) {
  75. break;
  76. }
  77. // keep the root node the largest
  78. int temp = nums[k];
  79. nums[k] = nums[min_index];
  80. nums[min_index] = temp;
  81. // adjust not only just current index
  82. k = min_index;
  83. }
  84. }
  85. public:
  86. // heap sort
  87. void heap_sort(vector<int> &nums) {
  88. int len = nums.size();
  89. // init heap structure
  90. build_max_heapify(nums, len);
  91. // heap sort
  92. for (int i = len - 1; i >= 0; --i) {
  93. // put the largest number int the last
  94. int temp = nums[0];
  95. nums[0] = nums[i];
  96. nums[i] = temp;
  97. // reconstruct heap
  98. build_max_heapify(nums, i);
  99. }
  100. print_heap(nums, len);
  101. }
  102. // print heap between [0, heap_size - 1]
  103. void print_heap(vector<int> &nums, int heap_size) {
  104. for (int i = 0; i < heap_size; ++i) {
  105. cout << nums[i] << ", ";
  106. }
  107. cout << endl;
  108. }
  109. };
  110. int main(int argc, char *argv[])
  111. {
  112. int A[] = {19, 1, 10, 14, 16, 4, 7, 9, 3, 2, 8, 5, 11};
  113. vector<int> nums;
  114. for (int i = 0; i < sizeof(A) / sizeof(A[0]); ++i) {
  115. nums.push_back(A[i]);
  116. }
  117. HeapSort sort;
  118. sort.print_heap(nums, nums.size());
  119. sort.heap_sort(nums);
  120. return 0;
  121. }

Java

  1. import java.util.*;
  2. public class HeapSort {
  3. // sign = 1 ==> min-heap, sign = -1 ==> max-heap
  4. private void siftDown(int[] nums, int k, int size, int sign) {
  5. int half = (size >>> 1);
  6. while (k < half) {
  7. int index = k;
  8. // left leaf node search
  9. int l = (k << 1) + 1;
  10. if (l < size && (sign * nums[l]) < (sign * nums[index])) {
  11. index = l;
  12. }
  13. // right leaf node search
  14. int r = l + 1;
  15. if (r < size && (sign * nums[r]) < (sign * nums[index])) {
  16. index = r;
  17. }
  18. // already heapify
  19. if (k == index) break;
  20. // keep the root node the smallest/largest
  21. int temp = nums[k];
  22. nums[k] = nums[index];
  23. nums[index] = temp;
  24. // adjust next index
  25. k = index;
  26. }
  27. }
  28. private void heapify(int[] nums, int size, int sign) {
  29. for (int i = size / 2; i >= 0; i--) {
  30. siftDown(nums, i, size, sign);
  31. }
  32. }
  33. private void minHeap(int[] nums, int size) {
  34. heapify(nums, size, 1);
  35. }
  36. private void maxHeap(int[] nums, int size) {
  37. heapify(nums, size, -1);
  38. }
  39. public void sort(int[] nums, boolean ascending) {
  40. if (ascending) {
  41. // build max heap
  42. maxHeap(nums, nums.length);
  43. // heap sort
  44. for (int i = nums.length - 1; i >= 0; i--) {
  45. int temp = nums[0];
  46. nums[0] = nums[i];
  47. nums[i] = temp;
  48. // reconstruct max heap
  49. maxHeap(nums, i);
  50. }
  51. } else {
  52. // build min heap
  53. minHeap(nums, nums.length);
  54. // heap sort
  55. for (int i = nums.length - 1; i >= 0; i--) {
  56. int temp = nums[0];
  57. nums[0] = nums[i];
  58. nums[i] = temp;
  59. // reconstruct min heap
  60. minHeap(nums, i);
  61. }
  62. }
  63. }
  64. public static void main(String[] args) {
  65. int[] A = new int[]{19, 1, 10, 14, 16, 4, 4, 7, 9, 3, 2, 8, 5, 11};
  66. HeapSort heapsort = new HeapSort();
  67. heapsort.sort(A, true);
  68. for (int i : A) {
  69. System.out.println(i);
  70. }
  71. }
  72. }

複雜度分析

從程式碼中可以發現堆排最費時間的地方在於構建二叉堆的過程。

上述構建大根堆和小根堆都是自底向上的方法,建堆過程時間複雜度爲 O(2N), 堆化過程(可結合圖形分析,最多需要調整的層數爲最大深度)時間複雜度爲 \log i, 故堆排過程中總的時間複雜度爲 O(N \log N).

先看看建堆的過程,畫圖分析(比如以8個節點爲例)可知在最壞情況下,每次都需要調整之前已經成爲堆的節點,那麼就意味着有二分之一的節點向下比較了一次,四分之一的節點向下比較了兩次,八分之一的節點比較了三次… 等差等比數列求和,具體過程可參考下面的連結。

Reference