1
の数(356)と一致しません:CNTK:シーケンス開始フラグの数(1)エラー取得シーケンス
trainer.train_minibatch(arguments)
minibatch_dimで
ValueError: Value::Create:: The number (1) of sequence start flags does not match the number (356) of sequences, In
def create_inputs(output_length):
batch_axis = ct.Axis.default_batch_axis()
input_seq_axis = ct.Axis('inputAxis')
input_dynamic_axes = [batch_axis, input_seq_axis]
input_sequence = ct.input_variable(shape=1, dynamic_axes=input_dynamic_axes)
label_sequence = ct.input_variable(shape=output_length, dynamic_axes=input_dynamic_axes)
return input_sequence, label_sequence
def make_model_and_train(model_root_path, epochs, minibatch_dim, output_length, test_minibatches):
vals = get_data('data.csv')
train, test, minibatches_per_epoch = make_sets(vals, minibatch_dim, output_length, test_minibatches)
input_sequence, label_sequence = create_inputs(output_length)
model = create_model(output_length)
z = model(input_sequence)
ce = ct.squared_error(z, label_sequence)
lt_per_sample = ct.learning_rate_schedule([(7000, 0.001),(10000, 0.0005)], ct.UnitType.sample, minibatches_per_epoch)
clipping_threshold_per_sample = 2
gradient_clipping_with_truncation = True
learner = ct.momentum_sgd(z.parameters, lt_per_sample, ct.momentum_as_time_constant_schedule(1100),gradient_clipping_threshold_per_sample = clipping_threshold_per_sample, gradient_clipping_with_truncation = gradient_clipping_with_truncation)
progress_printer = ct.logging.ProgressPrinter(100, tag = 'Training')
trainer = ct.Trainer(z, (ce), learner, progress_printer)
print ("Running %d epochs with %d minibatches per epoch" % (epochs, minibatches_per_epoch))
print('')
for e in range(epochs):
mask = [True]
for b in range(minibatches_per_epoch):
arguments = ({input_sequence: train[0][b], label_sequence: train[1][b]}, mask)
mask = [False]
trainer.train_minibatch(arguments)
global_minibatch = e*minibatches_per_epoch + b
if e % 100 == 0 and e != 0:
model_filename = '%s/%s/%s_epoch_%g.dnn' % (model_root_path, name, name, e+1)
z.save_model(model_filename)
print("Saved model to '%s'" % model_filename)
を、356であります output_lengthは356
私は主に私の他のLSTMのコードをコピーしましたが、これはうまくいったr。
これを修正するにはどうすればよいですか?
しかし、どうすればいいですか?あなたは例を提供したり、私のコードを変更していただけますか?申し訳ありませんが、CNTKのドキュメントは不足しています。 – arduano