用(n-1)d数组索引n维数组
什么是最优雅的方式来访问一个n维数组与一个给定维度(n-1)维数组,如虚拟示例
a = np.random.random_sample((3,4,4)) b = np.random.random_sample((3,4,4)) idx = np.argmax(a, axis=0)
我现在如何访问idx a
来获取最大值,就好像我已经使用了a.max(axis=0)
? 或者如何检索由idx
中的idx
指定的值?
我想过使用np.meshgrid
但我认为这是一个矫枉过正的问题。 请注意,维度axis
可以是任何有用的轴(0,1,2),并且事先不知道。 有没有一个优雅的方式来做到这一点?
利用advanced-indexing
–
m,n = a.shape[1:] I,J = np.ogrid[:m,:n] a_max_values = a[idx, I, J] b_max_values = b[idx, I, J]
一般情况下:
def argmax_to_max(arr, argmax, axis): """argmax_to_max(arr, arr.argmax(axis), axis) == arr.max(axis)""" new_shape = list(arr.shape) del new_shape[axis] grid = np.ogrid[tuple(map(slice, new_shape))] grid.insert(axis, argmax) return arr[tuple(grid)]
不幸的是,这样的自然操作应该会更尴尬。
为了用一个(n-1) dim
数组索引一个n dim
arrays,我们可以简化一下,为我们提供所有轴的索引网格,
def all_idx(idx, axis): grid = np.ogrid[tuple(map(slice, idx.shape))] grid.insert(axis, idx) return tuple(grid)
因此,使用它来索引input数组 –
axis = 0 a_max_values = a[all_idx(idx, axis=axis)] b_max_values = b[all_idx(idx, axis=axis)]