2017-11-29 9 views
2
def closest_centroid(points, centroids): 
    """returns an array containing the index to the nearest centroid for each point""" 
    distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2)) 
    return np.argmin(distances, axis=0) 

誰かがこの関数の正確な働きを説明できますか?私は現在のようになっています:このPython関数で何が起こっているのか理解しようとしています

31998888119  0.94  34 
23423423422  0.45  43 
.... 

などです。このnumpy配列では、points[1]はロングIDであり、points[2]0.94であり、points[3]は最初のエントリでは34となります。

重心は、この特定の配列からわずかランダムな選択である:

def initialize_centroids(points, k): 
    """returns k centroids from the initial points""" 
    centroids = points.copy() 
    np.random.shuffle(centroids) 
    return centroids[:k] 

は今、私は(再び最初の列を無視してIDの最初の列を無視してpointsの値からユークリッド距離を取得し、centroidsしたいです)。私は文の文脈を正確に理解していません。distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))。なぜ新しい軸の解読があるのですが、正確に3列目を合計しますか:np.newaxis?また、どの軸に沿ってnp.argminを動作させるはずですか?

答えて

0

寸法を考えるのに役立ちます。 k=4とし、10点あると仮定しましょう。したがって、points.shape = (10,3)です。

次に、centroids = initialize_centroids(points, 4)は、次元が(4,3)のオブジェクトを返します。

が内側からこのラインを破るのをしてみましょう:

distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))

  1. 私たちは、各点から各重心を減算します。 pointscentroidsは2次元なので、それぞれpoints - centroidは2次元です。重心が1つしかない場合は、大丈夫です。しかし、我々は4重心を持っています!したがって、各重心についてpoints - centroidsを実行する必要があります。したがって、これを保存するには別の次元が必要です。したがって、np.newaxisの追加。

  2. 私たちは距離であるので正方形にします。したがって、ネガティブをポジティブに変換したい(そしてユークリッド距離を最小にしているからです)。

  3. 第3列を合計していません。実際には、各重心について点と重心の差を合計しています。

  4. np.argmin()は、最小距離で重心を見つける。したがって、各重心について、各点について、最小のインデックスを見つける(したがって、minの代わりにargmin)。その指数はその点に割り当てられた重心です。ここ

は一例であり:

points = np.array([ 
[ 1, 2, 4], 
[ 1, 1, 3], 
[ 1, 6, 2], 
[ 6, 2, 3], 
[ 7, 2, 3], 
[ 1, 9, 6], 
[ 6, 9, 1], 
[ 3, 8, 6], 
[ 10, 9, 6], 
[ 0, 2, 0], 
]) 

centroids = initialize_centroids(points, 4) 

print(centroids) 
array([[10, 9, 6], 
    [ 3, 8, 6], 
    [ 6, 2, 3], 
    [ 1, 1, 3]]) 

distances = (pts - centroids[:, np.newaxis])**2 

print(distances) 
array([[[ 81, 49, 4], 
    [ 81, 64, 9], 
    [ 81, 9, 16], 
    [ 16, 49, 9], 
    [ 9, 49, 9], 
    [ 81, 0, 0], 
    [ 16, 0, 25], 
    [ 49, 1, 0], 
    [ 0, 0, 0], 
    [100, 49, 36]], 

    [[ 4, 36, 4], 
    [ 4, 49, 9], 
    [ 4, 4, 16], 
    [ 9, 36, 9], 
    [ 16, 36, 9], 
    [ 4, 1, 0], 
    [ 9, 1, 25], 
    [ 0, 0, 0], 
    [ 49, 1, 0], 
    [ 9, 36, 36]], 

    [[ 25, 0, 1], 
    [ 25, 1, 0], 
    [ 25, 16, 1], 
    [ 0, 0, 0], 
    [ 1, 0, 0], 
    [ 25, 49, 9], 
    [ 0, 49, 4], 
    [ 9, 36, 9], 
    [ 16, 49, 9], 
    [ 36, 0, 9]], 

    [[ 0, 1, 1], 
    [ 0, 0, 0], 
    [ 0, 25, 1], 
    [ 25, 1, 0], 
    [ 36, 1, 0], 
    [ 0, 64, 9], 
    [ 25, 64, 4], 
    [ 4, 49, 9], 
    [ 81, 64, 9], 
    [ 1, 1, 9]]]) 

print(distances.sum(axis=2)) 
array([[134, 154, 106, 74, 67, 81, 41, 50, 0, 185], 
    [ 44, 62, 24, 54, 61, 5, 35, 0, 50, 81], 
    [ 26, 26, 42, 0, 1, 83, 53, 54, 74, 45], 
    [ 2, 0, 26, 26, 37, 73, 93, 62, 154, 11]]) 

# The minimum of the first 4 centroids is index 3. The minimum of the second 4 centroids is index 3 again. 

print(np.argmin(distances.sum(axis=2), axis=0)) 
array([3, 3, 1, 2, 2, 1, 1, 1, 0, 3]) 
関連する問題