2017-10-29 1 views
0

%rを直接使用することはないので、そのエラーについては何の手掛かりもありません。私はどこかに型の混在があると思うが、どこで起こるのか分からない。コードをスピードアップするための他の提案はさておき、非常に感謝しています。Numbaは次のように答えています: 'TypingError:%rは同質のシーケンスでは許可されていません'

import numpy as np 
from numba import jit, float64 

c = 3*10**8 
epsilon = 8.854187817 * 10**(-12) 
mu = 4*np.pi *10**(-7) 

@jit(nopython=True) 
def cross(vec1, vec2): 

    result = np.array([0.,0.,0.]) 

    a1, a2, a3 = vec1[0],vec1[1], vec1[2] 
    b1, b2, b3 = vec2[0], vec2[1], vec2[2] 

    result[0] = a2 * b3 - a3 * b2 
    result[1] = a3 * b1 - a1 * b3 
    result[2] = a1 * b2 - a2 * b1 

    return result  

@jit(float64[:,:](float64[:],float64,float64,float64[:],float64[:],float64[:]), nopython = True) 

def jit_EM_field(position,length,ladung,velocity,acceleration,R): 
    #using solutions to lienard wiechert potential 
    radius = np.linalg.norm(R - position) 
    if radius != 0: 
     unitradius = (R - position)/radius 
    else: 
     unitradius = np.array([0.,0.,0.]) 

    if radius != 0 and np.dot(unitradius, velocity)!=1: 
     charge  = ladung/((1 - np.dot(unitradius, velocity)/c)** 3) 


     if radius < length: 
      radius = length 

     radius2  = radius ** 2 

     velocity_in_c = velocity/c 

     oneMinusV2 = 1 - np.dot(velocity_in_c, velocity_in_c) 
     uMinusV  = unitradius - velocity_in_c    
     aCrossUmV = cross(uMinusV, acceleration) 
     Eleft  = (oneMinusV2 * (unitradius - velocity_in_c))/radius2 
     Eright  = cross(unitradius, aCrossUmV)/(radius*c**2) 
     E   = (charge/(4*np.pi*epsilon)) * (Eleft - Eright) 

     B   = cross(unitradius/c, ((mu*epsilon*charge*c**2) * (Eleft - Eright))) 

     EM_field = np.array([E,B], dtype = float) 
    else: 
     EM_field = np.zeros((2,3), dtype = float) 

    return EM_field 

jit_EM_field(np.array([0.,1.,0.]),1.,0.1,np.array([0.,1.,0.]),np.array([0.,1.,0.]) 
,np.array([7.2,5.6,0.1])) 

ここに完全なエラーメッセージがあります。

runfile('C:/Users/Elios/testingjit.py', wdir='C:/Users/Elios') 
Traceback (most recent call last): 

    File "<ipython-input-26-221208a798d4>", line 1, in <module> 
    runfile('C:/Users/Elios/testingjit.py', wdir='C:/Users/Elios') 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\spyder\utils\site\sitecustomize.py", line 866, in runfile 
    execfile(filename, namespace) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\spyder\utils\site\sitecustomize.py", line 102, in execfile 
    exec(compile(f.read(), filename, 'exec'), namespace) 

    File "C:/Users/Elios/testingjit.py", line 32, in <module> 
    @jit(float64[:,:](float64[:],float64,float64,float64[:],float64[:],float64[:]), nopython = True) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\decorators.py", line 176, in wrapper 
    disp.compile(sig) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\dispatcher.py", line 531, in compile 
    cres = self._compiler.compile(args, return_type) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\dispatcher.py", line 80, in compile 
    flags=flags, locals=self.locals) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 725, in compile_extra 
    return pipeline.compile_extra(func) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 369, in compile_extra 
    return self.compile_bytecode(bc, func_attr=self.func_attr) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 378, in compile_bytecode 
    return self._compile_bytecode() 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 690, in _compile_bytecode 
    return self._compile_core() 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 677, in _compile_core 
    res = pm.run(self.status) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 257, in run 
    raise patched_exception 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 249, in run 
    stage() 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 476, in stage_nopython_frontend 
    self.locals) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\compiler.py", line 828, in type_inference_stage 
    infer.propagate() 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typeinfer.py", line 717, in propagate 
    raise errors[0] 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typeinfer.py", line 127, in propagate 
    constraint(typeinfer) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typeinfer.py", line 372, in __call__ 
    self.resolve(typeinfer, typevars, fnty) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typeinfer.py", line 385, in resolve 
    sig = typeinfer.resolve_call(fnty, pos_args, kw_args) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typeinfer.py", line 972, in resolve_call 
    return self.context.resolve_function_type(fnty, pos_args, kw_args) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typing\context.py", line 124, in resolve_function_type 
    return func.get_call_type(self, args, kws) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\types\functions.py", line 49, in get_call_type 
    sig = temp.apply(args, kws) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typing\templates.py", line 216, in apply 
    sig = typer(*args, **kws) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typing\npydecl.py", line 456, in typer 
    ndim, seq_dtype = _parse_nested_sequence(self.context, object) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typing\npydecl.py", line 423, in _parse_nested_sequence 
    n, dtype = _parse_nested_sequence(context, typ.dtype) 

    File "C:\Users\Elios\Anaconda4\lib\site-packages\numba\typing\npydecl.py", line 421, in _parse_nested_sequence 
    raise TypingError("%r not allowed in a homogenous sequence") 

