2016-03-04 34 views
6

私は、次元(d、k)のn個の行列と次元(k、n)の行列VからなるテンソルUを持っています。numpy tensordotによるテンソル乗算

結果は、列jがUの行列jとVの列jとの間の行列乗算の結果である次元(d、n)の行列を返すようにそれらを掛けたいと考えます。

これを取得する

enter image description here

1つの可能な方法は次のとおりです。numpyライブラリを使用してより高速なアプローチがある場合、私は疑問に思って

for j in range(n): 
    res[:,j] = U[:,:,j] * V[:,j] 

。特に、私はnp.tensordot()の機能を考えています。

この小さなスニペットでは、単一の行列にスカラーを乗算することができますが、ベクトルへの明らかな一般化は、私が望んでいたものを返すことではありません。

a = np.array(range(1, 17)) 
a.shape = (4,4) 
b = np.array((1,2,3,4,5,6,7)) 
r1 = np.tensordot(b,a, axes=0) 

ご提案がありますか?

+0

あなたがあなたのイメージを描画するために使用しているどのようなソフトウェア? – hlin117

+1

@ hlin117 - キーノートを使用しました。 – Matteo

答えて

6

これを行うにはいくつかの方法があります。心に来る最初の事はnp.einsumです:

# some fake data 
gen = np.random.RandomState(0) 
ni, nj, nk = 10, 20, 100 
U = gen.randn(ni, nj, nk) 
V = gen.randn(nj, nk) 

res1 = np.zeros((ni, nk)) 
for k in range(nk): 
    res1[:,k] = U[:,:,k].dot(V[:,k]) 

res2 = np.einsum('ijk,jk->ik', U, V) 

print(np.allclose(res1, res2)) 
# True 

np.einsumはテンソル収縮を表現するためにEinstein notationを使用しています。上記式'ijk,jk->ik'の式において、i,jおよびkは、UおよびVの異なる寸法に対応する下付き文字である。カンマで区切られた各グループは、np.einsumに渡されるオペランドの1つに対応します(この場合、のサイズはijkで、Vのサイズはjkです)。 '->ik'の部分には、出力配列の寸法を指定します。出力文字列に存在しない下付き文字を持つ次元は合計されます。

np.einsumは、複雑なテンソル収縮を実行するのに非常に便利ですが、動作の仕方を頭の中で完全に包むにはしばらく時間がかかることがあります。ドキュメントの例を見てください(上にリンクされています)。


いくつかの他のオプション:broadcasting

  1. 要素ごとの乗算、合計が続く:

    from numpy.core.umath_tests import inner1d 
    
    res4 = inner1d(U.transpose(0, 2, 1), V.T) 
    
  2. :転置の荷重で

    res3 = (U * V[None, ...]).sum(1) 
    

いくつかのベンチマーク:

In [1]: ni, nj, nk = 100, 200, 1000 

In [2]: %%timeit U = gen.randn(ni, nj, nk); V = gen.randn(nj, nk) 
    ....: np.einsum('ijk,jk->ik', U, V) 
    ....: 
10 loops, best of 3: 23.4 ms per loop 

In [3]: %%timeit U = gen.randn(ni, nj, nk); V = gen.randn(nj, nk) 
(U * V[None, ...]).sum(1) 
    ....: 
10 loops, best of 3: 59.7 ms per loop 

In [4]: %%timeit U = gen.randn(ni, nj, nk); V = gen.randn(nj, nk) 
inner1d(U.transpose(0, 2, 1), V.T) 
    ....: 
10 loops, best of 3: 45.9 ms per loop 
+0

答えをありがとう!機能の仕組みに関する説明を追加してください。例えば、 '(ni、nj、nk)'ではなくUが '(nk、ni、nj)'だったら、関数呼び出しはどう変わるのでしょうか? – Matteo

+0

すばらしい答え!どうもありがとう! – Matteo