2016-06-25 8 views
4

Tensorflowで事前秤量されたLSTMモデルを実装したいと考えています。これらの重量は、カフェやトーチから来るかもしれません。
ファイルrnn_cell.pyにLSTMセルがあり、rnn_cell.BasicLSTMCellrnn_cell.MultiRNNCellなどが見つかりました。しかし、どのようにしてこれらのLSTMセルの事前秤量をロードできますか?Tensorflowで事前に訓練されたLSTMモデルの荷重をロードする方法

答えて

0

これは、事前にトレーニングされたCaffeモデルを読み込むためのソリューションです。 (this threadの説明を参照)を参照してください。

net_caffe = caffe.Net(prototxt, caffemodel, caffe.TEST) 
caffe_layers = {} 

for i, layer in enumerate(net_caffe.layers): 
    layer_name = net_caffe._layer_names[i] 
    caffe_layers[layer_name] = layer 

def caffe_weights(layer_name): 
    layer = caffe_layers[layer_name] 
    return layer.blobs[0].data 

def caffe_bias(layer_name): 
    layer = caffe_layers[layer_name] 
    return layer.blobs[1].data 

#tensorflow uses [filter_height, filter_width, in_channels, out_channels] 2-3-1-0 
#caffe uses [out_channels, in_channels, filter_height, filter_width] 0-1-2-3 
def caffe2tf_filter(name): 
    f = caffe_weights(name) 
    return f.transpose((2, 3, 1, 0)) 

class ModelFromCaffe(): 
    def get_conv_filter(self, name): 
     w = caffe2tf_filter(name) 
     return tf.constant(w, dtype=tf.float32, name="filter") 

    def get_bias(self, name): 
     b = caffe_bias(name) 
     return tf.constant(b, dtype=tf.float32, name="bias") 

    def get_fc_weight(self, name): 
     cw = caffe_weights(name) 
     if name == "fc6": 
      assert cw.shape == (4096, 25088) 
      cw = cw.reshape((4096, 512, 7, 7)) 
      cw = cw.transpose((2, 3, 1, 0)) 
      cw = cw.reshape(25088, 4096) 
     else: 
      cw = cw.transpose((1, 0)) 

     return tf.constant(cw, dtype=tf.float32, name="weight") 

images = tf.placeholder("float", [None, 224, 224, 3], name="images") 
m = ModelFromCaffe() 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    batch = cat.reshape((1, 224, 224, 3)) 
    out = sess.run([m.prob, m.relu1_1, m.pool5, m.fc6], feed_dict={ images: batch }) 
... 
+1

ありがとうございました。それは私の多くを助けます。しかし、RNNのために、私は事前にトレーニングされた体重を初期化する方法を見つけませんでした。 –

+0

ModelFromCaffeクラスを使用して変数を作成します。 'fc6_W = tf.Variable(m.get_fc_weight(" fc6 ")、name =" fc6_W ")' [こちらのドキュメントはこちら](https://www.tensorflow.org/versions/r0.9/how_tos/variables/ index.html)。 – ssjadon

関連する問題