2017-01-27 4 views
1

私はTensorflowのニューラルネットワークで、入力画像のラベルに基づいて(ラベルに対応する別のデータセットからの)ベクトルをサンプリングし、入力画像のドット積をとるカスタム損失関数を使用しています埋め込み(softmax事前活性化)およびサンプリングされたベクトル。また、入力ラベルと矛盾する不一致ベクトルと、現在の入力ラベルとラベルが異なるランダムな訓練入力画像埋め込みを否定的にサンプリングする。もちろんこれらはすべて同じ次元テンソルとして構成されており、カスタム損失は次のとおりです。背景のために、主だったが、私がいる問題は、入力画像に対応するラベルの値にアクセスする方法ですトレーニング/評価中にTensorflowのラベルの値にアクセスするにはどうすればよいですか?

Loss = max(0, input_embedding * mismatching_sample - input_embedding * matching_sample + 1) \ 
+ max(0, random_embedding * matching_sample - input_embedding * matching_sample + 1) 

?正しいベクトルをサンプリングして損失を計算するために、これらのラベルの値にアクセスできる必要があります。

セッションで実行してテンソルの値を取得するには、.eval()を使用することができますが、この端末を試してみたら、ちょうどハングしました。私は自分のニューラルネットワークを訓練しているので、別のセッションの中でセカンドセッションを実行し、他の実行セッションの一部である値を評価しようとする問題があるかどうかはわかりません。とにかく、私はこの仕事をどのようにすることができるかについて、まったく考えられていません。どんな助けでも大歓迎です!問題は、私がサンプリングされた値で渡すfeed_dictを使用していなかったた

# compute custom loss function using tensors and tensor operations 
def compute_loss(e_list, labels, i): 
    embedding = e_list[i] #getting the current embedding tensor 
    label = labels[i] #getting the matching label tensor, this is value I need. 
    y_index = np.nonzero(label)[0][0] #this always returns 0, doesn't work :(
    target = get_mnist_embedding(y_index) 
    wrong_mnist = get_mismatch_mnist_embedding(y_index) 
    wrong_spec = get_random_spec_embedding(y_index, e_list, labels) 
    # compute the loss: 
    zero = tf.constant(0,dtype="float32") 
    one = tf.constant(1,dtype="float32") 
    mul1 = tf.mul(wrong_mnist,embedding) 
    dot1 = tf.reduce_sum(mul1) 
    mul2 = tf.mul(target,embedding) 
    dot2 = tf.reduce_sum(mul2) 
    mul3 = tf.mul(target,wrong_spec) 
    dot3 = tf.reduce_sum(mul3) 
    max1 = tf.maximum(zero, tf.add_n([dot1, tf.negative(dot2), one])) 
    max2 = tf.maximum(zero, tf.add_n([dot3, tf.negative(dot2), one])) 
    loss = tf.add(max1,max2) 
    return loss 
+0

おそらく、あなたがキューのランナーを開始しなかったために掛かったでしょう。初心者のためのMNISTはラベルを読む例を示しています。基本的には、「ラベル」と「イメージ」のバッチがあり、並行してキューから取り出すことができます。 –

+0

ありがとう@Yaroslav。私はキューランナーを使って画像とラベルバッチを作成していましたが、実行時にラベルテンソル値を評価して、損失計算に必要な値をサンプリングしていました。それらをfeed_dictオブジェクトで渡します。私の解決策を答えとして掲載しました。 – kashkar

答えて

0

は、ここに問題のある証明した私の元の試みです。私は損失の計算の時に入力ラベルのテンソルの値を特定しようとしていたので、に、次にのサンプルを入力して、私の損失関数に必要な値を得ました。

私はsess.run([train_op, loss...], feed_dict=feed_dict)を実行したとき、私の損失の値にアクセスするためにfeed_dict辞書オブジェクトを使用して事前計算し、例えばmismatching_samplematching_sample値を整理し、最初にこれらの値のtf.placeholderを使用して、その値を渡すことによってこの問題に対処してきました計算。

関連する問題