numpy数组的高效阈值filter
我需要过滤一个数组来删除低于特定阈值的元素。 我目前的代码是这样的:
threshold = 5 a = numpy.array(range(10)) # testing data b = numpy.array(filter(lambda x: x >= threshold, a))
问题是,这将创build一个临时列表,使用lambda函数(慢)的filter。
由于这是一个相当简单的操作,也许有一个numpy函数以高效的方式执行,但我一直无法find它。
我认为,另一种方法来实现这一点可能是sorting数组,find阈值的索引,并从该索引开始返回一个切片,但即使这将是更快的小input(它不会明显无论如何),随着input规模的增长,其确实渐近地变得不太有效。
有任何想法吗? 谢谢!
更新 :我也进行了一些测量,当input是100.000.000条目时,sorting+分片仍然比纯pythonfilter快两倍。
In [321]: r = numpy.random.uniform(0, 1, 100000000) In [322]: %timeit test1(r) # filter 1 loops, best of 3: 21.3 s per loop In [323]: %timeit test2(r) # sort and slice 1 loops, best of 3: 11.1 s per loop In [324]: %timeit test3(r) # boolean indexing 1 loops, best of 3: 1.26 s per loop
b = a[a>threshold]
应该这样做
我testing如下:
import numpy as np, datetime # array of zeros and ones interleaved lrg = np.arange(2).reshape((2,-1)).repeat(1000000,-1).flatten() t0 = datetime.datetime.now() flt = lrg[lrg==0] print datetime.datetime.now() - t0 t0 = datetime.datetime.now() flt = np.array(filter(lambda x:x==0, lrg)) print datetime.datetime.now() - t0
我有
$ python test.py 0:00:00.028000 0:00:02.461000
http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays
添加到
@yosukesabai
答案,它的重要使用不同的variables,因为这将返回一个空的数组:
im=im[im>167]
不能解释为什么虽然,也许是因为我太累了,以为:(