2017-01-23 34 views
3

TensorFlowを使用して作成されたカスタム目的関数を使用して、Kerasシーケンシャルモデルのフィット段階で次のエラーが発生します。ValueError:なし値はサポートされていませんTensorflowのKerasカスタム損失関数

File "basicCNN.py", line 110, in <module> 
callbacks=[TensorBoard(log_dir="./logs/{}".format(now))]) 
File "/home/garethjones/.local/lib/python2.7/site-packages/keras/models.py", line 664, in fit 
sample_weight=sample_weight) 
File "/home/garethjones/.local/lib/python2.7/site-packages/keras/engine/training.py", line 1115, in fit 
self._make_train_function() 
File "/home/garethjones/.local/lib/python2.7/site-packages/keras/engine/training.py", line 713, in _make_train_function 
self.total_loss) 
File "/home/garethjones/.local/lib/python2.7/site-packages/keras/optimizers.py", line 391, in get_updates 
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_ops.py", line 813, in binary_op_wrapper 
y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y") 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 669, in convert_to_tensor 
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/constant_op.py", line 176, in _constant_tensor_conversion_function 
return constant(v, dtype=dtype, name=name) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/constant_op.py", line 165, in constant 
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_util.py", line 360, in make_tensor_proto 
raise ValueError("None values not supported.") 

マイカスタム関数は、私は、対話型セッションを持っているとき、私はこれを実行することができ、この

def PAI(y_true, y_pred, k): 
    ''' 
    Args: 
     y_true (tensor): (batch x numCells) 
     y_pred (tensor): (batch x numCells) 
     k: The optimal number of hotspots 
     area: 
    Returns: 
     cfsRatio (tensor): The inverse of the percentage of crimes in hotspots per observation 
    ''' 
    # Compute total crime for each obs 
    totalCFS = tf.reduce_sum(y_true, axis=1) # batch x 1 
    # Flatten for gather 
    flatTruth = tf.reshape(y_true, [-1]) # 1 x batch * numCells 
    # Select top candidate cells 
    _, predHS = tf.nn.top_k(y_true, k) 
    # Convert indices for gather 
    predHSFlat = tf.range(0, tf.shape(y_true)[0]) * tf.shape(y_true)[1] + predHS) 
    # Map hotspot predictions to crimes 
    hsCFS = tf.gather(flatTruth, predHSFlat) 
    # Number of crimes commited in hotspots 
    hsCFSsum = tf.reduce_sum(hsCFS, axis=1) # batch x 1 
    # Ratio of crimes committed in hotspots and inverted for minimization 
    cfsRatio = tf.truediv(1.0, tf.truediv(hsCFSsum, totalCFS)) 

    return cfsRatio 

です。この関数は、主にこのTensorflowの問題https://github.com/tensorflow/tensorflow/issues/206のコードに依存しています。

+0

私も同様の問題があります。あなたは解決策を見つけましたか? – user2962197

答えて

0

Kerasのカスタム損失機能は、呼び出されたときにグラフを作成するだけです(TensorFlowをバックエンドとして使用する場合)。 TensorFlowコードは、あなたが気付いたようにfit()の呼び出しまで実際には実行されません(グラフは実行されません)。

したがって、一般的なデバッガを使用してコードをステップ実行し、問題のある行を見つけたり、データを調べたりすることはできません。印刷に

total = tf.Print(total, [total], 'total', summarize=10) 

:コンソール

ラップなどtf.Printを使用してコンソール出力で実行する際に

  • 印刷テンソル値:デバッグのため

    いくつかの技術テンソルの最初の10個の値total

    • 計算依存関係を最初の値は、(式LHSに割り当てられる)の計算に使用されるものであること

    メモを削除。第2引数は印刷されたものです。

    TF計算をテストして出力を見るには、それらを第2引数として渡してください。そうすれば、デバッグ中の出力には影響しません。

    戻り値から変数を削除し、問題のある計算を見つけます。

    knownGoodValue = K.sum(tf.square(y_true - y_pred)) # any expr known to work 
    printExpr = tf.reshape(y_true, [-1]) # 1 x batch * numCells 
    knownGoodValue = tf.Print(knownGoodValue, [printExpr], 'flatTruth', summarize=10) 
    
    return knownGoodValue 
    

    これを印刷し、任意のTFのexprがactaullyその結果を使用せずに、テスト式をラップすることによって、印刷/テストすることを可能にしながら、あなたが作品を知っているいくつかの表現を返します。たとえば。

    • 利用TFデバッガ

    https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/debug

    TensorFlowデバッガは、あなたのフィット関数の前に次のコードを追加することで呼び出すことができます。

    import keras.backend as K 
    from tensorflow.python import debug as tf_debug 
    sess = K.get_session() 
    sess = tf_debug.LocalCLIDebugWrapperSession(sess) 
    K.set_session(sess) 
    

    それはその後、GDB-などを提示fit()が呼び出されたときのcmdlineインターフェイス。

関連する問題