2017-02-11 3 views
0

私はTensorFlowの多目的ニューラルネットを自分自身の損失関数で訓練しており、バッチ処理がどのようにその機能と相互作用するかに関するドキュメントを見つけることはできません。バッチ処理はTensorFlowの損失機能とどのように相互作用しますか?

例えば、私は予測のテンソル/リストを受け取り、1個以下にしてくださいそれらの絶対値和ということになり、以下の私の損失関数のスニペット、持っている:

def fitness(predictions,actual): 

    absTensor = tf.abs(predictions) 
    sumTensor = tf.reduce_sum(absTensor) 
    oneTensor = tf.constant(1.0) 

    isGTOne = tf.greater(sumTensor,oneTensor) 

    def norm(): return predictions/sumTensor 
    def unchanged(): return predictions 

    predictions = tf.cond(isGTOne,norm,unchanged) 

    etc... 

をしかし、私はとき私は、この損失関数が、1に合計するそれぞれの集合ではなく、この点で合計1になるように入力全体を正規化していると感じる推定値のバッチを渡します。つまり、所望のものではなく
[[8、8]、[8,12] .8]、[.8、.8]] - [[.5、.5]、[.5、.5]]

誰かが明確に疑いを抱かせることができますか?これが私の機能が現在どのように働いているのか、それをどうやって変えるのですか?

答えて

2

リダクション軸を指定する必要があります。それ以外の場合は、すべての軸が減少します。伝統的に、これはあなたのテンソルの最初の次元です。したがって、2行目は次のようになります。

sumTensor = tf.reduce_sum(absTensor, 0) 

この変更を加えたら、別の問題が発生します。 sumTensorはもはやスカラーではなくなり、tf.condの条件としてはもはや意味をなさない(つまり、バッチのエントリごとに分岐することはどういう意味ですか?)。本当に欲しいのはtf.selectです。なぜなら、バッチエントリごとにロジックを分岐させたくないからです。このように:

isGTOne = tf.greater(sumTensor,oneTensor) 

norm = predictions/sumTensor 

predictions = tf.select(isGTOne,norm,predictions) 

これを見ても、条件付きでエントリを正規化することは気にしません。バッチの細分度で操作しているので、バッチのエントリを一度に正規化してパフォーマンスを得ることはできません。特に、分割は本当に高価な副作用ではないので。場合もあります:

def fitness(predictions,actual): 

    absTensor = tf.abs(predictions) 
    sumTensor = tf.reduce_sum(absTensor, 0) 

    predictions = predictions/sumTensor 

    etc... 

希望に役立ちます!

+0

これは完璧です。ありがとうございました。ドキュメントのどこかにこの動作について語っていますか?予期せぬことが何も起こらないようにするために読んでみたい – liqiudilk

+0

特にどのような振る舞いでドキュメントを探していましたか? [tf.select docs](https://www.tensorflow.org/api_docs/python/control_flow_ops/comparison_operators#select)は便利です。 – suharshs

関連する問題