开发者

How to select not only the maximum of a `numpy.ndarray` but the top 3 maximal values in python?

开发者 https://www.devze.com 2023-03-11 06:40 出处:网络
I have a list of float values (positive and negative ones) stored in a variable row of type <type \'numpy.ndarray\'>.

I have a list of float values (positive and negative ones) stored in a variable row of type <type 'numpy.ndarray'>.

max_value = max(row)

gives me the maximal value of row. Is there an elegant way to select the top 3 (5, 10,...) values?

I came up with

  1. selecting the maximum value from row
  2. deleting the maximal value in row
  3. selecting the maximum value from row
  4. de开发者_高级运维leting the maximal value in row
  5. and so on

But that's certainly an ugly style and not pythonic at all. What do the pythonistas say to that? :)


Edit

I need not only the maximal three values, bit there position (index in row), too. Sorry, I forgot to mention that...


I would use np.argsort

a = np.arange(10)
a[np.argsort(a)[-3:]]

EDIT To also get the position, just use:

ii = np.argsort(a)[-3:] # positions
vals = a[ii]            # values


Why not just sort the numpy array and then read off the values you need:

In [33]: np.sort(np.array([1,5,4,6,7,2,3,9]))[-3:]
Out[33]: array([6, 7, 9])

EDIT: seeing as the question has now changed and you need the positions as well as values, use numpy.argsort to obtain the indices instead of values:

In [43]: a=np.array([1,5,4,6,7,2,3,9])

In [44]: idx=np.argsort(a)

In [45]: topvals=idx[-3:]

In [46]: print topvals
[3 4 7]

In [47]: print a[topvals]
[6 7 9]


This ugly trick is somewhat faster than argsort()[-3:], at least in numpy 1.5.1 on my old mac ppc.
argpartsort in Bottleneck, some NumPy array functions written in Cython, would be waaay faster.

#!/bin/sh

python -mtimeit -s '
import numpy as np

def max3( A ):
   j = A.argmax();  aj = A[j];  A[j] = - np.inf
   j2 = A.argmax();  aj2 = A[j2];  A[j2] = - np.inf
   j3 = A.argmax()
   A[j] = aj
   A[j2] = aj2
   return [j, j2, j3]

N = '${N-1e6}'
A = np.arange(N)
' '
j3 = A.argsort()[-3:]   # N 1e6: 405 msec per loop
# j3 = max3( A )        # N 1e6: 105 msec per loop
'
0

精彩评论

暂无评论...
验证码 换一张
取 消

关注公众号