I currently have a long list which is being sorted using a lambda function f. I then choose a random element from the first five elements. Something like:
f = lambda x: some_function_of(x, local_variable)
my_list.sort(key=f)
foo = choice(my_list[:4])
This is a bottleneck in my program, according t开发者_如何学编程o the profiler. How can I speed things up? Is there a fast, inbuilt way to retrieve the elements I want (in theory shouldn't need to sort the whole list). Thanks.
Use heapq.nlargest
or heapq.nsmallest
.
For example:
import heapq
elements = heapq.nsmallest(4, my_list, key=f)
foo = choice(elements)
This will take O(N+KlogN) time (where K is the number of elements returned, and N is the list size), which is faster than O(NlogN) for normal sort when K is small relative to N.
It's actually possible in linear time (O(N)) on average.
You need a partition algorithm:
def partition(seq, pred, start=0, end=-1):
if end == -1: end = len(seq)
while True:
while True:
if start == end: return start
if not pred(seq[start]): break
start += 1
while True:
if pred(seq[end-1]): break
end -= 1
if start == end: return start
seq[start], seq[end-1] = seq[end-1], seq[start]
start += 1
end -= 1
which can be used by an nth_element algorithm:
def nth_element(seq_in, n, key=lambda x:x):
start, end = 0, len(seq_in)
seq = [(x, key(x)) for x in seq_in]
def partition_pred(x): return x[1] < seq[end-1][1]
while start != end:
pivot = (end + start) // 2
seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
pivot = partition(seq, partition_pred, start, end)
seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
if pivot == n: break
if pivot < n: start = pivot + 1
else: end = pivot
seq_in[:] = (x for x, k in seq)
Given these, just replace your second (sort) line with:
nth_element(my_list, 4, key=f)
精彩评论