• 欢迎访问开心洋葱网站,在线教程,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站,欢迎加入开心洋葱 QQ群
  • 为方便开心洋葱网用户,开心洋葱官网已经开启复制功能!
  • 欢迎访问开心洋葱网站,手机也能访问哦~欢迎加入开心洋葱多维思维学习平台 QQ群
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏开心洋葱吧~~~~~~~~~~~~~!
  • 由于近期流量激增,小站的ECS没能经的起亲们的访问,本站依然没有盈利,如果各位看如果觉着文字不错,还请看官给小站打个赏~~~~~~~~~~~~~!

从手写三层循环到标准实现,矩阵相乘运行效率提高三万六千倍之路

其他 MCTW 2379次浏览 0个评论

前言

矩阵乘法可以说是最常见的运算之一。

本文介绍不同的方式实现的矩阵乘法,并比较它们运行速度的差异。

表示矩阵的方式有很多种,完善的矩阵类应该实现切片取值,获得矩阵形状等操作,但本文并不打算直接从原生Python实现一个矩阵类,而是直接用 Pytorch中的tensor表示矩阵。

开始: 三层循环

根据矩阵相乘定义,可通过三层循环实现该运算。

def matmul(a, b):
    r1, c1 = a.shape
    r2, c2 = b.shape
    
    assert c1 == r2
    
    rst = torch.zeros(r1, c2)
    
    for i in range(r1):
        for j in range(c2):
            for k in range(c1):
                rst[i][j] += a[i][k] * b[k][j]
    return rst

那么这个函数的运行效率如何呢?让我们尝试两个较大的矩阵相乘,测试一下运行时间。

m1 = torch.randn(5, 784)
m2 = torch.randn(784, 10)

%timeit -n 10 matmul(m1, m2)

得到结果如下:

624 ms ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

即每次矩阵相乘, 需要耗时 600ms 左右,这是一个非常非常慢的速度,慢到两次矩阵乘法居然要耗时1秒多,这是不可能被接受的。

相同形状的张量进行运算

如果两个张量的形状相同,则他们的运算为同一位置的数字进行运算。

a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])

a + b  # tensor([5., 7., 9.])
a * b  # tensor([ 4., 10., 18.])

康康之前用三层循环实现的矩阵相乘,发现最里面一层循环的本质就是两个同样大小的张量相乘,再进行求和。
即第一个矩阵中的一行 跟 第二个矩阵中的一列 进行运算,且这行和列中的元素个数相同,则我们可以通过同样形状的张量运算改写最内层循环:

def matmul(a, b):
    r1, c1 = a.shape
    r2, c2 = b.shape
    
    assert c1 == r2
    
    rst = torch.zeros(r1, c2)
    
    for i in range(r1):
        for j in range(c2):
            rst[i][j] = (a[i,:] * b[:,j]).sum()  # 改了这里
    return rst

%timeit -n 10 matmul(m1, m2)

得到结果如下

1.4 ms ± 92.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

624 / 1.4=445,只改写了一下最内层循环,就使得矩阵乘法快了445倍!

广播机制

广播机制使得不同形状的张量间可以进行运算:

  1. 两个张量扩充成同样的形状
  2. 再按相同形状的张量进行运算
# shape: [2, 3]
a = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
])

# shape: [1]
b = torch.tensor([1])

# shape: [3]
c = torch.tensor([10, 20, 30])

形状为 [2, 3] 和 [1] 的两个张量相加:

a + b

"""输出:
tensor([[2, 3, 4],
        [5, 6, 7]])
"""

形状为 [2, 3] 和 [3] 的两个张量相加:

b + c

"""输出:
tensor([[11, 22, 33],
        [14, 25, 36]])
"""

这两个例子中,维度低的张量都是暗地里先扩充成了维度高的张量,然后再参与的运算。

那么如何查看扩充后的张量是啥呢?用 expand_as 函数就可以查看:

b.expand_as(a)

"""输出
tensor([[1, 1, 1],
        [1, 1, 1]])
"""
b.expand_as(a)

"""输出
tensor([[10, 20, 30],
        [10, 20, 30]])
"""

这就一目了然了,形状不同的张量可以通过广播机制扩充成形状一致的张量再进行运算。

那么任意形状的两个张量都可以运算吗?当然不是了,判断两个张量是否能运算的规则如下:

先从两个张量的最后一个维度看起,如果维度的维数相同,或者其中一个维数为1,则可以继续判断,否则就失败。
然后看倒数第二个维度,倒数第三个维数,一直到遍历完某个张量的维数为止,一直没有失败则这两个张量可以通过广播机制进行运算。

那么这个广播机制和矩阵乘法有什么关系呢?答案就是它可以帮我们再去掉一层循环。

现在的最内存循环的本质是 一个形状为 [c1] 的张量 和 一个形状为 [c1, c2] 的张量做运算,最终生成一个形状为 [c2] 的张量。

则我们可以把矩阵运算改写为:

def matmul(a, b):
    r1, c1 = a.shape
    r2, c2 = b.shape
    
    assert c1 == r2
    
    rst = torch.zeros(r1, c2)
    
    for i in range(r1):
        rst[i] = (a[i, :].unsqueeze(-1) * b).sum(0)
    return rst

%timeit -n 10 matmul(m1, m2)

"""输出
249 µs ± 66.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""

现在已经把每次矩阵运算的时间压缩到了 249 µs!!!,比最开始的 624ms 快了 2500倍!

对于 unsqueeze 操作不太熟悉的小伙伴请看我的另一篇文档: Pytorch 中张量的理解

但是还没结束。。。因为两个矩阵的相乘,就是 [r1, c1] 和 [c1, c2] 两个张量的运算,我们可以直接把它用广播机制一次到位的算出结果,连唯一的那层循环也可以省去:

def matmul(a, b):
    r1, c1 = a.shape
    r2, c2 = b.shape
    
    assert c1 == r2
    
    return (a.unsqueeze(-1) * b.unsqueeze(0)).sum(1)

%timeit -n 10 matmul(m1, m2)

"""输出:
169 µs ± 41.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""

这个 169µs 已经是最开始矩阵相乘版本的 3700 倍了。。( Ĭ ^ Ĭ )泪目,果然知识是第一生产力。

爱因斯坦求和

接下来就是 pytorch 自带的矩阵运算工具了,其中一个是爱因斯坦求和,貌似知道这个的同学不多。。
简单来说,它能让我们几乎不编写代码就能进行矩阵运算,只需要确定输入和输出矩阵的形状即可:

def matmul(a, b):
    return torch.einsum("ik,kj->ij", a, b)

%timeit -n 10 matmul(a, b)

"""输出
74 µs ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""

74µs 这个速度已经是原始版本的 8000 多倍了。。。但是对于工业级别的要求似乎仍然不够快~

pytorch 的矩阵相乘标准实现

最后祭出 pytorch 的矩阵相乘官方版本:

def matmul(a, b):
    return a @ b

%timeit -n 10 matmul(m1, m2)

"""输出
17.1 µs ± 28.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""

17.1 µs 是原始三层循环版本的 36000 倍,官方实现就是这么简单枯燥,朴实无华~


开心洋葱 , 版权所有丨如未注明 , 均为原创丨未经授权请勿修改 , 转载请注明从手写三层循环到标准实现,矩阵相乘运行效率提高三万六千倍之路
喜欢 (0)

您必须 登录 才能发表评论!

加载中……