2016-05-17 5 views
10

疎ベクトルをTFRecordに保存するにはどうすればよいでしょうか? valuesは「もの」のインデックスを含むリストである、ここでTFRecordからの可変サイズリストの保存と読み取り

example = tf.train.Example(
     features=tf.train.Features(
      feature={ 
       'label': self._int64_feature(label), 
       'features' : self._int64_feature_list(values) 
      } 
     ) 
    ) 

:私のスパースベクトルが唯一のものとゼロはので、私は、私はちょうど「もの」がこのように配置されているインデックスを救うことを決めたが含まれています。このvalues配列には何百もの要素が含まれていることがあります。その後、シリアライズされたサンプルをtfrecordに保存するだけです。その後、私はこのようなtfrecord読んでいます:

features = tf.parse_single_example(
    serialized_example, 
    features={ 
     # We know the length of both fields. If not the 
     # tf.VarLenFeature could be used 
     'label': tf.FixedLenFeature([], dtype=tf.int64), 
     'features': tf.VarLenFeature(dtype=tf.int64) 
    } 
) 

label = features['label'] 
values = features['features'] 

values配列がスパース配列として認識されていると私は私が保存したデータを得ることはありませんので、これは動作しません。 tfrecordsに疎テンソルを格納する最良の方法とそれを読む方法は何ですか?

+0

誰も答えを持っていますか?: –

+0

以下の私の答えを見てください。それはあなたの質問に答えますか? –

答えて

1

あなたはちょうどあなたが詐欺の少しであなたの正しいスパーステンソルを得ることができる必要があり1Sの位置をシリアル化している場合:

解析されたスパーステンソルfeatures['features']は次のようになります

features['features'].indices: [[batch_id, position]...]

ここで、positionは無駄な列挙型です。

が、あなたは本当にfeature['features']one_positionがあなたのまばらなテンソルで指定された実際の値である[[batch_id, one_position], ...]

に見えるようにしたいです。

ので:

indices = features['features'].indices 
indices = tf.transpose(indices) 
# Now looks like [[batch_id, batch_id, ...], [position, position, ...]] 
indices = tf.stack([indices[0], features['features'].values]) 
# Now looks like [[batch_id, batch_id, ...], [one_position, one_position, ...]] 
indices = tf.transpose(indices) 
# Now looks like [[batch_id, one_position], [batch_id, one_position], ...]] 
features['features'] = tf.SparseTensor(
    indices=indices, 
    values=tf.ones(shape=tf.shape(indices)[:1]) 
    dense_shape=1 + tf.reduce_max(indices, axis=[0]) 
) 

出来上がり! features['features']は、連結されたスパースベクトルのバッチであるマトリックスを表します。

注:これを高密度テンソルとして扱う場合は、tf.sparse_to_denseを実行する必要があり、高密度テンソルの形状は[None, None]になります(これは動作するのが難しくなります)。あなたはそれをハードコードしたいかもしれないベクトルの長さ:dense_shape=[batch_size, max_vector_length]

関連する問題