2017-01-28 16 views
1

回帰ターゲット値が-1.0と1.0の間の深いニューラルネットワークを訓練するときに得られる以下のログに対して、どちらも、一見良いニュースで検証損失の減少を、トレーニング深いニューラルネットワークのトレーニングトレースの解釈:非常に低いトレーニング損失とさらに低い検証損失

____________________________________________________________________________________________________ 
Layer (type)      Output Shape   Param #  Connected to 
==================================================================================================== 
cropping2d_1 (Cropping2D)  (None, 138, 320, 3) 0   cropping2d_input_1[0][0] 
____________________________________________________________________________________________________ 
lambda_1 (Lambda)    (None, 66, 200, 3) 0   cropping2d_1[0][0] 
____________________________________________________________________________________________________ 
lambda_2 (Lambda)    (None, 66, 200, 3) 0   lambda_1[0][0] 
____________________________________________________________________________________________________ 
convolution2d_1 (Convolution2D) (None, 31, 98, 24) 1824  lambda_2[0][0] 
____________________________________________________________________________________________________ 
spatialdropout2d_1 (SpatialDropo (None, 31, 98, 24) 0   convolution2d_1[0][0] 
____________________________________________________________________________________________________ 
convolution2d_2 (Convolution2D) (None, 14, 47, 36) 21636  spatialdropout2d_1[0][0] 
____________________________________________________________________________________________________ 
spatialdropout2d_2 (SpatialDropo (None, 14, 47, 36) 0   convolution2d_2[0][0] 
____________________________________________________________________________________________________ 
convolution2d_3 (Convolution2D) (None, 5, 22, 48)  43248  spatialdropout2d_2[0][0] 
____________________________________________________________________________________________________ 
spatialdropout2d_3 (SpatialDropo (None, 5, 22, 48)  0   convolution2d_3[0][0] 
____________________________________________________________________________________________________ 
convolution2d_4 (Convolution2D) (None, 3, 20, 64)  27712  spatialdropout2d_3[0][0] 
____________________________________________________________________________________________________ 
spatialdropout2d_4 (SpatialDropo (None, 3, 20, 64)  0   convolution2d_4[0][0] 
____________________________________________________________________________________________________ 
convolution2d_5 (Convolution2D) (None, 1, 18, 64)  36928  spatialdropout2d_4[0][0] 
____________________________________________________________________________________________________ 
spatialdropout2d_5 (SpatialDropo (None, 1, 18, 64)  0   convolution2d_5[0][0] 
____________________________________________________________________________________________________ 
flatten_1 (Flatten)    (None, 1152)   0   spatialdropout2d_5[0][0] 
____________________________________________________________________________________________________ 
dropout_1 (Dropout)    (None, 1152)   0   flatten_1[0][0] 
____________________________________________________________________________________________________ 
activation_1 (Activation)  (None, 1152)   0   dropout_1[0][0] 
____________________________________________________________________________________________________ 
dense_1 (Dense)     (None, 100)   115300  activation_1[0][0] 
____________________________________________________________________________________________________ 
dropout_2 (Dropout)    (None, 100)   0   dense_1[0][0] 
____________________________________________________________________________________________________ 
dense_2 (Dense)     (None, 50)   5050  dropout_2[0][0] 
____________________________________________________________________________________________________ 
dense_3 (Dense)     (None, 10)   510   dense_2[0][0] 
____________________________________________________________________________________________________ 
dropout_3 (Dropout)    (None, 10)   0   dense_3[0][0] 
____________________________________________________________________________________________________ 
dense_4 (Dense)     (None, 1)    11   dropout_3[0][0] 
==================================================================================================== 
Total params: 252,219 
Trainable params: 252,219 
Non-trainable params: 0 
____________________________________________________________________________________________________ 
None 
Epoch 1/5 
19200/19200 [==============================] - 795s - loss: 0.0292 - val_loss: 0.0128 
Epoch 2/5 
19200/19200 [==============================] - 754s - loss: 0.0169 - val_loss: 0.0120 
Epoch 3/5 
19200/19200 [==============================] - 753s - loss: 0.0161 - val_loss: 0.0114 
Epoch 4/5 
19200/19200 [==============================] - 723s - loss: 0.0154 - val_loss: 0.0100 
Epoch 5/5 
19200/19200 [==============================] - 1597s - loss: 0.0151 - val_loss: 0.0098 

:0.001および4800分の19200訓練/検証サンプルの学習率を持ちます。しかし、どのようにして、最初のエポック時にトレーニングの喪失がどうして低く抑えられますか?また、検証の損失をさらに低く抑えることができますか?それは私のモデルやトレーニングの設定のどこかに体系的なエラーがあることを示していますか?

答えて

4

実際には、訓練の損失よりも小さい妥当性の喪失は、考えられるほどまれな現象ではありません。例えば、検証データ内のすべての例がである場合は、トレーニングセットの例でとなり、ネットワークはデータセットの実際の構造を簡単に学習しました。

データの構造があまり複雑でない場合は、非常に頻繁に発生します。実際、あなたを驚かせた最初の時代の後の損失の小さな値は、これがあなたの場合に起こったという手がかりかもしれません。

あなたの損失は何も指定していませんが、あなたの仕事が回帰だと仮定した場合、私はmseと推測しました。この場合、平均二乗誤差は0.01真の値と実際の値との平均ユークリッド距離は、0.1と等しく、値の直径が[-1, 1]5%であることが分かります。だから、このエラーは実際には小さいですか?

また、1エポック中に分析されるバッチの数も指定していません。データの構造が複雑でなく、バ​​ッチサイズが小さい場合は、データをうまく学習するのに十分な時間が必要でした。

モデルが訓練されているかどうかを確認するには、correlation plotをプロットすることをお勧めします。y_predをプロットしてください。 Y軸上のX軸およびy_true。実際にモデルが実際にどのように訓練されているかが実際に分かります。

EDIT:Neilが述べたように、小さな検証エラーの背後にはさらに多くの理由があるかもしれません。また、5エポックが90分を超えないようにして、モデルの結果を確認することもできます。たとえば、次のような古典的なクロスバリデーションスキーマを使用するとよいでしょう。 5つ折り。これは、あなたのデータセットの場合、あなたのモデルがうまく機能していることを保証します。

+0

その洞察に感謝します。あなたは正しかったです:私はmseを使用しています。バッチサイズは128です。 – user1934212

+0

私はあなたにもう一度ヒントを与えました。あなたはあなたの投稿にこのプロットを追加することができます。 –

+2

予期せず低い妥当性評価スコアの可能性のある理由の1つは、訓練対cvサンプルの分離が不十分であることです。たとえば、入力が画像で、類似のシナリオで複数の画像が撮影された場合(またはtrain/cvに分割する前にデータ拡大*を使用する場合)、cvセットがトレーニングに似ている可能性があり、cv結果が不正確になります。これはOPのポストから直接的には示されていませんが、それはチェックして防御するものです。この問題を解決するには、相関関係のある例をまとめて(セットで)確実に保持する必要があります。すべてを列車またはCVで設定する必要があります。 –

関連する問題