2017-12-22 18 views
2

私はKerasニューラルネットワークを使用しています。これまで見てきたすべてのチュートリアルのように、ハードコードされていない入力次元を自動的に設定したいと思います。どのように私はこれを達成することができますか?KerasRegressorで入力引数を指定する

マイコード:

from keras.models import Sequential 
from keras.layers import Dense 
from keras.wrappers.scikit_learn import KerasRegressor 
seed = 1 

X = df_input 
Y = df_res 

def baseline_model(x): 
    # create model 
    model = Sequential()  
    model.add(Dense(20, input_dim=x, kernel_initializer='normal', activation=relu)) 
    model.add(Dense(1, kernel_initializer='normal')) 
    # Compile model 
    model.compile(loss='mean_absolute_error', optimizer='adam') 
    return model 

inpt = len(X.columns) 
estimator = KerasRegressor(build_fn = baseline_model(inpt ) , epochs=2, batch_size=1000, verbose=2) 
estimator.fit(X,Y) 

そして、私が取得エラー:

def baseline_model(x): 
    def baseline_model(): 
     # create model 
     model = Sequential() 
     model.add(Dense(20, input_dim=x, kernel_initializer='normal', activation='relu')) 
     model.add(Dense(1, kernel_initializer='normal')) 
     # Compile model 
     model.compile(loss='mean_absolute_error', optimizer='adam') 
     return model 
    return baseline_model 

そしてとしてKerasRegressorを定義し、フィットを:

Traceback (most recent call last):

File ipython-input-2-49d765e85d15, line 20, in estimator.fit(X,Y)

TypeError: call() missing 1 required positional argument: 'inputs'

+0

このエラーは、エスティメータがあなたが呼び出せるメソッドではないために発生します。具体的には、scikit-learn APIを持つオブジェクトです。つまり、 'estimator.fit(X、Y)'で推定子を訓練し、 'estimator.predict(X、Y)'で予測を行うことができます。 – rvinas

+0

ありがとうございます。また、実用的な解決策がありますか? –

+0

'estimator(X、Y)'を 'estimator.fit(X、Y)'に置き換えてください。 – rvinas

答えて

0

次のように私はあなたのbaseline_modelを包むだろう:

estimator = KerasRegressor(build_fn=baseline_model(inpt), epochs=2, batch_size=1000, verbose=2) 
estimator.fit(X, Y) 

これにより、baseline_modelに入力ディメンションをハードコードする必要がなくなります。

+0

ありがとう、私はこの解決策について考えたことはありません。私は機能のレイヤーを追加することが効果を発揮する理由についてもわかりません。 –

+0

私は助けることができてうれしいです。ここでのことは、 'KerasRegressor'は、モデル自体ではなく、モデルを構築する呼び出し可能コードを期待していることです。このように関数をラップすることで、指定された 'input_dim'を使ってビルド関数を(呼び出さずに)返すことができます。 – rvinas

関連する問題