好久没更新了,之前在忙SMP 2023的那个比赛和实验室的一些事情。好吧,其实也没那么忙,只是感觉很疲惫然后很想休息,一刷题就犯困了。今天要回顾的主题是快速排序,它很有诱惑力的地方是它既是一个平均复杂度在\(\Theta(n \log n)\)的算法,而且在实现时又不需要额外申请数组,可以就地排序。在实际应用中,它比大多数排序算法都要快,是名副其实的“快速”算法。在这篇文章里,我们会回顾快速排序的原理,也会进行一些防御性操作阻止快速排序陷入最坏条件中,最后也会把之前的两个算法题用快速排序解决一下。

快速排序的基本原理

快速排序是一种分治排序的算法,它将一个数组分成两个子数组,并将两个子数组独立地进行排序。快速排序与归并排序既有区别又有联系,它们的相同点在于都是应用分治法对数组进行排序;不同点在于归并排序是先分后治最后合并,而快速排序则是先治后分,当子数组有序时整个数组就有序了,因此快速排序不需要额外的合并操作了。

快速排序的核心操作在于“切分”,也即我们将实现的partition函数,这个函数将目标元素放在一个合适的位置j,在这个位置左边的数均不大于a[j],在这个位置右边的数均不小于a[j]。一旦我们实现了这个函数,那只需要知道被切分的元素放置的位置j,对于左边和右边的子数组再次应用快速排序方法就可以完成排序了。用Python实现这个排序算法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# 快速排序
from typing import List

def partition(a: List[int], left: int, right: int):
"""快速排序的核心,切分方法。
该方法使得j左边的元素均不大于a[j],使得j右边的元素均不小于a[j]
最后该方法会返回j的位置以便于后面继续排序整个数组"""
i, j = left, right + 1
# 枢轴元素位置,其可以是任意的一个元素,但当它是最左边元素时比较好进行双指针操作
# 因此即使要随机选取枢轴,我们也会先让它和最左边元素进行交换
pivot = a[left]

# 左右双指针实现左边元素不大于枢轴值,右边元素不小于枢轴值
# 一旦小于就进行交换
while 1:
# 左边指针一直移动到某个值不小于枢轴量的地方
i += 1
while a[i] < pivot:
if i == right:
break
i += 1

# 同样右边指针一直移动到某个值不大于枢轴量的地方
j -= 1
while a[j] > pivot:
# 冗余的边界检查,因为j==left时a[j] > pivot不成立
if j == left:
break
j -= 1

# 检查左右指针是否已经相遇或者错过
if i >= j:
break
# 通过检查后交换两元素
a[i], a[j] = a[j], a[i]

# 将枢轴元素与j所处位置元素交换
a[left], a[j] = a[j], a[left]
# 返回已经放入正确位置的元素的位置
return j

def quicksort(a: List[int], left: int, right: int):
# 边界检查,一个元素及以下的数组不用排序
if left >= right:
return None
j = partition(a, left, right)
# 分治,继续处理左边和右边的排序子问题
quicksort(a, left, j-1)
quicksort(a, j+1, right)



a = [9, 7, 5, 6, 4, 3, 1, 2, 8]
print(f"排序前: {a}")
quicksort(a, 0, len(a)-1)
print(f"排序后: {a}")
# 排序前: [9, 7, 5, 6, 4, 3, 1, 2, 8]
# 排序后: [1, 2, 3, 4, 5, 6, 7, 8, 9]

这样我们就得到了一个简易版的快速排序算法,它可以解决很多问题了。但是同时它也是脆弱的,它很容易在一些测试用例中陷入最坏情况从而具有\(O(n^2)\)复杂度。

快速排序初始实现的弱点及改进

特意构造的切分不均衡

