2017-01-17 9 views
4

私はいくつかのBNの例を見てきましたが、まだ少し混乱しています。だから私は現在、ここで関数を呼び出すこの関数を使用しています。バッチ正規化 - Tensorflow

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.batch_norm.md

from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm 
import tensorflow as tf 

def bn(x,is_training,name): 
    bn_train = batch_norm(x, decay=0.9, center=True, scale=True, 
    updates_collections=None, 
    is_training=True, 
    reuse=None, 
    trainable=True, 
    scope=name) 
    bn_inference = batch_norm(x, decay=1.00, center=True, scale=True, 
    updates_collections=None, 
    is_training=False, 
    reuse=True, 
    trainable=False, 
    scope=name) 
    z = tf.cond(is_training, lambda: bn_train, lambda: bn_inference) 
    return z 

これ以下の部分は、私はちょうど関数は2つの機能のためのトレーニングステップで算出した平均と分散を再利用することを確認していますおもちゃの実行です。 is_training=Falseすなわち、テストモードでコードのこの部分を実行して、トレーニングステップで算出した移動平均/分散たちは、私が最初にbnParams

if __name__ == "__main__": 
    print("Example") 

    import os 
    import numpy as np 
    import scipy.stats as stats 
    np.set_printoptions(suppress=True,linewidth=200,precision=3) 
    np.random.seed(1006) 
    import pdb 
    path = "batchNorm/" 
    if not os.path.exists(path): 
     os.mkdir(path) 
    savePath = path + "bn.model" 

    nFeats = 2 
    X = tf.placeholder(tf.float32,[None,nFeats]) 
    is_training = tf.placeholder(tf.bool,name="is_training") 
    Y = bn(X,is_training=is_training,name="bn") 
    mvn = stats.multivariate_normal([0,100]) 
    bs = 4 
    load = 0 
    train = 1 
    saver = tf.train.Saver() 
    def bnCheck(batch,mu,std): 
     # Checking calculation 
     return (x - mu)/(std + 0.001) 
    with tf.Session() as sess: 
     if load == 1: 
      saver.restore(sess,savePath) 
     else: 
      tf.global_variables_initializer().run() 
     #### TRAINING ##### 
     if train == 1: 
      for i in xrange(100): 
       x = mvn.rvs(bs) 
       y = Y.eval(feed_dict={X:x, is_training.name: True}) 

     def bnParams(): 
      beta, gamma, mean, var = [v.eval() for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="bn")] 
      return beta, gamma, mean, var 

     beta, gamma, mean, var = bnParams() 
     #### TESTING ##### 
     for i in xrange(10): 
      x = mvn.rvs(1).reshape(1,-1) 
      check = bnCheck(x,mean,np.sqrt(var)) 
      y = Y.eval(feed_dict={X:x, is_training.name: False}) 
      print("x = {0}, y = {1}, check = {2}".format(x,y,check)) 
      beta, gamma, mean, var = bnParams() 
      print("BN Params: Beta {0} Gamma {1} mean {2} var{3} \n".format(beta,gamma,mean,var)) 

     saver.save(sess,savePath) 

を呼び出してから取得BN変数をプリントアウトしたときに見ることができる変更されていますテストループの3回の反復が以下のように見える。

x = [[ -1.782 100.941]], y = [[-1.843 1.388]], check = [[-1.842 1.387]] 
BN Params: Beta [ 0. 0.] Gamma [ 1. 1.] mean [ -0.2 99.93] var[ 0.818 0.589] 

x = [[ -1.245 101.126]], y = [[-1.156 1.557]], check = [[-1.155 1.557]] 
BN Params: Beta [ 0. 0.] Gamma [ 1. 1.] mean [ -0.304 100.05 ] var[ 0.736 0.53 ] 

x = [[ -0.107 99.349]], y = [[ 0.23 -0.961]], check = [[ 0.23 -0.96]] 
BN Params: Beta [ 0. 0.] Gamma [ 1. 1.] mean [ -0.285 99.98 ] var[ 0.662 0.477] 

私はBPをやっていないので、ベータとガンマは変わりません。しかし、私の実行手段/分散は変化しています。どこが間違っていますか?

編集: なぜこれらの変数がテストと列車の間で変更する必要がないかを知っておくとよいでしょう。

答えて

3

あなたのbn機能が間違っています。これを代わりに使用してください:

def bn(x,is_training,name): 
    return batch_norm(x, decay=0.9, center=True, scale=True, 
    updates_collections=None, 
    is_training=is_training, 
    reuse=None, 
    trainable=True, 
    scope=name) 

is_trainingは、実行中の平均などを更新するかどうかをシグナリングするブール0次元テンソルです。次に、テンソルを変更するだけで、あなたはトレーニング中であるかテスト中であるかをシグナリングしています。

EDIT: テンソルフローの多くの演算では、テンソルが受け入れられますが、True/Falseの引数は定数ではありません。

+0

これと私が投稿したものとの違いは何ですか? – mattdns

+0

あなたは計算のグラフ上に不要なノードを作成します。正直言って私はあなたのコードをテストしていませんが、私は悪いスタイルの多くの行を見ています(例えば、インポートはファイルの先頭にあるべきです、if_mainの下にインデントされていないなど) – lpp

+0

はい、 。 – mattdns

0

slim.batch_normを使用する場合は、tf.train.GradientDecentOptimizer(lr).minimize(loss)または他のオプティマイザの代わりにslim.learning.create_train_opを必ず使用してください。それが動作するかどうかを試してみてください!

+0

コードを書式設定して詳細を入力してください。 – Sid

関連する問題