Python计算库NumPy的einsum(爱因斯坦和)使用简介

einsum函数是 Numpy 库非常出色的计算函数之一。einsum函数非常灵活,也非常高效,运行时占用的存储空间很小。不过einsum函数这么多优异的特点不是没有代价的——要想灵活的使用它,就需要花一些时间了解它。

网络上已经有不少内容介绍einsum函数的基本原理和实现,所以本文不对这些内容再做赘述,相反,这里只是做一些相对易懂的基本介绍,重点在于介绍如何使用它。

einsum是干什么的

使用einsum函数可以为 numpy arrays 指定一些数学计算,指定方式主要参考爱因斯坦求和约定,来看个例子:

假设有两个 array,A 和 B,现在我们希望:

  • 求 A B(注意,这里的乘不一定是简单的标量乘法),得到新的 array,然后
  • 按照指定的纬度方向求,并且/或者
  • 按照指定的纬度,变换(transpose)新array的纬度方向

虽然常规的 numpy 乘法/加法/transpose 操作能够完成上述期望,但是使用einsum能够使用更小的空间,更快的速度完成。(https://blog.popkx.com 原创,未经许可抄袭可耻)我们将上述例子实例化:

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

期望是:先计算 A 和 B 的逐元素乘积,然后沿着 axis 1 求和。

先来看看使用 numpy 的常规操作。首先需要做的是 reshape A,然后在做 broadcast,将其变成和 B 一样的形状。然后将 B 的第一行乘以 0,第二行乘以 1,第三行乘以 2,最终得到一个结果 array,再将其沿着 axis 1 求和:

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

这行代码能够准确的计算出结果,但是使用einsum更好一点:

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

为什么更好?简单来说,使用einsum,我们就不需要先对 A 做 reshape 和 broadcast(这个操作通常可以隐式的完成)操作了,更重要的是,在计算执行过程中不需要像 A[:, np.newaxis] * B 那样创建临时array了。相反,einsum简单的沿着行求积的和。

实际上,即使是这个简单的例子,我测试使用 einsum 计算的速度要是常规手段的 3 倍。

怎么使用einsum

使用einsum的诀窍就是关注输入/输出 arrays 的轴。einsum支持两种使用方式来标注轴变化:字符串,或者数字列表。不过,据我了解,字符串方式要简单许多,不少开源工程都是使用此种方式。

这里以矩阵乘法作为简单实例,矩阵乘法其实就是一个矩阵的行逐元素乘以另一个矩阵的列,然后把乘积加在一起。(https://blog.popkx.com 原创,未经许可抄袭可耻)对于两个 2D arrays A 和 B,使用 einsum 计算它俩的矩阵成绩可以按照下面这么写:

np.einsum('ij,jk->ik', A, B)

第一个参数'ij,jk->ik'是什么意思呢?将'->'看作分隔符,左边的部分'ij'表示输入 A 的轴,'jk'表示 B 的轴,右边的'ik'表示输出结果的轴。所以现在我们能够知道,这行代码接收两个 2D arrays 作为输入,输出也是一个 2D array。下文将参数'ij,jk->ik'称作轴标签

将 A 和 B 实例化:

A = np.array([[1, 1, 1],
              [2, 2, 2],
              [5, 5, 5]])

B = np.array([[0, 1, 0],
              [1, 1, 0],
              [1, 1, 1]])

我们将轴标签的输入/输出轴画出来:

np.einsum('ij,jk->ik', A, B)

要理解上述计算过程,记住下面这三条规则

  • 输入arrays沿着重复的那个轴做乘法计算。

在本例中,轴标签的输入参数ij,jk中 'j' 被重复用了 2 次:1次表示A的第二个轴,1次表示B的第一个轴。这意味着要计算A的第二个轴(行方向)与B的第一个轴(列方向)的乘积。所以,此时要保证输入合法,就需要保证A的行长与B的列长一致。

  • 如果某个轴标签在输出标签中消失了,则表示要沿着该轴做求和计算。

此处,'j' 没有出现在输出标签中,这表示在执行乘法运算后,需要再沿着'j'轴做求和运算,求和运算减少了输出 array 的 1 个纬度。如果我们将输出标签改为'ijk',那么因为少了求和运算,最终得到的输出 array 将会是 3x3x3 形状的。

如果此处我们将输入/输出标签改为'ij,jk->',也即令所有的标签都不出现在输出轴标签里,那么我们将得到一个标量,这个标量是所有元素的和。

  • 我们可以获取任意顺序轴的结果

如果我们在轴标签中不写'->',那么 numpy 会将只出现一次的标签按照字母顺序组合,作为输出轴标签,所以 ij,jk->ikij,jk 效果上是等价的。指定轴顺序的输出,可以通过指定轴标签的顺序获得。(https://blog.popkx.com 原创,未经许可抄袭可耻)例如,'ij,jk->ki'可以得到'ij,jk->ik'的转置矩阵。

了解上面三条基本规则后,再来看einsum如何计算矩阵乘法的就简单了。下图是左边是计算np.einsum('ij,jk->ijk', A, B)的结果,右图则是按照'j'轴求和后的结果:

np.einsum('ij,jk->ijk', A, B)和np.einsum('ij,jk->ik', A, B)

按照前文所述,einsum是非常节约空间的计算函数,所以对于np.einsum('ij,jk->ik', A, B)einsum并不构建临时的 3D array 然后求和,而是直接在 2D 空间累加得到最终结果。

einsum 的简单操作

下面两张表描述了 einsum 的基本操作。

若 A 和 B 为两个 1D arrays(假设在相应的操作中,A和B的形状总是合适的),那么:

Call signature NumPy equivalent Description
('i', A) A A
('i->', A) sum(A) A的所有元素和
('i,i->i', A, B) A * B A和B逐元素乘积
('i,i', A, B) inner(A, B) A和B的内积
('i,j->ij', A, B) outer(A, B) A和B的外积

若 A 和 B 为两个 2D arrays(假设在相应的操作中,A和B的形状总是合适的),那么:

Call signature NumPy equivalent Description
('ij', A) A A
('ji', A) A.T A的转置
('ii->i', A) diag(A) A 的对角
('ii', A) trace(A) A的迹
('ij->', A) sum(A) A的所有元素和
('ij->j', A) sum(A, axis=0) A的沿着axis=0的和
('ij->i', A) sum(A, axis=1) A的沿着axis=1的和
('ij,ij->ij', A, B) A * B A和B逐元素乘积
('ij,ji->ij', A, B) A * B.T A 和 B.T 逐元素乘积
('ij,jk', A, B) dot(A, B) A 和 B 的矩阵乘法
('ij,kj->ik', A, B) inner(A, B) A 和 B 的内积
('ij,kj->ikj', A, B) A[:, None] * B A的每一行与B的乘积
('ij,kl->ijkl', A, B) A[:, :, None, None] * B A的每一个元素与B的乘积

einsum轴标签中的'...'符号

在处理比较多的纬度时,为了方便,可以像 numpy array 一样使用 '...' 符号省略一些纬度的显式表示。例如,

np.einsum('...ij,ji->...', a, b)

这一行代码计算的是 a 的后两个轴与 2D array b 的乘积。

注意事项

einsum 在求和时,不会提升数据类型。如果我们使用了位宽比较小的数据类型,可能会得到不期望的结果:

>>> a = np.ones(300, dtype=np.int8)
>>> np.sum(a) # correct result
300
>>> np.einsum('i->', a) # produces incorrect result
44

另外,einsum在 numpy 的计算库中并不一定总是最快的。类似于 dotinner 的函数一般链接到快速计算库 BLAS,可能会快于 einsum

本文译自: https://ajcr.net/Basic-guide-to-einsum/

阅读更多:   算法
添加新评论

icon_redface.gificon_idea.gificon_cool.gif2016kuk.gificon_mrgreen.gif2016shuai.gif2016tp.gif2016db.gif2016ch.gificon_razz.gif2016zj.gificon_sad.gificon_cry.gif2016zhh.gificon_question.gif2016jk.gif2016bs.gificon_lol.gif2016qiao.gificon_surprised.gif2016fendou.gif2016ll.gif