前情提要
一个朋友询问我一道某大厂的面试题,看到题目的第一眼想到了 O(n×k) 的动态规划算法来解决,朋友反馈说记得 n 和 k 至少 105,后来发现只要维护最后 k 个元素的最大值即可,直接塞到 set
里面就能优化到 O(nlogn) 或 O(nlogk),应该可以 cover 掉数据范围,朋友也对复杂度满意了。但我总觉得可以 O(n) 线性解决,在为朋友写代码时,突然隐隐约约觉得单调队列似乎就是这样子,去 OI-Wiki 查询果然是这样,遂写此篇题解博客记录不断思考优化算法复杂度的心路历程。
题目大意
给定长度为 n 一个数组,有正有负的数。问你在其中取哪些数,可以使得这些数的和最大。但是条件是,每 k 个连续的数,都至少要取一个出来。输入是n、k、数组;输出是最大的这个和。
样例
输入
5 3
-4 -100 -9 -100 -4
输出
-9
题解
令 dp[i] 表示取了第 i 个数的情况下,取到的数的最大值。则状态转移方程为:
dp[i]=maxj∈[1,min(k,i)](dp[i−j])+arr[i]
最终答案为:
maxj∈[0,k−1](dp[n−j])
直接按照数学公式计算即可得到答案,时间复杂度为 O(n×k)。
我们注意到,我们需要一直维护数组中最后 k 个元素的最大值,这些最大值具有连续性(i 每次增加,都需要删掉一个元素,然后增加一个新元素),因此可以使用增加、删除以及查询最大值复杂度均为 O(logn) 的数据结构来维护。
在 C++
中我们可以使用 STL
中基于红黑树的 set
容器来维护,因为容器中最多有 k 个元素,所以时间复杂度为 O(logk);同理,在 Python
中我们可以使用基于小根堆的 heapq
优先队列来维护,因为队列中最多有 n 个元素(元素均为非负数),所以时间复杂度为 O(logn)。
我们注意到,在优先队列中,其实某些元素永远不会出队,而这些元素共同的性质便是比队头早入队 k 个或以上且比队头的元素小。那么我们能不能在当前队头元素入队时,将上述元素都踢出队列呢?优先队列显然是做不到的:因为上述元素一般集中在堆底。但如果我们尝试换成普通的队列呢?先不管如何查询最值,每次遇到一个元素,我们如果发现队尾元素比该元素更小,那么其实队尾元素在未来永远也用不到了(因为会比该元素早离开最后 k 个元素的范围),不断重复这个过程直到遇到一个比该元素大的队尾,此时再将该元素插入队尾——这时候让我们回过头来看,会惊奇地发现,整个队列中的元素竟然是单调递减的!那么想要查询最大值只要从队头开始找属于最后 k 个元素范围内的第一个元素即可。与此同时我们注意到,如果队头元素已经超出最后 k 个元素的范围,那么该元素未来也不可能会用到了,所以同样可以将符合这一条件的队头元素踢出队列。于是乎,我们总是可以直接查询队头元素来得到当前最后 k 个元素的最大值。因为这个过程中,每个元素只进队和出队一次,而查询最值也是 O(1) 的复杂度,因此总时间复杂度为 O(n)。这种允许在一端插入元素、两端都删除元素的数据结构叫做双端队列 deque
,在 Python
中位于标准库 collections
中,在 C++
中位于 STL
中。而本题中的双端队列永远保持单调递减,所以全程都具有这种单调性质的队列叫做“单调队列”,因而这种对于动态规划的优化方法叫做“单调队列优化”。
代码
O(n×k) 解法:
def solve(arr, n, k): dp = [0] for i in range(1, n + 1): dp.append(max(dp[i - min(k, i):]) + arr[i - 1]) # 也可以像下面这样写: # dp.append(dp[i - 1]) # for j in range(2, min(k, i) + 1): # dp[i] = max(dp[i], dp[i - j]) # dp[i] += arr[i - 1] return max(dp[n - min(k, n - 1):]) if __name__ == '__main__': n, k = map(int, input().split()) arr = list(map(int, input().split())) print(solve(arr, n, k))
O(n×logn) 解法:
import heapq def solve(arr, n, k): Q = [(0, 0)] for i in range(1, n + 1): heapq.heappush(Q, (Q[0][0] - arr[i - 1], i)) while Q[0][1] <= i - k: heapq.heappop(Q) return -Q[0][0] if __name__ == '__main__': n, k = map(int, input().split()) arr = list(map(int, input().split())) print(solve(arr, n, k))
O(n) 解法:
保留 dp 数组:from collections import deque def solve(arr, n, k): dp = [0] q = deque([0]) for i in range(1, n + 1): dp.append(dp[q[0]] + arr[i - 1]) while q and dp[q[-1]] <= dp[i]: q.pop() # 将比当前值小的元素全部弹出 q.append(i) while q[0] <= i - k: q.popleft() # 判断队首元素是否在窗口内 return dp[q[0]] if __name__ == '__main__': n, k = map(int, input().split()) arr = list(map(int, input().split())) print(solve(arr, n, k))
不保留 dp 数组:
from collections import deque def solve(arr, n, k): q = deque([(0, 0)]) for i in range(1, n + 1): t = q[0][0] + arr[i - 1] while q and q[-1][0] <= t: q.pop() # 将比当前值小的元素全部弹出 q.append((t, i)) while q[0][1] <= i - k: q.popleft() # 判断队首元素是否在窗口内 return q[0][0] if __name__ == '__main__': n, k = map(int, input().split()) arr = list(map(int, input().split())) print(solve(arr, n, k))