2017-07-03 6 views
1

の数(356)と一致しません:CNTK:シーケンス開始フラグの数(1)エラー取得シーケンス

trainer.train_minibatch(arguments) 

minibatch_dimで

ValueError: Value::Create:: The number (1) of sequence start flags does not match the number (356) of sequences, In

def create_inputs(output_length): 
    batch_axis = ct.Axis.default_batch_axis() 
    input_seq_axis = ct.Axis('inputAxis') 

    input_dynamic_axes = [batch_axis, input_seq_axis] 
    input_sequence = ct.input_variable(shape=1, dynamic_axes=input_dynamic_axes) 
    label_sequence = ct.input_variable(shape=output_length, dynamic_axes=input_dynamic_axes) 

    return input_sequence, label_sequence 

def make_model_and_train(model_root_path, epochs, minibatch_dim, output_length, test_minibatches): 
    vals = get_data('data.csv') 
    train, test, minibatches_per_epoch = make_sets(vals, minibatch_dim, output_length, test_minibatches) 

    input_sequence, label_sequence = create_inputs(output_length) 

    model = create_model(output_length) 

    z = model(input_sequence) 

    ce = ct.squared_error(z, label_sequence) 

    lt_per_sample = ct.learning_rate_schedule([(7000, 0.001),(10000, 0.0005)], ct.UnitType.sample, minibatches_per_epoch) 
    clipping_threshold_per_sample = 2 
    gradient_clipping_with_truncation = True 

    learner = ct.momentum_sgd(z.parameters, lt_per_sample, ct.momentum_as_time_constant_schedule(1100),gradient_clipping_threshold_per_sample = clipping_threshold_per_sample, gradient_clipping_with_truncation = gradient_clipping_with_truncation) 
    progress_printer = ct.logging.ProgressPrinter(100, tag = 'Training') 
    trainer = ct.Trainer(z, (ce), learner, progress_printer) 

    print ("Running %d epochs with %d minibatches per epoch" % (epochs, minibatches_per_epoch)) 
    print('') 

    for e in range(epochs): 
     mask = [True] 
     for b in range(minibatches_per_epoch): 
      arguments = ({input_sequence: train[0][b], label_sequence: train[1][b]}, mask) 
      mask = [False] 
      trainer.train_minibatch(arguments) 

      global_minibatch = e*minibatches_per_epoch + b 
     if e % 100 == 0 and e != 0: 
      model_filename = '%s/%s/%s_epoch_%g.dnn' % (model_root_path, name, name, e+1) 
      z.save_model(model_filename) 
      print("Saved model to '%s'" % model_filename) 

を、356であります output_lengthは356

私は主に私の他のLSTMのコードをコピーしましたが、これはうまくいったr。

これを修正するにはどうすればよいですか?

答えて

0

i番目のシーケンスが新しいシーケンスである場合、またはiが継続する場合は、リスト内のi番目のエレメントがTrueであるように、ミニバッファ内の各シーケンスに対してマスク値のリストを提供する必要があります - 前のミニブッチからのシーケンス

+0

しかし、どうすればいいですか?あなたは例を提供したり、私のコードを変更していただけますか?申し訳ありませんが、CNTKのドキュメントは不足しています。 – arduano

関連する問題