2016-09-24 3 views
1

2次元の勾配降下の理解に問題があります。私は機能があると言うf(x,y)=x**2-xyここでdf/dx = 2x-ydf/dy = -xPythonでの2次元の勾配の低下

したがって、ポイントdf(2,3)の場合、出力ベクトルは[1、-2] .Tです。ベクトル[1、-2]が指しているところは、最も急な上昇(f(x、y)という別名の出力)の方向にあります。 私は固定ステップサイズを選択し、そのような大きさのステップがf(x、y)を最も大きくする方向を見つけるのはです。降下したい場合、-f(x、y)を最も速く増加させる方向を見つけたいと思いますか?

私の直感が正しいとすれば、これはどのようにコード化しますか?ポイント(x = 0、y = 5)から始まって、最小値を見つけるために勾配降下を実行したいとします。ここで

step_size = 0.01 
precision = 0.00001 #stopping point 
enter code here?? 

答えて

1

はmatplotlibの可視化と勾配降下の実装です:

import csv 
import math 
def loadCsv(filename): 
    lines = csv.reader(open(filename, "r")) 
    dataset = list(lines) 
    for i in range(len(dataset)): 
     dataset[i] = [float(x) for x in dataset[i]] 
    return dataset 

def h(o1,o2,x): 
    ans=o1+o2*x 
    return ans 

def costf(massiv,p1,p2): 
    sum1=0.0 
    sum2=0.0 
    for x,y in massiv: 
     sum1+=(math.pow(h(o1,o2,x)-y,2)) 
    sum2=(1.0/(2*len(massiv)))*sum1 
    return sum1,sum2 

def gradient(massiv,er,alpha,o1,o2,max_loop=1000): 
    i=0 
    J,e=costf(massiv,o1,o2) 
    conv=False 
    m=len(massiv) 
    while conv!=True: 
     sum1=0.0 
     sum2=0.0 
     for x,y in massiv: 
      sum1+=(o1+o2*x-y) 
      sum2+=(o1+o2*x-y)*x 
     grad0=1.0/m*sum1 
     grad1=1.0/m*sum2 

     temp0=o1-alpha*grad0 
     temp1=o2-alpha*grad1 
     print(temp0,temp1) 
     o1=temp0 
     o2=temp1 
     e=0.0 
     for x,y in massiv: 
      e+=(math.pow(h(o1,o2,x)-y,2)) 
     if abs(J-e)<=ep: 
      print('Successful\n') 
      conv=True 

     J=e 

     i+=1 
     if i>=max_loop: 
      print('Too much\n') 
      break 
    return o1,o2 


#data = massiv 
data=loadCsv('ex1data1.txt') 
o1=0.0 #temp0=0 
o2=1.0 #temp1=1 
alpha=0.01 
ep=0.01 
t0,t1=gradient(data,ep,alpha,o1,o2) 
print('temp0='+str(t0)+' \ntemp1='+str(t1)) 

x=35000 
while x<=70000: 
    y=h(t0,t1,x) 
    print('x='+str(x)+'\ny='+str(y)+'\n') 
    x+=5000 

maxx=data[0][0] 
for q,w in data: 
    maxx=max(maxx,q) 
maxx=round(maxx)+1 
line=[] 
ll=0 
while ll<maxx: 
    line.append(h(t0,t1,ll)) 
    ll+=1 
x=[] 
y=[] 
for q,w in data: 
    x.append(q) 
    y.append(w) 

import matplotlib.pyplot as plt 
plt.plot(x,y,'ro',line) 
plt.ylabel('some numbers') 
plt.show() 

matplotlibの出力:

ここからdowloadedすることができex1data1.txt

enter image description here

ex1data1.txt

コードは、Python 3.5でAnacondaディストリビューションでそのまま実行できます。