TypingError: %r not allowed in a homogenous sequence 
+0

'cross 'を' jit'するだけでいいのですか?もしあなたがそれを稼働させても、numbaが2位で改善できるかなと思う。それは 'numpy'コードがたくさん含まれていて、私はループを見ません。 – hpaulj

+0

numbaコンパイルされた関数を呼び出す関数numbaは、そうでなければ実行するよりも速く実行できるので、はい、私はそれをやっています。それはちょうど推測tho。 – Ezrael

答えて

0

Numbaはネストされた配列をサポートしていないようです。次のコードで問題を回避することができました。 jit_EM_field()は、長さ3の2つのネストされた配列の代わりに長さ6の配列を返します。これは、jitデコレータに@jit(float64[:]と表示されます。冗長else節も削除しました。

import numpy as np 
from numba import jit, float64 

c = 3*10**8 
epsilon = 8.854187817 * 10**(-12) 
mu = 4*np.pi *10**(-7) 

@jit(nopython=True) 
def cross(vec1, vec2): 
    result = np.array([0.,0.,0.]) 
    a1, a2, a3 = vec1[0],vec1[1], vec1[2] 
    b1, b2, b3 = vec2[0], vec2[1], vec2[2] 
    result[0] = a2 * b3 - a3 * b2 
    result[1] = a3 * b1 - a1 * b3 
    result[2] = a1 * b2 - a2 * b1 
    return result  

@jit(float64[:](float64[:],float64,float64,float64[:],float64[:],float64[:]), nopython=True) 
def jit_EM_field(position,length,ladung,velocity,acceleration,R): 
    #using solutions to lienard wiechert potential 
    EM_field = np.array([0.,0.,0.,0.,0.,0.]) 
    radius = np.linalg.norm(R - position) 
    if radius != 0: 
     unitradius = (R - position)/radius 
     if np.dot(unitradius, velocity) != 1: 
      charge = ladung/((1 - np.dot(unitradius, velocity)/c)** 3) 
      if radius < length: 
       radius = length 
      radius2 = radius ** 2 
      velocity_in_c = velocity/c 
      oneMinusV2 = 1 - np.dot(velocity_in_c, velocity_in_c) 
      uMinusV = unitradius - velocity_in_c    
      aCrossUmV = cross(uMinusV, acceleration) 
      Eleft = (oneMinusV2 * (unitradius - velocity_in_c))/radius2 
      Eright = cross(unitradius, aCrossUmV)/(radius*c**2) 
      E = (charge/(4*np.pi*epsilon)) * (Eleft - Eright) 
      B = cross(unitradius/c, ((mu*epsilon*charge*c**2) * (Eleft - Eright))) 
      EM_field = np.array([E[0],E[1],E[2],B[0],B[1],B[2]]) 
    return EM_field 

em_field = jit_EM_field(np.array([0.,1.,0.]),1.,0.1,np.array([0.,1.,0.]),np.array([0.,1.,0.]),np.array([7.2,5.6,0.1])) 
em_field_zero = jit_EM_field(np.array([0.,1.,0.]),1.,0.1,np.array([0.,1.,0.]),np.array([0.,1.,0.]),np.array([0.,1.,0.])) 
import pprint as pp 
pp.pprint(em_field) 
pp.pprint(em_field_zero) 
+0

dtype = float_とdouble()を指定して関数を呼び出すと、同じエラーメッセージが表示されます。私はあなたの答えを誤解しましたか? – Ezrael

+0

あなたは '' 'dtype = float64' ''を試しましたか? – Lindehaven

+0

はい、インポートされたfloat64との潜在的な競合を避けるために 'float_'に変更しました。 – Ezrael

関連する問題