2015-11-19 3 views
8

は、ここで私はcsvファイルのTensorFlowのdtypeを変更するにはどうすればよいですか?

import tensorflow as tf 
import numpy as np 
import input_data 

filename_queue = tf.train.string_input_producer(["cs-training.csv"]) 

reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 

record_defaults = [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11 = tf.decode_csv(
    value, record_defaults=record_defaults) 
features = tf.concat(0, [col2, col3, col4, col5, col6, col7, col8, col9, col10, col11]) 

with tf.Session() as sess: 
    # Start populating the filename queue. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(1200): 
    # Retrieve a single instance: 
    print i 
    example, label = sess.run([features, col1]) 
    try: 
     print example, label 
    except: 
     pass 

    coord.request_stop() 
    coord.join(threads) 

このコードは以下のエラーを返すランしようとしていたコードです。

--------------------------------------------------------------------------- 
InvalidArgumentError      Traceback (most recent call last) 
<ipython-input-23-e42fe2609a15> in <module>() 
     7  # Retrieve a single instance: 
     8  print i 
----> 9  example, label = sess.run([features, col1]) 
    10  try: 
    11   print example, label 

/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict) 
    343 
    344  # Run request and get response. 
--> 345  results = self._do_run(target_list, unique_fetch_targets, feed_dict_string) 
    346 
    347  # User may have fetched the same tensor multiple times, but we 

/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, target_list, fetch_list, feed_dict) 
    417   # pylint: disable=protected-access 
    418   raise errors._make_specific_exception(node_def, op, e.error_message, 
--> 419            e.code) 
    420   # pylint: enable=protected-access 
    421  raise e_type, e_value, e_traceback 

InvalidArgumentError: Field 1 in record 0 is not a valid int32: 0.766126609 

私はこの問題には関係しないと考えている情報がたくさんあります。明らかに問題は、私がプログラムに与えている多くのデータがdtype int32ではないということです。これは主に浮動小数点数です。私はtf.decode_csvtf.concatに明示的にdtype=float引数を設定するのと同様に、dtypeを変更するいくつかの試みを試みました。どちらもうまくいかなかった。それは無効な議論です。このコードが実際にデータを予測するかどうかはわかりません。私はそれがcol1が1か0になるかどうかを予測したいと思っていますし、コード内に実際にその予測をすることを示唆する何も表示されません。たぶん私は別のスレッドのためにその質問を保存します。どんな助けでも大歓迎です!

答えて

1

DTYPEを変更への答えは、あなたがCOL1をプリントアウトした場合、このメッセージが表示されます、それを行うの後だけSO-

record_defaults = [[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]] 

のようなデフォルト設定を変更することです。

Tensor("DecodeCSV_43:0", shape=TensorShape([]), dtype=float32) 

しかし、あなたは答えをおさらいするにwhich has been answered here.、に実行される別のエラーがあるが、この問題を回避するには、そのようtf.packtf.concatを変更することです。

features = tf.pack([col2, col3, col4, col5, col6, col7, col8, col9, col10, col11]) 
13

tf.decode_csv()へのインターフェイスは、少しトリッキーです。各列のdtypeは、record_defaults引数の対応する要素によって決まります。コード内のrecord_defaultsの値は、型としてtf.int32の各列として解釈され、浮動小数点データに遭遇するとエラーになります。次のようにrecord_defaultsを建設する、をを必要とされるすべての列を仮定し

4, 8, 9, 4.5 
2, 5, 1, 3.7 
2, 2, 2, 0.1 

浮動小数点列に続く3つの整数の列を含む、あなたは以下のCSVデータを持っていると言う

value = ... 

record_defaults = [tf.constant([], dtype=tf.int32), # Column 0 
        tf.constant([], dtype=tf.int32), # Column 1 
        tf.constant([], dtype=tf.int32), # Column 2 
        tf.constant([], dtype=tf.float32)] # Column 3 

col0, col1, col2, col3 = tf.decode_csv(value, record_defaults=record_defauts) 

assert col0.dtype == tf.int32 
assert col1.dtype == tf.int32 
assert col2.dtype == tf.int32 
assert col3.dtype == tf.float32 

record_defaultsの空の値は、値が必須であることを示します。 (例えば)2欄は、欠損値を持つことを許可されている場合は、次のように別の方法として、あなたはrecord_defaultsを定義します

record_defaults = [tf.constant([], dtype=tf.int32),  # Column 0 
        tf.constant([], dtype=tf.int32),  # Column 1 
        tf.constant([0], dtype=tf.int32), # Column 2 
        tf.constant([], dtype=tf.float32)] # Column 3 

構築し、のいずれかの値を予測するモデルを訓練する方法をあなたの質問の懸念の第二部入力データからの列現在のところ、このプログラムは単純に列を1つのテンソルに連結します(featuresという)。そのデータを解釈するモデルを定義して訓練する必要があります。そのような最も単純なアプローチの1つは線形回帰であり、このチュートリアルはlinear regression in TensorFlowで問題に適応できます。

関連する問題