0

私はMATLABで確率的勾配降下を実装しようとしていますが、どこかで間違っています。私はコンバージェンスをチェックする方法が間違っていると思うかもしれません(各反復でエバリュエーターを更新する方法はわかりませんでしたが)。私は基本的な線形データに合わせるよう努力してきましたが、かなり遠い結果を得ています。私はいくつかの助けを得ることを望んでいます。誰かが私がどこに間違っているのか、なぜこれが正しく動作していないのかを指摘できますか?MATLABの確率勾配降下アルゴリズム

ありがとうございます!ここで

は設定データと一般的なコードです:

clear all; 
close all; 
clc 

N_features = 2; 
d = 100; 
m = 100; 

X_train = 10*rand(d,1); 
X_test = 10*rand(d,1); 
X_train = [ones(d,1) X_train]; 
X_test = [ones(d,1) X_test]; 

y_train = 5 + X_train(:,2) + 0.5*randn(d,1); 
y_test = 5 + X_test(:,2) + 0.5*randn(d,1); 

gamma = 0.01; %learning rate 

[sgd_est_train,sgd_est_test,SSE_train,SSE_test,w] = stoch_grad(d,m,N_features,X_train,y_train,X_test,y_test,gamma); 

figure(1) 
plot(X_train(:,2),sgd_est_train,'ro',X_train(:,2),y_train,'go') 

figure(2) 
plot(X_test(:,2),sgd_est_test,'bo',X_test(:,2),y_test,'go') 

と実際にSGDを実装する関数は次のとおりです。

% stochastic gradient descent 

function [sgd_est_train,sgd_est_test,SSE_train,SSE_test,w] = stoch_grad(d,m,N_features,X_train,y_train,X_test,y_test,gamma) 

    epsilon = 0.01; %convergence criterion 
    max_iter = 10000; 

    w0 = zeros(N_features,1); %initial guess 
    w = zeros(N_features,1); %for convenience 

    x = zeros(d,1); 
    z = zeros(d,1); 

    for jj=1:max_iter; 
     for kk=1:d; 
      x = X_train(kk,:)'; 
      z = gamma*((w0'*x-y_train(kk))*x); 
      w = w0 - z; 
     end 

     if norm(w0-w,2)<epsilon 
      break; 
     else 
      w0 = w; 
     end 
    end 

    sgd_est_test = zeros(m,1); 
    sgd_est_train = zeros(d,1); 

    for ll=1:m; 
     sgd_est_test(ll,1) = w'*X_test(ll,:)'; 
    end 

    for ii=1:d; 
     sgd_est_train(ii,1) = w'*X_train(ii,:)'; 
    end 

    SSE_test = sum((sgd_est_test - y_test).^2); 
    SSE_train = sum((sgd_est_train - y_train).^2); 

end 
+0

あなたは変数を記述できますか?彼らはどういう意味ですか?私は手伝ってくれるかもしれませんが、私はあなたが物事をどのようにしているのか正確に分からない限りできません。 –

+0

申し訳ありません:あなたの更新式は何ですか?あなたは私のためのグラデーションを数式の形で書くことができますか? –

+0

私はそれをちょっと試しました。私はあなたの更新が間違っている、つまり 'z = ....'という行を98%忠告しています。あなたがデータからノイズを除去し、正確な解を与えるなら、それは働きます。 'w0(2)'ビットを変更すると正しい値を見つけることができますが、 'w0(1)'を変更すると収束しません。 –

答えて

0

私は0.001で学習率を下げる試み、これになりました:
あなたのアルゴリズムは、y = aの代わりにy = a という形式の推定値を生成することを示しています。 x + b(なんらかの理由で定数項を無視する)、収束させるために学習率を下げる必要があります。

+0

私はy切片を取り出すことも試みましたが、あなたは正しいです - それはy切片値を全く得られません!しかし、私はまた、学習率が低いので、それを複数回実行しました。場合によっては助けにも見えますが、他の人にはまだ奇妙な見積もりが残っています。 – poppy3345

関連する問題