开发者

Fast way to get N Min or Max elements from a list in Python

开发者 https://www.devze.com 2022-12-21 00:46 出处:网络
I currently have a long list which is being sorted using a lambda function f. I then choose a random element from the first five elements. Something like:

I currently have a long list which is being sorted using a lambda function f. I then choose a random element from the first five elements. Something like:

f = lambda x: some_function_of(x, local_variable)
my_list.sort(key=f)
foo = choice(my_list[:4])

This is a bottleneck in my program, according t开发者_如何学编程o the profiler. How can I speed things up? Is there a fast, inbuilt way to retrieve the elements I want (in theory shouldn't need to sort the whole list). Thanks.


Use heapq.nlargest or heapq.nsmallest.

For example:

import heapq

elements = heapq.nsmallest(4, my_list, key=f)
foo = choice(elements)

This will take O(N+KlogN) time (where K is the number of elements returned, and N is the list size), which is faster than O(NlogN) for normal sort when K is small relative to N.


It's actually possible in linear time (O(N)) on average.

You need a partition algorithm:

def partition(seq, pred, start=0, end=-1):
    if end == -1: end = len(seq)
    while True:
        while True:
            if start == end: return start
            if not pred(seq[start]): break
            start += 1
        while True:
            if pred(seq[end-1]): break
            end -= 1
            if start == end: return start
        seq[start], seq[end-1] = seq[end-1], seq[start]
        start += 1
        end -= 1

which can be used by an nth_element algorithm:

def nth_element(seq_in, n, key=lambda x:x):
    start, end = 0, len(seq_in)
    seq = [(x, key(x)) for x in seq_in]

    def partition_pred(x): return x[1] < seq[end-1][1]

    while start != end:
        pivot = (end + start) // 2
        seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
        pivot = partition(seq, partition_pred, start, end)
        seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
        if pivot == n: break
        if pivot < n: start = pivot + 1
        else: end = pivot

    seq_in[:] = (x for x, k in seq)

Given these, just replace your second (sort) line with:

nth_element(my_list, 4, key=f)
0

精彩评论

暂无评论...
验证码 换一张
取 消