Maintaining sum of k largest items in a dynamic set
Maintain a set, allowing to add or remove elements and to query the sum of the up to k largest items.
Use two heaps
The idea is to use 2 heaps. We will use a min-heap ‘large’, where the top element is its smallest item. As well as a max-heap ‘small’, where the top element is its largest item.
large: small:
[800, 400, 350] [200, 100]
^top ^top
<-----k------->
The invariant is that if the set contains less than k items, then it is entirely stored in ‘large’, while ‘small’ is empty. Otherwise ‘large’ stores the k largest items of the set, and ‘small’ all the others.
Maintaining the invariant, simply consists to move items from one heap to another, if the cardinality of the large set is not k.
Application
In the above mentioned problem, we are given n weighted intervals, and need to find a set of up to k intervals, all intersecting, maximizing the total weight of this set. This can be solved by a sweep line algorithm. Just scan the intervals from left to right, adding or removing weights to a dynamic set, when the endpoints of the corresponding intervals are processed. Fairly easy. Hence overall time complexity is O(n log n).
Implementation in Python
Here we use our implementation of lazy heaps, explained here.
from sys import stdin
from collections import Counter
from heapq import *
def readint(): return int(stdin.readline())
def readints(): return list(map(int, stdin.readline().strip().split()))
def readstr(): return stdin.readline().strip()
class dynset:
"""Maintains a multiset and keeps track of the sum of its k largest elements.
large is the minheap containing the up to k largest elements.
small is the maxheap containing the smaller elements.
in fact we use a minheap but invert the values.
"""
def __init__(self, k):
self.k = k
self.large = lazyheap()
self.small = lazyheap()
def balance(self):
"""maintains invariant on heap sizes
"""
if self.large.n > self.k:
self.small.push(-self.large.pop())
if self.large.n < self.k and self.small.n > 0:
self.large.push(-self.small.pop())
def add(self, value):
# negate value to make a maxheap
self.large.push(value)
self.balance()
def remove(self, value):
# from which heap should we remove?
if value < self.large.top():
self.small.remove(-value)
else:
self.large.remove(value)
self.balance()
def solve(k, n, h, s, e):
D = dynset(k)
"""scan time line by processing an event list.
+hi means we enter interval i, -hi means we leave it.
"""
events = [(s[i], +h[i]) for i in range(n)]
events += [(e[i] + 1, -h[i]) for i in range(n)]
best = 0
for (time, delta) in sorted(events):
if delta < 0:
D.remove(-delta)
else:
D.add(delta)
best = max(best, D.large.sum)
return best
for test in range(readint()):
d, n, k = readints()
h = [0] * n
s = [0] * n
e = [0] * n
for i in range(n):
h[i], s[i], e[i] = readints()
print('Case #%d: %s' % (test + 1, solve(k, n, h, s, e)))