从一道面试题到单调队列优化 DP

前情提要

一个朋友询问我一道某大厂的面试题,看到题目的第一眼想到了 O(n×k) 的动态规划算法来解决,朋友反馈说记得 nk 至少 105,后来发现只要维护最后 k 个元素的最大值即可,直接塞到 set 里面就能优化到 O(nlogn)O(nlogk),应该可以 cover 掉数据范围,朋友也对复杂度满意了。但我总觉得可以 O(n) 线性解决,在为朋友写代码时,突然隐隐约约觉得单调队列似乎就是这样子,去 OI-Wiki 查询果然是这样,遂写此篇题解博客记录不断思考优化算法复杂度的心路历程。

题目大意

给定长度为 n 一个数组,有正有负的数。问你在其中取哪些数,可以使得这些数的和最大。但是条件是,每 k 个连续的数,都至少要取一个出来。输入是nk、数组;输出是最大的这个和。

样例

输入

5 3
-4 -100 -9 -100 -4

输出

-9

题解

dp[i] 表示取了第 i 个数的情况下,取到的数的最大值。则状态转移方程为:
dp[i]=maxj[1,min(k,i)](dp[ij])+arr[i]
最终答案为:
maxj[0,k1](dp[nj])

直接按照数学公式计算即可得到答案,时间复杂度为 O(n×k)

我们注意到,我们需要一直维护数组中最后 k 个元素的最大值,这些最大值具有连续性(i 每次增加,都需要删掉一个元素,然后增加一个新元素),因此可以使用增加、删除以及查询最大值复杂度均为 O(logn) 的数据结构来维护。

C++ 中我们可以使用 STL 中基于红黑树的 set 容器来维护,因为容器中最多有 k 个元素,所以时间复杂度为 O(logk);同理,在 Python 中我们可以使用基于小根堆的 heapq 优先队列来维护,因为队列中最多有 n 个元素(元素均为非负数),所以时间复杂度为 O(logn)

我们注意到,在优先队列中,其实某些元素永远不会出队,而这些元素共同的性质便是比队头早入队 k 个或以上且比队头的元素小。那么我们能不能在当前队头元素入队时,将上述元素都踢出队列呢?优先队列显然是做不到的:因为上述元素一般集中在堆底。但如果我们尝试换成普通的队列呢?先不管如何查询最值,每次遇到一个元素,我们如果发现队尾元素比该元素更小,那么其实队尾元素在未来永远也用不到了(因为会比该元素早离开最后 k 个元素的范围),不断重复这个过程直到遇到一个比该元素大的队尾,此时再将该元素插入队尾——这时候让我们回过头来看,会惊奇地发现,整个队列中的元素竟然是单调递减的!那么想要查询最大值只要从队头开始找属于最后 k 个元素范围内的第一个元素即可。与此同时我们注意到,如果队头元素已经超出最后 k 个元素的范围,那么该元素未来也不可能会用到了,所以同样可以将符合这一条件的队头元素踢出队列。于是乎,我们总是可以直接查询队头元素来得到当前最后 k 个元素的最大值。因为这个过程中,每个元素只进队和出队一次,而查询最值也是 O(1) 的复杂度,因此总时间复杂度为 O(n)。这种允许在一端插入元素、两端都删除元素的数据结构叫做双端队列 deque,在 Python 中位于标准库 collections 中,在 C++ 中位于 STL 中。而本题中的双端队列永远保持单调递减,所以全程都具有这种单调性质的队列叫做“单调队列”,因而这种对于动态规划的优化方法叫做“单调队列优化”。

代码

  • O(n×k) 解法:

    def solve(arr, n, k):
        dp = [0]
        for i in range(1, n + 1):
            dp.append(max(dp[i - min(k, i):]) + arr[i - 1])
            # 也可以像下面这样写:
            # dp.append(dp[i - 1])
            # for j in range(2, min(k, i) + 1):
            #     dp[i] = max(dp[i], dp[i - j])
            # dp[i] += arr[i - 1]
        return max(dp[n - min(k, n - 1):])
    if __name__ == '__main__':
        n, k = map(int, input().split())
        arr = list(map(int, input().split()))
        print(solve(arr, n, k))
    
  • O(n×logn) 解法:

    import heapq
    def solve(arr, n, k):
        Q = [(0, 0)]
        for i in range(1, n + 1):
            heapq.heappush(Q, (Q[0][0] - arr[i - 1], i))
            while Q[0][1] <= i - k: heapq.heappop(Q)
        return -Q[0][0]
    if __name__ == '__main__':
        n, k = map(int, input().split())
        arr = list(map(int, input().split()))
        print(solve(arr, n, k))
    
  • O(n) 解法:
    保留 dp 数组:

    from collections import deque
    def solve(arr, n, k):
        dp = [0]
        q = deque([0])
        for i in range(1, n + 1):
            dp.append(dp[q[0]] + arr[i - 1])
            while q and dp[q[-1]] <= dp[i]: q.pop() # 将比当前值小的元素全部弹出
            q.append(i)
            while q[0] <= i - k: q.popleft() # 判断队首元素是否在窗口内
        return dp[q[0]]
    if __name__ == '__main__':
        n, k = map(int, input().split())
        arr = list(map(int, input().split()))
        print(solve(arr, n, k))
    

    不保留 dp 数组:

    from collections import deque
    def solve(arr, n, k):
        q = deque([(0, 0)])
        for i in range(1, n + 1):
            t = q[0][0] + arr[i - 1]
            while q and q[-1][0] <= t: q.pop() # 将比当前值小的元素全部弹出
            q.append((t, i))
            while q[0][1] <= i - k: q.popleft() # 判断队首元素是否在窗口内
        return q[0][0]
    if __name__ == '__main__':
        n, k = map(int, input().split())
        arr = list(map(int, input().split()))
        print(solve(arr, n, k))
    
喜欢 Issue Page
评论加载中...
Menu