我怎样才能检查一个二维NumPy数组里面是否包含一个特定的数值模式?
我有一个大的NumPy.array
field_array
和一个较小的数组match_array
,都包含int
值。 使用以下示例,如何检查field_array的任何match_array形状段field_array
包含与field_array
中的值完全对应的值?
import numpy raw_field = ( 24, 25, 26, 27, 28, 29, 30, 31, 23, \ 33, 34, 35, 36, 37, 38, 39, 40, 32, \ -39, -38, -37, -36, -35, -34, -33, -32, -40, \ -30, -29, -28, -27, -26, -25, -24, -23, -31, \ -21, -20, -19, -18, -17, -16, -15, -14, -22, \ -12, -11, -10, -9, -8, -7, -6, -5, -13, \ -3, -2, -1, 0, 1, 2, 3, 4, -4, \ 6, 7, 8, 4, 5, 6, 7, 13, 5, \ 15, 16, 17, 8, 9, 10, 11, 22, 14) field_array = numpy.array(raw_field, int).reshape(9,9) match_array = numpy.arange(12).reshape(3,4)
这些例子应该返回True
因为match_array
所描述的模式与[6:9,3:7]
alignment。
方法#1
这种方法来源于a solution
, Implement Matlab's im2col 'sliding' in python
中Implement Matlab's im2col 'sliding' in python
, a solution
旨在将rearrange sliding blocks from a 2D array into columns
。 因此,为了解决我们这里的情况,那些来自field_array
滑动块可以堆积为列,并与match_array
列向量版本进行match_array
。
这里是重新排列/堆叠函数的正式定义 –
def im2col(A,BLKSZ): # Parameters M,N = A.shape col_extent = N - BLKSZ[1] + 1 row_extent = M - BLKSZ[0] + 1 # Get Starting block indices start_idx = np.arange(BLKSZ[0])[:,None]*N + np.arange(BLKSZ[1]) # Get offsetted indices across the height and width of input array offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent) # Get all actual indices & index into input array for final output return np.take (A,start_idx.ravel()[:,None] + offset_idx.ravel())
为了解决我们的情况,这里是基于im2col
的实现 –
# Get sliding blocks of shape same as match_array from field_array into columns # Then, compare them with a column vector version of match array. col_match = im2col(field_array,match_array.shape) == match_array.ravel()[:,None] # Shape of output array that has field_array compared against a sliding match_array out_shape = np.asarray(field_array.shape) - np.asarray(match_array.shape) + 1 # Now, see if all elements in a column are ONES and reshape to out_shape. # Finally, find the position of TRUE indices R,C = np.where(col_match.all(0).reshape(out_shape))
问题中给定样本的输出将是 –
In [151]: R,C Out[151]: (array([6]), array([3]))
方法#2
鉴于opencv已经有模板匹配function,可以做差异的平方,你可以使用它,并寻找零差异,这将是你的匹配位置。 所以,如果你有权访问cv2(opencv模块),实现将看起来像这样 –
import cv2 from cv2 import matchTemplate as cv2m M = cv2m(field_array.astype('uint8'),match_array.astype('uint8'),cv2.TM_SQDIFF) R,C = np.where(M==0)
给我们 –
In [204]: R,C Out[204]: (array([6]), array([3]))
标杆
本节比较所有build议解决问题的方法的运行时间。 本节列出的各种方法的功劳归功于他们的贡献者。
方法定义 –
def seek_array(search_in, search_for, return_coords = False): si_x, si_y = search_in.shape sf_x, sf_y = search_for.shape for y in xrange(si_y-sf_y+1): for x in xrange(si_x-sf_x+1): if numpy.array_equal(search_for, search_in[x:x+sf_x, y:y+sf_y]): return (x,y) if return_coords else True return None if return_coords else False def skimage_based(field_array,match_array): windows = view_as_windows(field_array, match_array.shape) return (windows == match_array).all(axis=(2,3)).nonzero() def im2col_based(field_array,match_array): col_match = im2col(field_array,match_array.shape)==match_array.ravel()[:,None] out_shape = np.asarray(field_array.shape) - np.asarray(match_array.shape) + 1 return np.where(col_match.all(0).reshape(out_shape)) def cv2_based(field_array,match_array): M = cv2m(field_array.astype('uint8'),match_array.astype('uint8'),cv2.TM_SQDIFF) return np.where(M==0)
运行时testing –
案例#1(来自问题的示例数据):
In [11]: field_array Out[11]: array([[ 24, 25, 26, 27, 28, 29, 30, 31, 23], [ 33, 34, 35, 36, 37, 38, 39, 40, 32], [-39, -38, -37, -36, -35, -34, -33, -32, -40], [-30, -29, -28, -27, -26, -25, -24, -23, -31], [-21, -20, -19, -18, -17, -16, -15, -14, -22], [-12, -11, -10, -9, -8, -7, -6, -5, -13], [ -3, -2, -1, 0, 1, 2, 3, 4, -4], [ 6, 7, 8, 4, 5, 6, 7, 13, 5], [ 15, 16, 17, 8, 9, 10, 11, 22, 14]]) In [12]: match_array Out[12]: array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) In [13]: %timeit seek_array(field_array, match_array, return_coords = False) 1000 loops, best of 3: 465 µs per loop In [14]: %timeit skimage_based(field_array,match_array) 10000 loops, best of 3: 97.9 µs per loop In [15]: %timeit im2col_based(field_array,match_array) 10000 loops, best of 3: 74.3 µs per loop In [16]: %timeit cv2_based(field_array,match_array) 10000 loops, best of 3: 30 µs per loop
案例#2(更大的随机数据):
In [17]: field_array = np.random.randint(0,4,(256,256)) In [18]: match_array = field_array[100:116,100:116].copy() In [19]: %timeit seek_array(field_array, match_array, return_coords = False) 1 loops, best of 3: 400 ms per loop In [20]: %timeit skimage_based(field_array,match_array) 10 loops, best of 3: 54.3 ms per loop In [21]: %timeit im2col_based(field_array,match_array) 10 loops, best of 3: 125 ms per loop In [22]: %timeit cv2_based(field_array,match_array) 100 loops, best of 3: 4.08 ms per loop
NumPy中没有内置的searchfunction,但在NumPy中可以做到这一点
只要你的数组不是太大 *,你可以使用滚动窗口的方法:
from skimage.util import view_as_windows windows = view_as_windows(field_array, match_array.shape)
view_as_windows
函数view_as_windows
是用NumPy编写的,所以如果你没有使用skimage你总是可以从这里复制代码。
然后为了查看子数组是否出现在较大的数组中,可以这样写:
>>> (windows == match_array).all(axis=(2,3)).any() True
要find子arrays左上angular的匹配索引,可以这样写:
>>> (windows == match_array).all(axis=(2,3)).nonzero() (array([6]), array([3]))
这种方法也适用于更高维数组。
*虽然数组windows
占用额外的内存(只改变步幅和形状来创build新的数据视图),但是写入windows == match_array
会创build一个大小为( windows == match_array
的布尔数组, 504字节的内存。 如果你正在处理非常大的数组,这种方法可能是不可行的。
一种解决scheme是一次search整个search_in
数组(一个'block'是一个search_for
shaped切片),直到find匹配的数据段或者search_for
数组耗尽。 我可以用它来获得匹配块的坐标,或者只是一个bool
结果发送True
或False
的return_coords
可选参数…
def seek_array(search_in, search_for, return_coords = False): """Searches for a contiguous instance of a 2d array `search_for` within a larger `search_in` 2d array. If the optional argument return_coords is True, the xy coordinates of the zeroeth value of the first matching segment of search_in will be returned, or None if there is no matching segment. If return_coords is False, a boolean will be returned. * Both arrays must be sent as two-dimensional!""" si_x, si_y = search_in.shape sf_x, sf_y = search_for.shape for y in xrange(si_y-sf_y+1): for x in xrange(si_x-sf_x+1): if numpy.array_equal(search_for, search_in[x:x+sf_x, y:y+sf_y]): return (x,y) if return_coords else True # don't forget that coordinates are transposed when viewing NumPy arrays! return None if return_coords else False
我不知道NumPy
是否还没有一个可以做同样事情的函数,尽pipe…
要添加已经发布的答案,我想添加一个考虑到浮点精度的错误,以防matrix来自例如image processing,例如数字受浮点运算影响。
您可以recursion更大matrix的索引,search更小的matrix。 然后,您可以提取匹配较小matrix大小的较大matrix的子matrix。
如果两者的内容,“大”和“小”matrix的子matrix匹配,则匹配。
以下示例显示如何返回发现匹配的大matrix中位置的第一个索引。 如果这个意图是扩展这个函数来返回一个匹配的数组,那将是微不足道的。
import numpy as np def find_submatrix(a, b): """ Searches the first instance at which 'b' is a submatrix of 'a', iterates rows first. Returns the indexes of a at which 'b' was found, or None if 'b' is not contained within 'a'""" a_rows=a.shape[0] a_cols=a.shape[1] b_rows=b.shape[0] b_cols=b.shape[1] row_diff = a_rows - b_rows col_diff = a_cols - b_cols for idx_row in np.arange(row_diff): for idx_col in np.arange(col_diff): row_indexes = [idx + idx_row for idx in np.arange(b_rows)] col_indexes = [idx + idx_col for idx in np.arange(b_cols)] submatrix_indexes = np.ix_(row_indexes, col_indexes) a_submatrix = a[submatrix_indexes] are_equal = np.allclose(a_submatrix, b) # allclose is used for floating point numbers, if they # are close while comparing, they are considered equal. # Useful if your matrices come from operations that produce # floating point numbers. # You might want to fine tune the parameters to allclose() if (are_equal): return[idx_col, idx_row] return None
使用上面的函数可以运行下面的例子:
large_mtx = np.array([[1, 2, 3, 7, 4, 2, 6], [4, 5, 6, 2, 1, 3, 11], [10, 4, 2, 1, 3, 7, 6], [4, 2, 1, 3, 7, 6, -3], [5, 6, 2, 1, 3, 11, -1], [0, 0, -1, 5, 4, -1, 2], [10, 4, 2, 1, 3, 7, 6], [10, 4, 2, 1, 3, 7, 6] ]) # Example 1: An intersection at column 2 and row 1 of large_mtx small_mtx_1 = np.array([[4, 2], [2,1]]) intersect = find_submatrix(large_mtx, small_mtx_1) print "Example 1, intersection (col,row): " + str(intersect) # Example 2: No intersection small_mtx_2 = np.array([[-14, 2], [2,1]]) intersect = find_submatrix(large_mtx, small_mtx_2) print "Example 2, intersection (col,row): " + str(intersect)
哪个会打印:
例1,交点:[1,2] 例2,交点:无