2016-08-05 4 views
0

GoogleとTensorflowによってリリースされた新しいモデル(つまり、wide_n_deep学習)に興奮しています。だから私はthe tutorial exampleを実行してそれを試してみようとしています。Tensorflow Wide&Deepチュートリアルバッチを使用した例

マシンラーニングの一般的なトリックとして、トレーニングデータセット全体が大きくなるとバッチラーニングが重要になります。

index_in_epoch = 0 
num_examples = df_train.shape[0] 
for i in xrange(FLAGS.train_steps): 
    startTime = datetime.now() 
    print("start step %i" %i) 
    start = index_in_epoch 
    index_in_epoch += batch_size 
    if index_in_epoch > num_examples: 
     if start < num_examples: 
      m.fit(input_fn=lambda: input_fn(df_train[start:num_examples], steps=1) 
     df_train.reindex(np.random.permutation(df_train.index) 
     start = 0 
     index_in_epoch = batch_size 
    if i%5 == 1: 
     results = m.evaluate(input_fn=lambda: input_fn(df_test), steps = 1) 
     for key in sorted(results): 
      print("%s: %s %(key, results[key])) 
    end = index_in_epoch 
    m.fit(input_fn=lambda: input_fn(df_train[start:end], steps=1) 

簡単に言えば、私はバッチでバッチを設定し、全トレーニングデータを反復処理し、各バッチのために、私は「フィット呼び出します。だから私は、次のように、バッチ学習を取得するにはチュートリアルの例を学習wide_n_deep提供を変更しようモデルを再トレーニングする機能を備えています。

この単純な戦略の問題は、処理時間が非常に遅いことです(たとえば、400万レコードのデータセットを100回反復し、バッチサイズを100k、トレーニングおよび評価時間約1週間)。だから私はバッチ学習を適切な方法で使用しているのは本当に疑問です。

wide_n_deepラーニングモデルを使用してプレイするときに、バッチラーニングを処理するために才能があなたの経験を共有できるのであれば、私は感謝します。

+0

これらのトレーニング時間は珍しくありません。 Imagenetは約2週間かかります。 –

答えて

0

すべてのフィット/評価コールは、グラフとセッションを作成し、操作を実行します。ループの中でそれを行うと、遅くなります。 これを高速化するには、テンソルバッチと呼ばれるinput_fnを提供する必要があります。 データフレームからデータを読み取る場合は、to_feature_columns_and_input_fn のファイルを読み込みます。tf.Exampleを保持するファイルからデータを読み取る場合は、input_fnread_batch_examplesのようなものを使用できます。

+0

あなたの答えをありがとう!テンソルフローの初心者として、あなたのコメントを読んだ後、バッチ処理のアプローチが遅い理由を理解することができます。あなたの提案にしたがって、バッチごとにバッチを読み込む方法について、より具体的な例を提供できる場合は、より役に立ちます。ご協力いただきありがとうございます。 –

関連する問題