2017-12-26 4 views
0

私は、tf.map_fn()で正しく動作するようにパラメータを構造化しようとしていますが、ほとんどのサンプルドキュメントでは、関数の引数と同じ形状の配列またはテンソルについてしか説明していません。Tensorflow tf.map_fn parameters

リンクは、次のとおり

Does tensorflow map_fn support taking more than one tensor?

私の具体的な例は、このです: Iパラメータテンソル形状として[なし]、[2]及び[X、Y]期待一部tensorflow機能を有しています。

テンソルAの形状の[BATCH_SIZE、のx * yを、2]

テンソルBの形状であるtensorflowドキュメントから[BATCH_SIZE、X、Y]

lambdaData = (tensorA, tensorB) 
lambdaFunc = lambda x: tensorflowFunc(x[0], x[1]) 
returnValues = tf.map_fn(lambdaFunc, lambdaData) 

ある:

If elems is a (possibly nested) list or tuple of tensors, then each of these 
tensors must have a matching first (unpack) dimension 

tensorsAとBは次元0でのみ一致するため、スタックしたり連結することはできません。私としてもlambdaDataを作成しようとした:

  1. 2個のテンソル
  2. 2個のテンソル
  3. テンソルペア

のリストのタプルのリスト上の結果のすべてのディメンションの不一致を変える中エラー。私はすべてのデータを単一のテンソルに配置するというドキュメントに従って推奨された使用に従いますが、テンソルAとテンソルBの間の次元の不一致のために私はできません。誰にもタプルやelemsの議論のリストがありましたか?

答えて

0

tf.map_fnのエラーメッセージはひどく誤解を招きます。ドキュメンテーションはこれについて詳細には言及していませんが、もしあなたがタプル/テンソルのリストを渡すならば、関数としての戻り値の正確な数を引数として必要とします。これを行う最も簡単な方法は、ジャンクを返してから、最初の戻り値だけを取得することです。

print(a.shape) #[batch, 784, 2] 
print(b.shape) #[batch, 28, 28] 
lambdaData = (a, b) 
testFunc = lambda x: return <somethingUseful>, 0 
returnValues, _ = tf.map_fn(testFunc, lambdaData) 

は期待どおりに機能します。