上面的快速排序算法每次都将数组中的第一个元素放在合适的位置。那可以想象,如果这本身就是一个排序好的数组或者是一个倒序的数组,每次地切分将只产生一个数组。这种极其不均衡地切分使得排序的复杂度来到\(O(n^2)\)。但这是可以预防的,一个防止人为构造这样数组的方法就是在快速排序中引入随机性。随机性可以有两种产生方式,一是在排序之前我们就特意随机打乱数组,二是在选择枢轴元素时从数组中随机选择一个元素而不是坚持选择左边的元素。

第一种方式的实现只需要在原来的实现上套一层打乱该数组即可,实现如下:

1
2
3
4
5
import random

def quicksort_random(a: List[int]):
random.shuffle(a)
quicksort(a, 0, len(a)-1)

第二种实现也只需要随机选择一个元素,然后让这个元素和最左边元素交换一下位置即可,实现如下:

1
2
3
4
5
6
import random

def partition_random(a: List[int], left: int, right: int):
random_pos = random.randint(left, right)
a[left], a[random_pos] = a[random_pos], a[left]
partition(a, left, right)

三取样切分

之前谈到能够人为构造快速排序的最坏情况是因为在切分的时候左右子数组不均衡。从原理上讲最好的切分应该是数组的中位数,这样切分的左右子数组是均衡的,然而代价是需要计算中位数。计算中位数的算法也可以由切分函数演化而来,一般来讲找到中位数也需要线性时间复杂度,最坏情况下需要平方级别的复杂度。这样的复杂度用来选择最佳切分点可能是得不偿失的。

但人们发现将取样大小设置为3并用居中元素切分效果最好,还可以将取样元素(即3个元素中的中位数)放在末尾作为“哨兵”来去掉partition函数中的数组边界测试。我们一步步地思考一下这样子的操作是如何使得我们可以免除左右边界的检查的:

  1. 首先,我们将三个元素里的中位数移到末尾,那左指针(即i)肯定是不会越界了。因为它最终会指向那个取样元素,自己不可能小于自己,因此左指针此时不会越界。这里的原因类似于前文右指针不会越界的原因。
  2. 之后,我们讨论一下为啥右指针(即j)不会越界。因为取样元素是开头三个元素的中位数,这里可能有两种情况,第一种情况是原来的几个元素恰好按升序排列,这样右指针一定会遇到左边的比取样元素小的那个元素从而发觉i>=j了。另一种情况是这三取样的数在原数组的顺序是反过来的,那这样左右指针相遇前一定会完成一次交换,那这样就回到第一种情况了。总结来说就是,左指针不会越界,然后左指针去过的地方一定有一个值是小于等于取样元素的,这时候可以发现i>=j从而终止算法。
  3. 最后,我们讨论一下左右指针的初始位置。因为取样元素被移到最后了,所以这里左指针实际是从第一个元素出发,右指针实际是从倒数第二个元素出发。然后最后我们应该和左指针交换位置,有种左右指针和之前反过来了的感觉。

之后,我们对这个思路进行实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def partition_3_samples(a: List[int], left: int, right: int):
"""三取样切分,把三取样元素挪到最后作为哨兵从而去除边界检查"""
# 先看看有没有三个元素
if right - left == 0:
return right
elif right - left == 1:
if a[left] > a[right]:
# 两个元素而且位置不对就交换一下就行了
a[left], a[right] = a[right], a[left]
return right

# 这里是有至少三个元素的,先从前三个元素里找取样元素
# 因为元素很少,用冒泡排序(选择排序)、插入排序都可
for m in range(left, left+3):
min_pos = m
for n in range(m+1, left+3):
if a[n] < a[min_pos]:
min_pos = n
a[min_pos], a[m] = a[m], a[min_pos]

# 这样前三个元素有序了,中间元素就是取样元素
# 交换第二个元素和最后一个元素
a[left+1], a[right] = a[right], a[left+1]
pivot = a[right]
# 左指针从最左边元素开始,右指针从倒数第二个元素开始。考虑到要先i+=1和j-=1,所以应该预防性-1和+1
i, j = left-1, right

