2016-10-30 10 views
1

現在テンソルフローでhttp://www.aclweb.org/anthology/P15-1061を実装しています。Tensorflowでペアワイズランキング損失関数を効率的に計算する

Iは以下のようにペアワイズランキング損失関数(紙のセクション2.5)を実装している:

s_theta_y = tf.gather(tf.reshape(s_theta, [-1]), y_true_index) 
s_theta_c_temp = tf.reshape(tf.gather(tf.reshape(s_theta, [-1]), y_neg_index), [-1, classes_size]) 
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1]) 

私は後者がまだ勾配を用いて実装されていないので、むしろtf.gather_ndよりtf.gather使用していました降下。私はまた、フラットな行列ですべてのインデックスを正しいものに変換しなければなりませんでした。

tf.gather_ndが急降下して実施された場合は、私のコードは次のようにされているでしょう:

s_theta_y = tf.gather_nd(s_theta, y_t_index) 
s_theta_c_temp = tf.gather_nd(s_theta, y_neg_index) 
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1]) 

s_thetaは、紙のように、各クラスラベルのための計算されたスコアです。 y_true_indexには、s_theta_yを計算するために真のクラスのインデックスが含まれています。 y_neg_indexはすべての負のクラスのインデックスであり、その次元は#class-1または#classのいずれかであり、その関係はotherとして分類されます。

ただし、いくつかの文章は「その他」に分類されているため、s_theta_y は存在しません。計算には考慮しないでください。このような場合を処理するために、定数を0にしてその項を取り消し、負のクラスの次元ベクトルを同じにするために、インデックスのランダムな値をコピーするだけです。 (インデックスではなく)すべての負のクラス間の最大値。

損失関数のこれらの項をより効率的に計算する方法はありますか?私は、あまりにも多くの形をしたtf.gatherを使うのが非常に遅いと感じています。

答えて

1

もちろん、gather_ndはあなたの望みですが、グラジエントがそこに実装されるまで、私はreshape() reshape()は実質的にフリーです。

C++ implementation of the reshape() opは多くの作業をしているようですが、形状情報の確認はすばやくエラーです。 "仕事"は高価かもしれないが、実際にはポインタコピー(CopyFromはポインタをコピーするCopyFromInternalを呼び出します)のように聞こえる90行目のCopyFromで発生します。

これは完全な意味を持っています。下にあるバッファーは数字のフラットな配列で、row-major orderであり、その順序は形状情報に依存しません。同じ理由から、tf.transpose()のようなものはに一般的なコピーが必要です。

関連する問題