「TOP K」问题的几种解法

序言

TOPKTOP K 问题应该是面试必问的算法题了,通常是问数组 numsnums 中第 KK 大或者第 KK 小的元素。
通常有如下几种解法:

  • 直接排序
  • 优先队列(大根堆/小根堆)
  • 快速选择(基于快排的划分)

本文以 LC 215. 数组中的第K个最大元素为例分析。

直接排序

这里贴一个快排模板,需要注意的是,每次划分中pivot的选取对快排性能影响很大,极端情况下(数组逆序)时间复杂度会退化到 O(n2)O(n^2)。常见的解决方案是 三者取中,即首尾和中间,取其中位数。

class Solution {
    public int findKthLargest(int[] nums, int k) {
        quick_sort(nums, 0, nums.length-1);
        return nums[nums.length-k];
    }
    int partition(int[] nums, int left, int right) {
        int piv = nums[left];
        while ( left < right ) {
            // right找比piv小的
            while ( left < right && nums[right] >= piv ) right--;
            nums[left] = nums[right];

            while ( left < right && nums[left] <= piv ) left++;
            nums[right] = nums[left];
        }
        // piv放入最终位置
        nums[left] = piv;
        return left;
    }
    void quick_sort(int[] nums, int left, int right) {
        if ( left < right ) {
            int piv_index = partition(nums, left, right);
            quick_sort(nums, left, piv_index-1);
            quick_sort(nums, piv_index+1, right);
        }
    }
}
  • 时间复杂度:O(nlogn)O(nlogn),极端情况下退化到 O(n2)O(n^2)

  • 空间复杂度:O(logn)O(logn),递归调用栈的开销

优先队列

优先队列,也就是堆。对于 TOPKTOP K 大的元素,维护大小为 KK 的小根堆,不断将 numsnums 中的元素送进去,最终堆顶元素就是 TOPKTOP K 大的元素。

  • 为什么是小根堆而不是大根堆?
  • 如果用大根堆,就要对全量元素进行全部排序;反之,维护一个大小为k的小根堆,最终留下的堆顶,就是topk大
class Solution {
    public int findKthLargest(int[] nums, int k) {
        PriorityQueue<Integer> pq = new PriorityQueue<>();
        for ( int x : nums ) {
            pq.offer(x);
            if ( pq.size() > k ) {
                pq.poll();
            }
        }
        return pq.peek();
    }
}
  • 时间复杂度:O(nlogk)O(nlogk),遍历元素 O(n)O(n),堆调整 O(logk)O(logk)

  • 空间复杂度:O(k)O(k)

快速选择

快速排序的每次划分,会确定一个pivot元素的最终位置,在pivot左边的元素都比pivot小;右边的都比pivot大(但都不一定有序)。

快速选择基于快速排序的划分思想,每次划分之后,得到pivot元素的下标,并不像快速排序那样递归地对左右两个子区间进行分治处理。

而是判断当前pivot元素的下标 pivot_idxpivot\_idx 和待求 topk 元素的下标 k_idx:=nums.lengthkk\_idx := nums.length-k 作比较,如果:

  • pivot_idx==k_idxpivot\_idx == k\_idx,直接返回 nums[pivot_idx]nums[pivot\_idx]
  • pivot_idx>k_idxpivot\_idx > k\_idx,说明目标的K应该在左半区间,则递归在左半区间进行处理
  • pivot_idx<k_idxpivot\_idx < k\_idx,说明目标的K应该在右半区间,则递归在右半区间进行处理

可以看出,相对于朴素快排,快速选择相当于进行了一次“二分”剪枝处理,每次划分完之后,并没有像快速排序那样对左右两个子区间都进行递归处理。

class Solution {
    public int findKthLargest(int[] nums, int k) {
        quick_select(nums, 0, nums.length-1, nums.length-k);
        return nums[nums.length-k];
    }
    int partition(int[] nums, int left, int right) {
        int piv = nums[left];
        while ( left < right ) {
            // right找比piv小的
            while ( left < right && nums[right] >= piv ) right--;
            nums[left] = nums[right];

            while ( left < right && nums[left] <= piv ) left++;
            nums[right] = nums[left];
        }
        // piv放入最终位置
        nums[left] = piv;
        return left;
    }

    void quick_select(int[] nums, int left, int right, int k_index) {
        if ( left < right ) {
            int piv_idx = partition(nums, left, right);
            if ( k_index == piv_idx ) {
                return;
            }
            // topk 索引比piv索引大
            else if ( k_index > piv_idx ) {
                // 往右搜索
                quick_select(nums, piv_idx+1, right, k_index);
            } else {
                quick_select(nums, left, piv_idx-1, k_index);
            }
        }
    }
}
  • 时间复杂度:O(n)O(n),具体证明参见《算法导论》

  • 空间复杂度:O(logn)O(logn),递归调用栈的开销

写在后面

TOPKTOP K 问题作为算法题来考察,主要还是考察「快速选择」的,因此务必要掌握。除此之外,如果涉及到海量数据处理,可能需要用到外部排序,MapReduce处理等。

最后,再贴一个堆排序和归并排序的模板。

堆排序模板

class Solution {
    public int[] sortArray(int[] nums) {
        int n = nums.length;
        int[] res = new int[n+1];
        System.arraycopy(nums, 0, res, 1, n);

        buildMaxHeap(res, n);
        for ( int i = n; i > 1; i-- ) { // n-1趟交换和建堆
            swap(res, i, 1);            // 堆顶和堆底交换(输出堆顶)
            adjust(res, 1, i-1);        // 把剩下的i-1个元素调整成堆
        }
        return Arrays.copyOfRange(res, 1, n+1);
    }
    // 初始建堆
    void buildMaxHeap(int[] nums, int len) {
        for ( int i = len/2; i > 0; i-- ) { // 从 [n/2] ~ 1,调整堆
            adjust(nums, i, len);
        }
    }
    // 调整k为根的子树为大根堆
    void adjust(int[] nums, int k, int len) {
        nums[0] = nums[k];      // nums[0]暂存当前根节点
        for ( int i = 2*k; i <= len; i *= 2 ) {
            if ( i < len && nums[i] < nums[i+1] ) {
                i++;
            }
            if ( nums[0] >= nums[i] ) {
                break;
            } else {
                nums[k] = nums[i];
                k = i;
            }
        }
        nums[k] = nums[0];
    }
    void swap(int[] nums, int i, int j) {
        int tmp = nums[i];
        nums[i] = nums[j];
        nums[j] = tmp;
    }
}

归并排序模板

class Solution {
    void merge_sort(int[] nums, int low, int high) {
        if ( low < high ) {
            int mid = low + (high-low)/2;
            merge_sort(nums, low, mid);
            merge_sort(nums, mid+1, high);
            // 合并
            merge(nums, low, mid, high);
        }
    }
    // 合并两个有序子数组[low,mid] 和 [mid+1, high]
    void merge(int[] nums, int low, int mid, int high) {
        int[] tmp = new int[high-low+1];
        int i = low, j = mid+1, idx = 0;
        while ( i <= mid && j <= high ) {
            if ( nums[i] <= nums[j] ) {
                tmp[idx++] = nums[i++];
            } else {
                tmp[idx++] = nums[j++];
                // 左边的大于右边的,说明 mid+1 ~ i 之间的都是逆序对
            }
        }
        // 剩余的直接复制过去
        while ( i <= mid ) {
            tmp[idx++] = nums[i++];
        }
        while ( j <= high ) {
            tmp[idx++] = nums[j++];
        }
        System.arraycopy(tmp, 0, nums, low, tmp.length);
    }
}