2015-12-02 6 views
11

私はやや複雑で非標準的なNNアーキテクチャをデバッグしようとしています。フォワードパスを計算することはうまくいくと思いますが、予想される結果が得られますが、Adamや標準オプティマイザを使用して最適化しようとすると、非常に小さな学習率で繰り返してもどこでもナンを得ることができます。私はそれらをローカライズしようとしていますが、最初のオカレンスをキャッチして、それがどのように発生したかを検出する方法があるかどうか疑問に思っていましたか?私はtf.add_check_numerics_ops()を試しましたが、それは何もしていないようです、あるいは私はそれを間違って使っています。逆方向パスのナンをデバッグ

答えて

18

大規模なネットワークを使用している場合は、特にNaNのデバッグが難しい場合があります。 tf.add_check_numerics_ops()は、グラフの各浮動小数点テンソルにNaN値が含まれていないと主張するopsをグラフに追加しますが、デフォルトではこれらのチェックを実行しません。次のように代わりにそれは、あなたが定期的に実行することができますオペアンプを返し、またはステップごとに:

train_op = ... 
check_op = tf.add_check_numerics_ops() 

sess = tf.Session() 
sess.run([train_op, check_op]) # Runs training and checks for NaNs 
+0

問題は、train_opを実行すると、ネットワーク全体にナンバーが伝播するため、その原因を見つけることは役に立たないということです。私がやりたいことは、順方向パスと逆方向パスを実行することです。そして、nanが生成されるとすぐに、違反操作によって例外がスローされます。 –

+6

'train_op'と' check_op'を一緒に実行すると、NaNを持つ最初のノードを報告するエラーが発生します。発生した 'tf.InvalidArgumentError'を捕捉し、その' opから 'opを抽出できます。 op'プロパティを使用します。 opのハンドルを使って、 'op.inputs [0]'プロパティにアクセスして、どのテンソルがNaN値を持っているかを知ることができます。 – mrry

+0

これはありがとう! –

2

たぶん、あなたは、OPSのプリント値を疑うために、この

print_ops = [] 
for op in ops: 
    print_ops.append(tf.Print(op, [op], 
        message='%s :' % op.name, summarize=10)) 
print_op = tf.group(*print_ops) 
sess.run([train_op, print_op]) 

のようなものを印刷OPSを追加することができますすべての操作に追加するには、add_check_numerics_opsの行に沿ってループを実行できます。

関連する問題