while 1:
# 左边指针一直移动到某个值不小于枢轴量的地方
i += 1
while a[i] < pivot:
i += 1

# 同样右边指针一直移动到某个值不大于枢轴量的地方
j -= 1
while a[j] > pivot:
j -= 1

# 检查左右指针是否已经相遇或者错过
if i >= j:
break
# 通过检查后交换两元素
a[i], a[j] = a[j], a[i]

# 将枢轴元素与i所处位置元素交换
a[right], a[i] = a[i], a[right]
# 返回已经放入正确位置的元素的位置
return i


def quicksort3(a: List[int], left: int, right: int):
# 边界检查,一个元素及以下的数组不用排序
if left >= right:
return None
j = partition_3_samples(a, left, right)
# 分治,继续处理左边和右边的排序子问题
quicksort3(a, left, j-1)
quicksort3(a, j+1, right)


a = [9, 7, 5, 6, 4, 3, 1, 2, 8]
print(f"排序前: {a}")
quicksort3(a, 0, len(a)-1)
print(f"排序后: {a}")
# 排序前: [9, 7, 5, 6, 4, 3, 1, 2, 8]
# 排序后: [1, 2, 3, 4, 5, 6, 7, 8, 9]

好极了,现在通过三采样切分移除了每次指针移动的边界检查,所付出的代价仅仅是对前3个元素进行排序,后者在3次比较和交换之后就可以完成。

熵最优的排序

实际应用种经常会出现含有大量重复元素的数组,比如按性别排序等。在这些情况下,上面的快速排序性能尚可,但还有巨大的改进空间,比如一个元素全部重复的子数组就不需要排序了,但在之前的实现种还会继续切分为更小的子数组并排序。在有大量重复元素的情况下,快速排序的递归性会使重复元素的子数组经常出现,这就有很大的改进潜力,将当前实现的线性对数级别的性能提升到线性级别(将\(\Theta(n\log n)\)提升到\(\Theta(n)\),这还是挺诱人的)。

一个简单的想法是将数组切分为三部分,分别对于小于、等于和大于切分元素的数组元素,这种方法也被称为三向切分。具体来说,将从左到右遍历数组一次,维护一个指针lt使得a[left: lt]中的元素都小于枢轴值,另一个指针gt使得a[gt+1: right]中的元素都大于枢轴值,一个指针i使得a[lt: i-1]中的元素都等于枢轴值,而a[i: gt]中的元素还未确定。一开始ileft相等,而ltgt分别在最左边和最右边。对a[i]有如下三种情况:

  • a[i]<pivot: 将a[lt]和a[i]交换,将lt和i加一
  • a[i]>pivot: 将a[gt]和a[i]交换,将gt减一
  • a[i]==pivot: 将i加一

i>gt时,循环结束,此时已经完成了三向切分。

具体实现如下(居然感觉三向切分的方法还比较简洁!):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def partition_3_way(a: List[int], left: int, right: int):
"""三向切分的实现"""
lt, gt = left, right
i = left
# 确定枢轴值,这边还是默认最左边的值
# 实际如果在进入快速排序前打乱数组选谁都有随机性
pivot = a[left]

while i <= gt:
if a[i] < pivot:
a[lt], a[i] = a[i], a[lt]
lt += 1
i += 1
elif a[i] > pivot:
a[gt], a[i] = a[i], a[gt]
gt -= 1
else:
i += 1

return lt, gt

def quicksort_3_way(a: List[int], left: int, right: int):
if left >= right:
return None

lt, gt = partition_3_way(a, left, right)
quicksort_3_way(a, left, lt-1)
quicksort_3_way(a, gt+1, right)

a = [3, 3, 5, 5, 6, 6, 1, 1, 1, 3]
print(f"排序前: {a}")
quicksort_3_way(a, 0, len(a)-1)
print(f"排序后: {a}")
# 排序前: [3, 3, 5, 5, 6, 6, 1, 1, 1, 3]
# 排序后: [1, 1, 1, 3, 3, 3, 5, 5, 6, 6]

