2017-09-15 4 views
2

fit_generator()は、必要以上に多くのステップを実行します。
steps_per_epoch=100と設定しました。 iとkはともに0から始まりますが、トレーニングの最後にk = 109が印刷されます。この状況は、検証データが追加された場合にのみ発生します。Keras fit_generatorのバグは、それ以上のステップを実行している必要があります。

def data_generate(xfd, yfd, x_line_offset, y_line_offset): 

    while True: 
     k = 0 

     x_line_offset, y_line_offset = shuffle_list(x_line_offset, y_line_offset) 

     for i in range(100): 
      print('i = {}'.format(i)) 
      print('k = {}'.format(k)) 
      k += 1 

      x_train = get_line_by_offset(xfd, x_line_offset[i]) 
      x_train = rescaling(x_train, 0, 65535, 0, 1) 
      y_train = get_line_by_offset(yfd, y_line_offset[i]) 

      yield x_train, y_train 

train_generator = data_generate(xfd_train, yfd_train, x_train_line_offset, y_train_line_offset) 
validation_generator = data_generate(xfd_valid, yfd_valid, x_valid_line_offset, y_valid_line_offset) 

model.fit_generator(train_generator, steps_per_epoch=100, 
        validation_data=validation_generator, 
        validation_steps=len(fix_y_valid_line_offset), epochs=1) 

k = 109が印刷されるので、それ以上のステップを実行すると仮定します。バグかどうかわかりません。しかし、ケラスのログメッセージはk = 99の後には表示されません。 enter image description here

答えて

3

ここにバグはありません。実装の詳細です。関数fit_generator()にはデフォルトの引数max_queue_size=10があります。 train_generatorvalidation_generatorのバッチは、モデルにフィット/評価する前にキューに挿入されます。

最初のエポックが終了すると、100個のバッチが生成されます(k = 99)。ただし、ジェネレーターはキューを埋めるために10個のバッチを生成し続けます。だから、k = 100k = 109が表示されています。同時に、検証プロセスが開始されますので、validation_generatorからのk = 0, ...も表示されます。

関連する問題