了解NumPy的寄语
我正在努力弄清楚einsum
是如何工作的。 我已经看过这个文档和一些例子,但是它并不是一成不变的。
下面是我们在课堂上学习的一个例子:
C = np.einsum("ij,jk->ki", A, B)
对于两个数组A
和B
我认为这将需要A^T * B
,但是我不确定(它正在转换其中一个对吗?)。 任何人都可以通过我在这里发生的事情(一般在使用einsum
)吗?
(注意:这个答案是基于我刚才写的一篇关于einsum
的简短博客文章 。)
einsum
做什么?
想象一下,我们有两个multidimensional array, A
和B
现在让我们假设我们想…
- 以特定的方式将
A
与B
相乘以创build新的产品arrays; 然后也许 - 将这个新的数组沿着特定的轴求和 ; 然后也许
- 按照特定顺序转置新arrays的轴。
有一个很好的机会, einsum
将帮助我们更快,更有效地记忆NumPy函数的组合,如multiply
, sum
和transpose
将允许。
einsum
如何工作?
这是一个简单的(但不是完全微不足道的)例子。 采取以下两个数组:
A = np.array([0, 1, 2]) B = np.array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]])
我们将A
和B
乘以元素,然后沿新数组的行进行求和。 在“正常”的NumPy中,我们写道:
>>> (A[:, np.newaxis] * B).sum(axis=1) array([ 0, 22, 76])
所以在这里, A
上的索引操作将两个数组的第一个轴排成一列,以便可以广播乘法运算。 然后将产品arrays的行相加以返回答案。
现在,如果我们想要使用einsum
,我们可以这样写:
>>> np.einsum('i,ij->i', A, B) array([ 0, 22, 76])
签名string'i,ij->i'
是这里的关键,需要一点解释。 你可以把它想成两半。 在左侧( ->
左侧),我们标记了两个input数组。 在->
的右边,我们标记了我们想要结束的数组。
接下来会发生什么:
-
A
有一个轴; 我们已经贴上了标签。 而B
有两个轴; 我们将轴0标记为i
,将轴1标记为j
。 -
通过在两个input数组中重复标签
i
,我们告诉einsum
这两个轴应该相乘 。 换句话说,我们将数组A
与数组B
每一列相乘,就像A[:, np.newaxis] * B
一样。 -
请注意,
j
在我们所需的输出中不会显示为标签。 我们刚刚使用i
(我们想结束一维数组)。 通过省略标签,我们告诉einsum
沿着这个轴进行求和 。 换句话说,我们正在总结产品的行,就像.sum(axis=1)
一样。
这基本上是所有你需要知道使用einsum
。 这有助于发挥一点; 如果我们在输出中留下两个标签,我们就得到一个二维数组(与A[:, np.newaxis] * B
)相同。 如果我们说没有输出标签,我们得到一个单一的数字(和(A[:, np.newaxis] * B).sum()
)。
然而,关于einsum
在于,并不是首先build立一个临时的产品阵容; 它只是对产品进行总结。 这可以导致内存使用的大量节省。
一个稍大的例子
为了解释点积,下面是两个新的数组:
A = array([[1, 1, 1], [2, 2, 2], [5, 5, 5]]) B = array([[0, 1, 0], [1, 1, 0], [1, 1, 1]])
我们将使用np.einsum('ij,jk->ik', A, B)
来计算点积。 下面的图片显示了A
和B
的标签以及我们从函数中获得的输出数组:
你可以看到标签j
被重复 – 这意味着我们将A
的行与B
的列相乘。 此外,标签j
不包括在输出中 – 我们正在总结这些产品。 标签i
和k
保留为输出,所以我们得到一个二维数组。
将这个结果与标签j
不加和的数组进行比较可能会更清楚。 下面,在左边你可以看到写入np.einsum('ij,jk->ijk', A, B)
(即我们保留了标签j
):
总结轴j
给出了预期的点积,如右图所示。
一些练习
为了获得更多的感觉,可以使用下标符号来实现熟悉的NumPy数组操作。 任何涉及乘法和求和轴的组合的东西都可以使用einsum
。
设A和B是两个具有相同长度的一维数组。 例如, A = np.arange(10)
和B = np.arange(5, 15)
。
-
A
的总和可以写成:np.einsum('i->', A)
-
单元乘法
A * B
可以写成:np.einsum('i,i->i', A, B)
-
内积或点积
np.inner(A, B)
或np.dot(A, B)
可写成:np.einsum('i,i->', A, B) # or just use 'i,i'
-
外部产品
np.outer(A, B)
可以写成:np.einsum('i,j->ij', A, B)
对于二维数组, C
和D
,假设这些轴是兼容的长度(两个长度相同或其中一个长度为1),下面是几个例子:
-
C
(主对angular线之和),np.trace(C)
可以写成:np.einsum('ii', C)
-
C
和D
,C * DT
的转置的元素乘法可以写成:np.einsum('ij,ji->ij', C, D)
-
将
C
的每个元素乘以数组D
(构成一个4D数组),C[:, :, None, None] * D
,可以写成:np.einsum('ij,kl->ijkl', C, D)
让我们制作2个不同的,但兼容的尺寸来突出他们的相互作用
In [43]: A=np.arange(6).reshape(2,3) Out[43]: array([[0, 1, 2], [3, 4, 5]]) In [44]: B=np.arange(12).reshape(3,4) Out[44]: array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]])
你的计算需要一个(2,3)和(3,4)的“点”(产品的总和)来产生一个(4,2)数组。 i
是A
的第一个暗淡, C
的最后一个; k
B
的最后一个, C
第一个。 j
是总和消耗的。
In [45]: C=np.einsum('ij,jk->ki',A,B) Out[45]: array([[20, 56], [23, 68], [26, 80], [29, 92]])
这与np.dot(A,B).T
– 它是最终的输出。
要查看更多关于j
发生了什么,请将C
下标更改为ijk
:
In [46]: np.einsum('ij,jk->ijk',A,B) Out[46]: array([[[ 0, 0, 0, 0], [ 4, 5, 6, 7], [16, 18, 20, 22]], [[ 0, 3, 6, 9], [16, 20, 24, 28], [40, 45, 50, 55]]])
这也可以通过以下方式产生:
A[:,:,None]*B[None,:,:]
也就是说,在A
的末尾添加一个k
维,并在B
的前面添加一个i
,得到一个(2,3,4)数组。
0 + 4 + 16 = 20
9 + 28 + 55 = 92
等; 求和j
并转置得到较早的结果:
np.sum(A[:,:,None] * B[None,:,:], axis=1).T # C[k,i] = sum(j) A[i,j (,k) ] * B[(i,) j,k]
我发现NumPy:交易的技巧(第二部分)有启发性
我们用 – >来表示输出数组的顺序。 所以把'ij,i-> j'看作左边(LHS)和右边(RHS)。 LHS上任何重复的标签都会计算产品元素,然后进行求和。 通过改变RHS(输出)侧的标签,我们可以定义我们想要对input数组进行处理的轴,即轴0,1上的求和等等。
import numpy as np >>> a array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) >>> b array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) >>> d = np.einsum('ij, jk->ki', a, b)
注意有三个坐标轴,i,j,k,j是重复的(在左边)。 i,j
表示一个行和列。 j,k
为b
。
为了计算产品和alignmentj
轴,我们需要添加一个轴到a
。 ( b
将沿(?)第一轴播放)
a[i, j, k] b[j, k] >>> c = a[:,:,np.newaxis] * b >>> c array([[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8]], [[ 0, 2, 4], [ 6, 8, 10], [12, 14, 16]], [[ 0, 3, 6], [ 9, 12, 15], [18, 21, 24]]])
j
在右边不存在,所以我们求和3x3x3arrays的第二个轴j
>>> c = c.sum(1) >>> c array([[ 9, 12, 15], [18, 24, 30], [27, 36, 45]])
最后,右边的(按字母顺序)颠倒了索引,所以我们转换。
>>> cT array([[ 9, 18, 27], [12, 24, 36], [15, 30, 45]]) >>> np.einsum('ij, jk->ki', a, b) array([[ 9, 18, 27], [12, 24, 36], [15, 30, 45]]) >>>