2015-11-13 15 views
7

平均二乗誤差を計算するTensorFlowに実装された損失関数があります。目的を計算するために使用されるすべてのテンソルはfloat64型であり、従って損失関数自体はdtype float64です。私は最小化しようとすると、特に、テンソルフロー損失最小化タイプエラー

print cost 
==> Tensor("add_5:0", shape=TensorShape([]), dtype=float64) 

はしかし、私はテンソルの種類についての値のエラーを取得:

GradientDescentOptimizer(learning_rate=0.1).minimize(cost) 
==> ValueError: Invalid type <dtype: 'float64'> for add_5:0, expected: [tf.float32]. 

テンソルの予想DTYPEがある理由を私は理解していません計算に至るすべての変数がfloat64型である場合の単精度浮動小数点数です。私は、すべての変数をfloat32に強制すると計算が正しく実行されることを確認しました。

なぜこのようなことが起こるのかについての洞察はありますか?私のコンピュータは64ビットマシンです。ここ

は動作

import tensorflow as tf 
import numpy as np 

# Make 100 phony data points in NumPy. 
x_data = np.random.rand(2, 100) # Random input 
y_data = np.dot([0.100, 0.200], x_data) + 0.300 

# Construct a linear model. 
b = tf.Variable(tf.zeros([1], dtype=np.float64)) 
W = tf.Variable(tf.random_uniform([1, 2], minval=-1.0, maxval=1.0, dtype=np.float64)) 
y = tf.matmul(W, x_data) + b 

# Minimize the squared errors. 
loss = tf.reduce_mean(tf.square(y - y_data)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss) 

# For initializing the variables. 
init = tf.initialize_all_variables() 

# Launch the graph 
sess = tf.Session() 
sess.run(init) 

# Fit the plane. 
for step in xrange(0, 201): 
    sess.run(train) 
    if step % 20 == 0: 
     print step, sess.run(W), sess.run(b) 

答えて

4

32ビット浮動小数点変数と損失値に現在tf.train.GradientDescentOptimizerクラスのみsupportsトレーニングを再生する例です。

しかし、カーネルが倍精度値用に実装されているように見えるので、あなたのシナリオで訓練することができるはずです。

class DoubleGDOptimizer(tf.train.GradientDescentOptimizer): 
    def _valid_dtypes(self): 
    return set([tf.float32, tf.float64]) 

...そしてtf.train.GradientDescentOptimizerの代わりにDoubleGDOptimizerを使用します。

迅速な回避策は、同様tf.float64値をサポートし、サブクラスを定義することです。

編集:この作業を行うには、学習率をtf.constant(learning_rate, tf.float64)として渡す必要があります。

NBこれは、サポートインタフェースではありません、将来変更される可能性がありますが、チームは倍精度浮動小数点数を最適化するための欲求を認識し、そして内蔵のソリューションを提供することを目的とします。 )

+0

わかりました!ありがとう! – user1936768

+0

Doens'tは現在動作しているようです(tf v0.6)。 'TypeError: 'ApplyGradientDescent'の 'alpha'を入力Opにfloat32型があり、引数 'var'のfloat64型に一致しません。 – colinfang

+0

ありがとうございました。私は修正でその答えを編集しました。 – mrry

関連する問題