現在テンソルフローでhttp://www.aclweb.org/anthology/P15-1061を実装しています。Tensorflowでペアワイズランキング損失関数を効率的に計算する
Iは以下のようにペアワイズランキング損失関数(紙のセクション2.5)を実装している:
s_theta_y = tf.gather(tf.reshape(s_theta, [-1]), y_true_index)
s_theta_c_temp = tf.reshape(tf.gather(tf.reshape(s_theta, [-1]), y_neg_index), [-1, classes_size])
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])
私は後者がまだ勾配を用いて実装されていないので、むしろtf.gather_ndよりtf.gather使用していました降下。私はまた、フラットな行列ですべてのインデックスを正しいものに変換しなければなりませんでした。
tf.gather_ndが急降下して実施された場合は、私のコードは次のようにされているでしょう:
s_theta_y = tf.gather_nd(s_theta, y_t_index)
s_theta_c_temp = tf.gather_nd(s_theta, y_neg_index)
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])
s_thetaは、紙のように、各クラスラベルのための計算されたスコアです。 y_true_indexには、s_theta_yを計算するために真のクラスのインデックスが含まれています。 y_neg_indexはすべての負のクラスのインデックスであり、その次元は#class-1または#classのいずれかであり、その関係はotherとして分類されます。
ただし、いくつかの文章は「その他」に分類されているため、s_theta_y は存在しません。計算には考慮しないでください。このような場合を処理するために、定数を0にしてその項を取り消し、負のクラスの次元ベクトルを同じにするために、インデックスのランダムな値をコピーするだけです。 (インデックスではなく)すべての負のクラス間の最大値。
損失関数のこれらの項をより効率的に計算する方法はありますか?私は、あまりにも多くの形をしたtf.gatherを使うのが非常に遅いと感じています。