以下のnumpyの動作を理解できます。どのようにnumpy.where仕事ですか?
>>> a
array([[ 0. , 0. , 0. ],
[ 0. , 0.7, 0. ],
[ 0. , 0.3, 0.5],
[ 0.6, 0. , 0.8],
[ 0.7, 0. , 0. ]])
>>> argmax_overlaps = a.argmax(axis=1)
>>> argmax_overlaps
array([0, 1, 2, 2, 0])
>>> max_overlaps = a[np.arange(5),argmax_overlaps]
>>> max_overlaps
array([ 0. , 0.7, 0.5, 0.8, 0.7])
>>> gt_argmax_overlaps = a.argmax(axis=0)
>>> gt_argmax_overlaps
array([4, 1, 3])
>>> gt_max_overlaps = a[gt_argmax_overlaps,np.arange(a.shape[1])]
>>> gt_max_overlaps
array([ 0.7, 0.7, 0.8])
>>> gt_argmax_overlaps = np.where(a == gt_max_overlaps)
>>> gt_argmax_overlaps
(array([1, 3, 4]), array([1, 2, 0]))
Iは、[1,1]で0.7、0.7および0.8を理解[3,2]及び[4,0]私は0次及び1からなる各アレイれたタプル(array[1,3,4] and array[1,2,0])
を得これらの3つの要素のインデックス。次に私の理解が正しいことを見て他の例を試しました。
>>> np.where(a == [0.3])
(array([2]), array([1]))
0.3は[2,1]ですので、結果は期待通りです。それで試しました。
>>> np.where(a == [0.3, 0.5])
(array([], dtype=int64),)
??私は(array([2,2])、array([2,3]))を見たいと思っていました。上記の結果がなぜ表示されるのですか?
>>> np.where(a == [0.7, 0.7, 0.8])
(array([1, 3, 4]), array([1, 2, 0]))
>>> np.where(a == [0.8,0.7,0.7])
(array([1]), array([1]))
2番目の結果もわかりません。誰かがそれを私に説明してもらえますか?ありがとう。
'np.where((a == 0.3)|(a == 0.5))'と 'np.where((a == 0.7)|(a == 0.8))'を使用して正しい結果を得る。しかし、 'np.where(a == [0.7,0.8])'が 'DeprecationWarning'を投げている間に' np.where(a == [0.7、0.7、0.8]) 'が働く理由を知りません。バグのように見えます。 – Khris
'where'が予期しないインデックスを与えるとき、条件配列を見てください。 'where'はその配列が' True'であるところを伝えています。 – hpaulj