Kerasでチェックポイント機能を実装する際に助けが必要です。私は大規模なデータセットを訓練するつもりです。まず、虹彩の花のデータセットを使用してモデルを訓練しました:http://machinelearningmastery.com/multi-class-classification-tutorial-keras-deep-learning-library/Kerasのチェックポイント深い学習モデル
私自身のデータセットはそれによく似ています。違いは私のデータセットはもっとですより大きい。チェックポイント機能については
:http://machinelearningmastery.com/check-point-deep-learning-models-keras/
私はピマ・インディアンのデータセットを使用した例を理解します。 今、私はiris-flowerスクリプトで同じチェックポイント機能を実装しようとしています。これまで私が試したことがあります。
import numpy
from pandas import *
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import cross_val_score, KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
from keras.callbacks import ModelCheckpoint
seed = 7
numpy.random.seed(seed)
dataframe = read_csv("iris.csv", header=None)
dataset = dataframe.values
X = dataset[:,0:4].astype(float)
Y = dataset[:,4]
# encode class value as integers
encoder = LabelEncoder()
encoder.fit(Y)
encoded_Y = encoder.transform(Y)
dummy_y = np_utils.to_categorical(encoded_Y)
def baseline_model():
model = Sequential()
model.add(Dense(4, input_dim=4, init='normal', activation='relu'))
model.add(Dense(3, init='normal', activation='sigmoid'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
estimator = KerasClassifier(build_fn=baseline_model, validation_split=0.33, nb_epoch=200, batch_size=5, callbacks=callbacks_list, verbose=0)
kfold = KFold(n_splits=10, shuffle=True, random_state=seed)
results = cross_val_score(estimator, X, dummy_y, cv=kfold)
print("Baseline: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))
このスクリプトでは、次のエラーが発生しました。私はそれをトラブルシューティングする方法を知らない、またはスクリプト内の私の配置が間違っている。
RuntimeError: Cannot clone object <keras.wrappers.scikit_learn.KerasClassifier object at 0x10e120fd0>, as the constructor does not seem to set parameter callbacks
誰かが私にこれを手伝ってくれることを願っています。ありがとうございました。
あなたがエラーの原因となっているライン知っていますか? – jdehesa