对于存在大量重复元素的数组,这段代码可以避免很多不必要的子数组排序,因而比标准快速排序效率高得多。

当然,这段代码并不是最优的,在重复元素不多的情况下它比标准的二向切分多了很多次交换。有人找到了一个聪明的方法解决这个问题,即快速三向切分,这里不再继续深入。如果对这个快速三向切分有兴趣,可以参加《算法(第四版)》的练习2.3.22。

补充一下三向切分在存在重复主键时将带来线性复杂度:

命题 M。不存在任何基于比较的排序算法能够保证在 NH-N 次比较之内将 N 个元素排序,其中 H 为由主键值出现频率定义的香农信息量。

命题 N。对于大小为 N 的数组,三向切分的快速排序需要 ~(2ln2)NH 次比较。其中 H 为由主键 值出现频率定义的香农信息量。

这两个性质一起说明了三向切分是信息量最优的,即对于任意分布的输入,最优的基于比较的算法平均所需的比较次数和三向切分的快速排序平均所需的比较次数相互处于常数因子范围之内。对于包含大量重复元素的数组,它将排序时间从线性对数级别降低到了线性级别。因为包含大量重复元素数组的排序案例很常见,这使得三向切分的快速排序成为排序库函数的最佳算法选择。

快速排序的应用

我们简单地看一下牛客的两道题目,并使用快速排序的切分思想完成问题。

BM 46 最小的K个数

url:牛客 BM46

考察知识点:堆、快速排序

我们之前用堆做了这道题,现在用快速排序的切分思想来完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#
# 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
#
#
# @param input int整型一维数组
# @param k int整型
# @return int整型一维数组
#
class Solution:
def partition(self, nums: List[int], left: int, right: int):
i, j = left, right+1
pivot = nums[left]

while 1:
i += 1
while nums[i] < pivot:
if i == right:
break
i += 1

j -= 1
while nums[j] > pivot:
j -= 1

if i >= j:
break
nums[i], nums[j] = nums[j], nums[i]

# 交换第一个元素和right_point
nums[j], nums[left] = nums[left], nums[j]
return j

def sort_k(self, nums: List[int], left: int, right: int, k: int):
if left >= right:
return None
j = self.partition(nums, left, right)

if j > k:
# 左边需要精排
self.sort_k(nums, 0, j-1, k)
elif j < k:
# 右边排一下
self.sort_k(nums, j+1, right, k)
else:
# 相等就说明找到了
return None


def GetLeastNumbers_Solution(self , input: List[int], k: int) -> List[int]:
# write code here
self.sort_k(input, 0, len(input)-1, k)
return input[:k]

BM 47 寻找第K大

url:牛客 BM47

考察知识点:堆、快速排序

这个也是之前用堆解决过的问题,我们用快速排序实现一次。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#
# 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
#
#
# @param a int整型一维数组
# @param n int整型
# @param K int整型
# @return int整型
#
class Solution:
def partition(self, a: List[int], left: int, right: int):
i = left
j = right+1
pivot = a[left]

while 1:
i += 1
while a[i] < pivot:
if i == right:
break
i += 1

j -= 1
while a[j] > pivot:
j -= 1

if i >= j:
break
a[i], a[j] = a[j], a[i]

a[left], a[j] = a[j], a[left]
return j

def findKth(self , a: List[int], n: int, K: int) -> int:
# write code here
left, right = 0, n-1
# 寻找第K大就是寻找排序完成后数组的第n-K+1个元素
K = n - K
j = self.partition(a, left, right)
while j != K:
if j < K:
left = j+1
j = self.partition(a, left, right)
else:
right = j-1
j = self.partition(a, left, right)
print(j)
return a[j]

参考资料

  1. 《算法 (第4版)》
  2. 牛客编程题