程序设计#5 - 二分查找

13 min

如果要在一个有序数组里寻找第一个大于等于 target 的元素下标,最直觉的写法就是遍历,但是这样的复杂度是 O(n),在大数据下较劣。原因是,没有利用到数组有序这个性质。

二分查找(binary search)是一种能在有序数组中快速找到指定元素的算法。它每次将搜索范围减小一半,因此非常高效。时间复杂度通常为 O(log n)

但这么高效的算法是陷阱密布的。接下来我想用尽可能简单的语言避开二分查找的每一个坑点。

在这之前必须了解几个常用的库函数,它们都在 <algorithm> 头文件中。

vector<int> v;

/* 一些初始化代码 */

auto i = lower_bound(v.begin(), v.end(), elem);
auto j = upper_bound(v.begin(), v.end(), elem);

也可以是

auto i = ranges::lower_bound(v, elem);
auto j = ranges::upper_bound(v, elem);

或者

auto [i, j] = ranges::equal_range(v, elem);

其中 ij 会返回一个迭代器,分别代表数组中第一个大于或等于 elem 的位置和第一个严格大于 elem 的位置。

我们希望自己实现一个二分查找,因此这里仅作了解,下面是正式部分。

红蓝染色法

我们在本篇所说的二分查找都是在一个非严格递增顺序的数组中寻找大于等于 target 的数,简单记为 lowerbound()

二分查找的写法按照区间来划分可以有三种:闭区间左闭右开区间开区间

所谓闭区间,意思是在 [l, r] 这个下标区间内,所有的元素都无法确定大小。左闭右开区间是 [l, r),开区间则是 (l, r)。这里先说闭区间。

标题中的红蓝染色法是一种二分查找常用的标记元素的方法。先看下面的例子。

LC34在排序数组中查找元素的第一个和最后一个位置

给你一个按照非递减顺序排列的整数数组 nums,和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。

如果数组中不存在目标值 target,返回 [-1, -1]

你必须设计并实现时间复杂度为 O(log n) 的算法解决此问题。

这是原题给的样例。

输入:nums = [5,7,7,8,8,10], target = 8
输出:[3,4]

红蓝染色法简单来说,就是将所有严格小于 target 的部分染成红色块,大于或等于 target 的部分染成蓝色块,这样最终第一个蓝色块就是 lowerbound() 的返回值了。

下面是这个流程的图示。

_binary_closed

这就是我们所说的红蓝染色法。

同时闭区间 [l, r] 以外的元素都是已经确定好会被排除的元素。**换言之,我们写闭区间的二分查找,本质上就是保证 [r+1, nums.size()] 中的元素大于等于 target,而 [0,l-1] 中的元素小于 target。**上一句话很重要,这就是确保二分查找不出现死循环的关键。

(几乎所有死循环的二分查找代码,都是因为没有保证区间从头到尾都不变。其中有一些闭区间写着写着变成左开右闭区间了)

同时注意,C++ 中如果 nums 的长度到了 int 类型最大值,使用 m = (l+r) / 2 就会导致溢出,解决方案是换成 m = l + (r - l) / 2

闭区间写法

class Solution {
    int lowerbound(vector<int>& nums, int target) {
        int l = 0, r = nums.size() - 1;
        while (l <= r) { // 保证区间不为空
            int m = l + (r - l) / 2;     // 防止溢出,Python 可以写为 m = (r + l) // 2
            if (nums[m] >= target) { // 大于或等于 target,更新 r
                r = m - 1;
            } else { // 严格小于 target,更新 l
                l = m + 1;
            }
        }
        return l;
    }

public:
    vector<int> searchRange(vector<int>& nums, int target) {
        // 实际上就是返回第一个大于等于 target 的元素对应的下标
        // 还有第一个严格大于 target 的元素对应的下标减去 1
        int start = lowerbound(nums, target);
        if (start == nums.size() || nums[start] != target) { // 前者为了防空数组,后者防数组中不存在 target
            return {-1, -1};
        }
        int end = lowerbound(nums, target + 1) - 1;
        return {start, end};
    }
};

nums 中寻找最后一个大于等于 target 元素对应的下标,等价于寻找第一个严格大于 target 的元素对应的下标减去 1。

以此类推,可以写出一张针对不同情况的二分查找调用方式。

需求调用方式……不存在
>= x 的第一个元素的下标lowerbound(nums, x)nums.size()
> x 的第一个元素的下标lowerbound(nums, x+1)nums.size()
<= x 的最后一个元素的下标lowerbound(nums, x+1)-1-1
< x 的最后一个元素的下标lowerbound(nums, x)-1-1

当然,除了闭区间以外,还有两种写法:左闭右开区间开区间。这三种写法实际使用没有区别,看更喜欢哪一种写法就好。

左闭右开区间写法

int lowerbound_1(vector<int>& nums, int target) { // 左闭右开区间 [l,r) 写法
        int l = 0, r = nums.size();
        while (l < r) {
            int m = l + (r - l) / 2;
            if (nums[m] < target) {
                l = m + 1; // 继续二分 [m+1, r)
            } else {
                r = m; // 继续二分 [l, m)
            }
        }
        // 最后 l 和 r 会重叠,此时区间 [l, r) 内没有元素,循环结束
        // 返回 l 或 r 都可以
        return r;
    }

