2012-12-06 7 views
8

デフォルトでは、配列ベースがpickleされていてもnumpyビュー配列をpicklingするとビューの関係が失われます。私の状況は、私はいくつかの複雑なコンテナオブジェクトをピクルされているということです。場合によっては、含まれているデータの中には、他のビューがあります。各ビューの独立した配列を保存するだけでなく、スペースの損失だけでなく、再ロードされたデータがビューの関係を失ってしまいます。pickling時のnumpyビューの保存

簡単な例は次のようになります(しかし、私の場合には、コンテナは、辞書よりも複雑です):

import numpy as np 
import cPickle 

tmp = np.zeros(2) 
d1 = dict(a=tmp,b=tmp[:]) # d1 to be saved: b is a view on a 

pickled = cPickle.dumps(d1) 
d2 = cPickle.loads(pickled) # d2 reloaded copy of d1 container 

print 'd1 before:', d1 
d1['b'][:] = 1 
print 'd1 after: ', d1 

print 'd2 before:', d2 
d2['b'][:] = 1 
print 'd2 after: ', d2 

印刷した:

d1 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])} 
d1 after: {'a': array([ 1., 1.]), 'b': array([ 1., 1.])} 
d2 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])} 
d2 after: {'a': array([ 0., 0.]), 'b': array([ 1., 1.])} # not a view anymore 

は私の質問:

( 1)それを保存する方法はありますか? (2)(より良い)ベースは、(1)私はなど、__reduce_ex___setstate__を変更することにより、いくつかの方法があるかもしれないと思うのために

を漬けされている場合にのみ、それを行う方法がある...のビュー配列。しかし、私は今はこれに自信を持っていません。 (2)私は分かりません。

答えて

7

これはNumPy固有のものではありません。なぜなら、ベース配列をピクルスするのは必ずしも意味をなさないので、pickleは他のオブジェクトもそのAPIの一部としてピクシングされているかどうかを確認する機能を公開していません。

しかし、この種のチェックは、NumPy配列のカスタムコンテナで行うことができます。たとえば、次のように

import numpy as np 
import pickle 

def byte_offset(array, source): 
    return array.__array_interface__['data'][0] - np.byte_bounds(source)[0] 

class SharedPickleList(object): 
    def __init__(self, arrays): 
     self.arrays = list(arrays) 

    def __getstate__(self): 
     unique_ids = {id(array) for array in self.arrays} 
     source_arrays = {} 
     view_tuples = {} 
     for array in self.arrays: 
      if array.base is None or id(array.base) not in unique_ids: 
       # only use views if the base is also being pickled 
       source_arrays[id(array)] = array 
      else: 
       view_tuples[id(array)] = (array.shape, 
              array.dtype, 
              id(array.base), 
              byte_offset(array, array.base), 
              array.strides) 
     order = [id(array) for array in self.arrays] 
     return (source_arrays, view_tuples, order) 

    def __setstate__(self, state): 
     source_arrays, view_tuples, order = state 
     view_arrays = {} 
     for k, view_state in view_tuples.items(): 
      (shape, dtype, source_id, offset, strides) = view_state 
      buffer = source_arrays[source_id].data 
      array = np.ndarray(shape, dtype, buffer, offset, strides) 
      view_arrays[k] = array 
     self.arrays = [source_arrays[i] 
         if i in source_arrays 
         else view_arrays[i] 
         for i in order] 

# unit tests 
def check_roundtrip(arrays): 
    unpickled_arrays = pickle.loads(pickle.dumps(
     SharedPickleList(arrays))).arrays 
    assert all(a.shape == b.shape and (a == b).all() 
       for a, b in zip(arrays, unpickled_arrays)) 

indexers = [0, None, slice(None), slice(2), slice(None, -1), 
      slice(None, None, -1), slice(None, 6, 2)] 

source0 = np.random.randint(100, size=10) 
arrays0 = [np.asarray(source0[k1]) for k1 in indexers] 
check_roundtrip([source0] + arrays0) 

source1 = np.random.randint(100, size=(8, 10)) 
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers] 
check_roundtrip([source1] + arrays1) 

これは大幅なスペースの節約につながる:

source = np.random.rand(1000) 
arrays = [source] + [source[n:] for n in range(99)] 
print(len(pickle.dumps(arrays, protocol=-1))) 
# 766372 
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1))) 
# 11833 
関連する問題