1. ホーム
  2. python

[解決済み] numpy dot()とPython 3.5+の行列の乗算の違い@。

2022-05-30 14:17:43

質問

最近 Python 3.5 に移行したのですが、その際に 新しい行列の乗算演算子 (@) とは異なる挙動をすることがあります。 numpy dot 演算子と異なる動作をすることがあります。例として、3次元配列の場合。

import numpy as np

a = np.random.rand(8,13,13)
b = np.random.rand(8,13,13)
c = a @ b  # Python 3.5+
d = np.dot(a, b)

@ 演算子は形状の配列を返します。

c.shape
(8, 13, 13)

の間に np.dot() 関数が返ります。

d.shape
(8, 13, 8, 13)

同じ結果をnumpy dotで再現するにはどうしたらよいでしょうか?他に大きな違いはありますか?

どのように解決するのですか?

この @ 演算子は配列の __matmul__ メソッドではなく dot . このメソッドは、API でも関数 np.matmul .

>>> a = np.random.rand(8,13,13)
>>> b = np.random.rand(8,13,13)
>>> np.matmul(a, b).shape
(8, 13, 13)

ドキュメントから

matmul とは異なり dot とは二つの重要な点で異なります。

  • スカラーによる乗算は許可されていません。
  • 行列のスタックは、行列が要素であるかのように一緒に放送されます。

最後のポイントにより、以下のことが明らかになりました。 dotmatmul メソッドは、3次元(またはそれ以上の次元)の配列を渡されたときに異なる動作をします。もう少しドキュメントから引用します。

について matmul :

いずれかの引数がN-D, N > 2の場合、最後の2つのインデックスに存在する行列のスタックとして扱われ、それに応じてブロードキャストされます。

については np.dot :

2次元配列では行列の乗算、1次元配列ではベクトルの内積(複素共役なし)と等価です。 N次元の場合、aの最終軸とbの最後から2番目の軸の和積となる