开区间写法

因为到最后区间 (l, r) 中没有元素,所以在二分时要确保 lr 中间至少隔了一位,也就是 l+1 < r

int lowerbound_2(vector<int>& nums, int target) { // 开区间 (l,r) 写法
        int l = -1, r = nums.size();
        while (l + 1 < r) {
            int m = l + (r - l) / 2;
            if (nums[m] < target) {
                l = m; // 继续二分 (m, r)
            } else {
                r = m; // 继续二分 (l, m)
            }
        }
        return r;
    }

LC2529正整数和负整数的最大计数就可以转化为寻找 0 的起始点和结束点。

class Solution {
    int lowerbound(vector<int>& nums, int target) { // 闭区间 [l,r] 写法
        int l = 0, r = nums.size() - 1;
        while (l <= r) {             // 保证区间不为空
            int m = l + (r - l) / 2; // 防止溢出,Python 可以写为 m = (r + l) // 2
            if (nums[m] >= target) { // 大于或等于 target,更新 r
                r = m - 1;           // 继续二分 [l, m-1]
            } else {                 // 严格小于 target,更新 l
                l = m + 1;           // 继续二分 [m+1, r]
            }
        }
        return l;
    }

    int max(int a, int b) {
        if (a > b) {
            return a;
        }
        return b;
    }

public:
    int maximumCount(vector<int>& nums) {
        int l = lowerbound(nums, 0);
        int r = lowerbound(nums, 1);
        return max(l, nums.size()-r);
    }
};

二分查找的本质在于用 O(log n) 的时间换一个新的条件,而且因为二分查找只能在有序数组里面用,因此常常和快速排序(时间复杂度为 O(n log n))结合起来。

看一道先排序再二分答案的题目。

LC1385两个数组间的距离值

给你两个整数数组 arr1arr2 和一个整数 d ,请你返回两个数组之间的 距离值

距离值」 定义为符合此距离要求的元素数目:对于元素 arr1[i] ,不存在任何元素 arr2[j] 满足 |arr1[i]-arr2[j]| <= d

假如 arr2 是有序的,那么对于 arr1 中的每一个元素 arr1[i],只要证明 arr2 中没有区间 [arr1[i]-d, arr1[i]+d] 中的元素就可以了。

这就是所说的排序+二分查找,代码不难,直接给出。

class Solution {
    int lowerbound(vector<int>& nums, int target) {
        int l = 0, r = nums.size();
        while (l < r) {
            int m = l + (r - l) / 2;
            if (nums[m] >= target) {
                r = m;
            } else {
                l = m + 1;
            }
        }
        return l;
    }

public:
    int findTheDistanceValue(vector<int>& arr1, vector<int>& arr2, int d) {
        int res = 0;
        sort(arr2.begin(), arr2.end()); // C++ 排序库函数(快速排序),在头文件 <algorithm> 中
        for (int i : arr1) {
            int k = lowerbound(arr2, i - d);
            if (k == arr2.size() || arr2[k] > i + d) {
                res++;
            }
        }
        return res;
    }
};

二分查找也可以结合起前缀和。

前缀和(全称前缀和数组)是一种算法技巧,可以快速求出原数组的任意子数组和。对于前缀和数组 pre 和原数组 nums,有这样的性质:

pre[i] - pre[j] == nums[j] + nums[j+1] + ... + nums[i]

其中 i >= j。容易发现 pre[0] == 0pre.size() - nums.size() == 1

C++ 标准库 <numeric> 提供了计算前缀和的函数 partial_sum()。但计算得到的结果第一个元素不是 0 而是原数组的第一个元素。

LC2389和有限的最长子序列

给你一个长度为 n 的整数数组 nums ,和一个长度为 m 的整数数组 queries

返回一个长度为 m 的数组 answer ,其中 answer[i]nums 中 元素之和小于等于 queries[i]子序列最大 长度 。

子序列是指从一个原始序列中通过去除某些元素而不改变剩余元素的相对顺序所形成的新序列。子序列在原数组中不一定是连续的。(PS:这一段我改了,力扣原文写得不知所云)

因为是求子序列和,所以完全不必在意按什么顺序挑选元素。也就是说,挑选出原数组中最小的几个元素,让它们求和恰好不大于 queries[i] 就可以了。

class Solution {
public:
    vector<int> answerQueries(vector<int>& nums, vector<int>& queries) {
        ranges::sort(nums); // 先排序,准备计算前缀和
        partial_sum(nums.begin(), nums.end(), nums.begin()); // 直接存到原数组里,节约空间,下同
        for (int& q : queries) {
            q = ranges::upper_bound(nums, q) - nums.begin(); // upper_bound 求的是第一个严格大于 q 的数 k,说明 k-1 一定是不大于 q 的。
            // 并且,下标从 0 开始,二分查找得到的结果恰好就是子序列长度
        }
        return queries;
    }
};

二分查找给我写麻了,一想到后面还有二分答案最长递增子序列这两位神仙就头大,就写到这了。🤪