I want to speed up the following code using cython:
class A(object):
cdef fun(self):
return 3
class B(object):
cdef fun(self):
return 2
def test():
cdef int x, y, i, s = 0
a = [ [A(), B()], [B(), A()]]
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
s += a[x][y].fun()
return s
The only thing that comes to mind is something like this:
def test():
cdef int x, y, i, s = 0
types = [ [0, 1], [1, 0]]
data = [[...], [...]]
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
if types[x,y] == 0:
s+= A(data[x,y]).fun()
开发者_如何学C else:
s+= B(data[x,y]).fun()
return s
Basically, the solution in C++ will be to have array of pointers to some base class with virtual method fun()
, then you could iterate through it pretty quickly. Is there a way to do it using python/cython?
BTW: would it be faster to use numpy's 2D array with dtype=object_, instead of python lists?
Looks like code like this gives about 20x speedup:
import numpy as np
cimport numpy as np
cdef class Base(object):
cdef int fun(self):
return -1
cdef class A(Base):
cdef int fun(self):
return 3
cdef class B(Base):
cdef int fun(self):
return 2
def test():
bbb = np.array([[A(), B()], [B(), A()]], dtype=np.object_)
cdef np.ndarray[dtype=object, ndim=2] a = bbb
cdef int i, x, y
cdef int s = 0
cdef Base u
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
u = a[x,y]
s += u.fun()
return s
It even checks, that A and B are inherited from Base, probably there is way to disable it in release builds and get additional speedup
EDIT: Check could be removed using
u = <Base>a[x,y]
精彩评论