Exemple #1
0
    def test_alias_ctypes(self):
        # use xxnrm2 to test call a C function with ctypes
        from numba.targets.linalg import _BLAS
        xxnrm2 = _BLAS().numba_xxnrm2(types.float64)

        def remove_dead_xxnrm2(rhs, lives, call_list):
            if call_list == [xxnrm2]:
                return rhs.args[4].name not in lives
            return False

        # adding this handler has no-op effect since this function won't match
        # anything else but it's a bit cleaner to save the state and recover
        old_remove_handlers = remove_call_handlers[:]
        remove_call_handlers.append(remove_dead_xxnrm2)

        def func(ret):
            a = np.ones(4)
            xxnrm2(100, 4, a.ctypes, 1, ret.ctypes)

        A1 = np.zeros(1)
        A2 = A1.copy()

        try:
            pfunc = self.compile_parallel(func, (numba.typeof(A1),))
            numba.njit(func)(A1)
            pfunc(A2)
        finally:
            # recover global state
            remove_call_handlers[:] = old_remove_handlers

        self.assertEqual(A1[0], A2[0])
    def callable(cls, nans=False, reverse=False, scalar=False):
        """ Compile a jitted function doing the hard part of the job """
        _valgetter = cls._valgetter_scalar if scalar else cls._valgetter
        valgetter = nb.njit(_valgetter)
        outersetter = nb.njit(cls._outersetter)

        _cls_inner = nb.njit(cls._inner)
        if nans:
            def _inner(ri, val, ret, counter, mean):
                if not np.isnan(val):
                    _cls_inner(ri, val, ret, counter, mean)
            inner = nb.njit(_inner)
        else:
            inner = _cls_inner

        def _loop(group_idx, a, ret, counter, mean, outer, fill_value, ddof):
            # fill_value and ddof need to be present for being exchangeable with loop_2pass
            size = len(ret)
            rng = range(len(group_idx) - 1, -1 , -1) if reverse else range(len(group_idx))
            for i in rng:
                ri = group_idx[i]
                if ri < 0:
                    raise ValueError("negative indices not supported")
                if ri >= size:
                    raise ValueError("one or more indices in group_idx are too large")
                val = valgetter(a, i)
                inner(ri, val, ret, counter, mean)
                outersetter(outer, i, ret[ri])
        return nb.njit(_loop, nogil=True)
Exemple #3
0
    def test_record_real(self):
        rectyp = np.dtype([('real', np.float32), ('imag', np.complex64)])
        arr = np.zeros(3, dtype=rectyp)
        arr['real'] = np.random.random(arr.size)
        arr['imag'] = np.random.random(arr.size) * 1.3j

        # check numpy behavior
        # .real is identity
        self.assertIs(array_real(arr), arr)
        # .imag is zero_like
        self.assertEqual(array_imag(arr).tolist(), np.zeros_like(arr).tolist())

        # check numba behavior
        # it's most likely a user error, anyway
        jit_array_real = njit(array_real)
        jit_array_imag = njit(array_imag)

        with self.assertRaises(TypingError) as raises:
            jit_array_real(arr)
        self.assertIn("cannot access .real of array of Record",
                      str(raises.exception))

        with self.assertRaises(TypingError) as raises:
            jit_array_imag(arr)
        self.assertIn("cannot access .imag of array of Record",
                      str(raises.exception))
Exemple #4
0
 def test_jit(self):
     def foo(x):
         return x + math.sin(x)
     fastfoo = njit(fastmath=True)(foo)
     slowfoo = njit(foo)
     self.assertEqual(fastfoo(0.5), slowfoo(0.5))
     fastllvm = fastfoo.inspect_llvm(fastfoo.signatures[0])
     slowllvm = slowfoo.inspect_llvm(slowfoo.signatures[0])
     # Ensure fast attribute in fast version only
     self.assertIn('fadd fast', fastllvm)
     self.assertIn('call fast', fastllvm)
     self.assertNotIn('fadd fast', slowllvm)
     self.assertNotIn('call fast', slowllvm)
Exemple #5
0
 def test_eq(self, flags=no_pyobj_flags):
     pyfunc = eq_usecase
     cfunc = njit(pyfunc)
     for a in UNICODE_EXAMPLES:
         for b in reversed(UNICODE_EXAMPLES):
             self.assertEqual(pyfunc(a, b),
                              cfunc(a, b), '%s, %s' % (a, b))
Exemple #6
0
    def _check_ordering_op(self, usecase):
        pyfunc = usecase
        cfunc = njit(pyfunc)

        # Check comparison to self
        for a in UNICODE_ORDERING_EXAMPLES:
            self.assertEqual(
                pyfunc(a, a),
                cfunc(a, a),
                '%s: "%s", "%s"' % (usecase.__name__, a, a),
            )

        # Check comparison to adjacent
        for a, b in permutations(UNICODE_ORDERING_EXAMPLES, r=2):
            self.assertEqual(
                pyfunc(a, b),
                cfunc(a, b),
                '%s: "%s", "%s"' % (usecase.__name__, a, b),
            )
            # and reversed
            self.assertEqual(
                pyfunc(b, a),
                cfunc(b, a),
                '%s: "%s", "%s"' % (usecase.__name__, b, a),
            )
Exemple #7
0
    def test_zfill(self):
        pyfunc = zfill_usecase
        cfunc = njit(pyfunc)

        ZFILL_INPUTS = [
            'ascii',
            '+ascii',
            '-ascii',
            '-asc ii-',
            '12345',
            '-12345',
            '+12345',
            '',
            '¡Y tú crs?',
            '🐍⚡',
            '+🐍⚡',
            '-🐍⚡',
            '大眼,小手。',
            '+大眼,小手。',
            '-大眼,小手。',
        ]

        with self.assertRaises(TypingError) as raises:
            cfunc(ZFILL_INPUTS[0], 1.1)
        self.assertIn('<width> must be an Integer', str(raises.exception))

        for s in ZFILL_INPUTS:
            for width in range(-3, 20):
                self.assertEqual(pyfunc(s, width),
                                 cfunc(s, width))
Exemple #8
0
    def test_split_whitespace(self):
        # explicit sep=None cases covered in test_split and test_split_with_maxsplit
        pyfunc = split_whitespace_usecase
        cfunc = njit(pyfunc)

        # list copied from https://github.com/python/cpython/blob/master/Objects/unicodetype_db.h
        all_whitespace = ''.join(map(chr, [
            0x0009, 0x000A, 0x000B, 0x000C, 0x000D, 0x001C, 0x001D, 0x001E, 0x001F, 0x0020,
            0x0085, 0x00A0, 0x1680, 0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006,
            0x2007, 0x2008, 0x2009, 0x200A, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000
        ]))

        CASES = [
            '',
            'abcabc',
            '🐍 ⚡',
            '🐍 ⚡ 🐍',
            '🐍   ⚡ 🐍  ',
            '  🐍   ⚡ 🐍',
            ' 🐍' + all_whitespace + '⚡ 🐍  ',
        ]
        for test_str in CASES:
            self.assertEqual(pyfunc(test_str),
                             cfunc(test_str),
                             "'%s'.split()?" % (test_str,))
Exemple #9
0
    def test_case06_double_objmode(self):
        def foo(x):
            # would nested ctx in the same scope ever make sense? Is this
            # pattern useful?
            with objmode_context():
                #with npmmode_context(): not implemented yet
                    with objmode_context():
                        print(x)
            return x

        with self.assertRaises(errors.TypingError) as raises:
            njit(foo)(123)
        # Check that an error occurred in with-lifting in objmode
        pat = ("During: resolving callee type: "
               "type\(ObjModeLiftedWith\(<.*>\)\)")
        self.assertRegexpMatches(str(raises.exception), pat)
Exemple #10
0
 def compile_func(self, pyfunc):
     def check(*args, **kwargs):
         expected = pyfunc(*args, **kwargs)
         result = f(*args, **kwargs)
         self.assertPreciseEqual(result, expected)
     f = njit(pyfunc)
     return f, check
Exemple #11
0
    def _get_stencil_last_ind(self, dim_size, end_length, gen_nodes, scope,
                                                                        loc):
        last_ind = dim_size
        if end_length != 0:
            # set last index to size minus stencil size to avoid invalid
            # memory access
            index_const = ir.Var(scope, mk_unique_var("stencil_const_var"),
                                                                        loc)
            self.typemap[index_const.name] = types.intp
            if isinstance(end_length, numbers.Number):
                const_assign = ir.Assign(ir.Const(end_length, loc),
                                                        index_const, loc)
            else:
                const_assign = ir.Assign(end_length, index_const, loc)

            gen_nodes.append(const_assign)
            last_ind = ir.Var(scope, mk_unique_var("last_ind"), loc)
            self.typemap[last_ind.name] = types.intp

            g_var = ir.Var(scope, mk_unique_var("compute_last_ind_var"), loc)
            check_func = numba.njit(_compute_last_ind)
            func_typ = types.functions.Dispatcher(check_func)
            self.typemap[g_var.name] = func_typ
            g_obj = ir.Global("_compute_last_ind", check_func, loc)
            g_assign = ir.Assign(g_obj, g_var, loc)
            gen_nodes.append(g_assign)
            index_call = ir.Expr.call(g_var, [dim_size, index_const], (), loc)
            self.calltypes[index_call] = func_typ.get_call_type(
                self.typingctx, [types.intp, types.intp], {})
            index_assign = ir.Assign(index_call, last_ind, loc)
            gen_nodes.append(index_assign)

        return last_ind
 def test_default_args(self):
     """
     Test a nested function call using default argument values.
     """
     cfunc = njit(g)
     self.assertEqual(cfunc(1, 2, 3), g(1, 2, 3))
     self.assertEqual(cfunc(1, y=2, z=3), g(1, 2, 3))
 def test_named_args(self):
     """
     Test a nested function call with named (keyword) arguments.
     """
     cfunc = njit(f)
     self.assertEqual(cfunc(1, 2, 3), f(1, 2, 3))
     self.assertEqual(cfunc(1, y=2, z=3), f(1, 2, 3))
Exemple #14
0
    def test_misc(self):

        @njit
        def swap(x, y):
            return(y, x)

        def test_bug2537(m):
            a = np.ones(m)
            b = np.ones(m)
            for i in range(m):
                a[i], b[i] = swap(a[i], b[i])

        try:
            njit(test_bug2537, parallel=True)(10)
        except IndexError:
            self.fail("test_bug2537 raised IndexError!")
Exemple #15
0
    def collect_results(fn, fast=True, should_jit=True):
        if fast:
            sizes = (100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000)

            if should_jit and not isinstance(fn, numba.targets.registry.CPUDispatcher):
                fn = njit(fn)

            fn(np.random.random((10, 10)))  # force compile
        else:
            sizes = (100, 200, 300, 400, 500, 600)

        raw_data = []
        for n in sizes:
            data = np.random.random((n, n))
            start = time.time()
            lower, upper = fn(data)
            end = time.time()
            elapsed = float(str(end - start)[:8])
            print(f'{fn.__name__}: {n}: {elapsed}')
            raw_data.append((n, elapsed))

            # sanity check
            assert_array_almost_equal(lower @ upper, data)

        raw_results = np.array(raw_data)
        raw_df = pd.DataFrame(raw_results, columns=['idx', 'vals']).set_index('idx')
        return raw_df
def fjit(fun):
    'just-in-time compile a function by wrapping it in a singleton class'
    
    from numba import jitclass
    import time
    
    # the function is jitted first
    jitted_fun = njit(fun)

    # Generate a random class name like 'Singleton_Sat_Jan__2_18_08_32_2016'
    classname = 'Singleton_' + time.asctime().replace(' ','_').replace(':','_')
    
    # programmatically create a class equivalent to :
    # class Singleton_Sat_Jan__2_18_08_32_2016:
    #     def __init__(self): pass
    #     def value(self, x): return fj(x)
    
    def __init__(self): pass
    def value(self, x): return jitted_fun(x)
    SingletonClass = type(classname, (object,), {'__init__': __init__, 'value': value})
    
    # jit compile the class
    # spec is [] since we don't store attributes
    spec = []
    sc = jitclass(spec)(SingletonClass)
    
    # return a unique instance of the class
    return sc()
Exemple #17
0
def test(*args):
    exec format_stmt(gen_index_i_at_dim(2))
    nb_index_2i = nb.njit(index_2i)
    Y = np.arange(4).reshape(2,2)
    assert nb_index_2i(Y, 0) == 0
    assert nb_index_2i(Y, 1) == 3
    assert nb_index_2i(Y, -1) == 3
    return nb_index_2i
Exemple #18
0
    def test_case18_njitfunc_passed_to_objmode_ctx(self):
        def foo(func, x):
            with objmode_context():
                func(x[0])

        x = np.array([1, 2, 3])
        fn = njit(lambda z: z + 5)
        self.assert_equal_return_and_stdout(foo, fn, x)
Exemple #19
0
def _generate_property(field, template, fname):
    """
    Generate simple function that get/set a field of the instance
    """
    source = template.format(field)
    glbls = {}
    exec_(source, glbls)
    return njit(glbls[fname])
Exemple #20
0
 def test_inplace_concat(self, flags=no_pyobj_flags):
     pyfunc = inplace_concat_usecase
     cfunc = njit(pyfunc)
     for a in UNICODE_EXAMPLES:
         for b in UNICODE_EXAMPLES[::-1]:
             self.assertEqual(pyfunc(a, b),
                              cfunc(a, b),
                              "'%s' + '%s'?" % (a, b))
Exemple #21
0
 def test_literal_getitem(self):
     def pyfunc(which):
         return 'abc'[which]
     cfunc = njit(pyfunc)
     for a in [-1, 0, 1, slice(1, None), slice(None, -1)]:
         args = [a]
         self.assertEqual(pyfunc(*args), cfunc(*args),
                          msg='failed on {}'.format(args))
Exemple #22
0
    def test_literal_xyzwith(self):
        def pyfunc(x, y):
            return 'abc'.startswith(x), 'cde'.endswith(y)

        cfunc = njit(pyfunc)
        for args in permutations('abcdefg', r=2):
            self.assertEqual(pyfunc(*args), cfunc(*args),
                             msg='failed on {}'.format(args))
Exemple #23
0
 def test_unicode_stopiteration_iter(self):
     self.disable_leak_check()
     pyfunc = iter_stopiteration_usecase
     cfunc = njit(pyfunc)
     for f in (pyfunc, cfunc):
         for a in UNICODE_EXAMPLES:
             with self.assertRaises(StopIteration):
                 f(a)
Exemple #24
0
 def test_endswith(self, flags=no_pyobj_flags):
     pyfunc = endswith_usecase
     cfunc = njit(pyfunc)
     for a in UNICODE_EXAMPLES:
         for b in [x for x in ['', 'x', a[:-2], a[3:], a, a + a]]:
             self.assertEqual(pyfunc(a, b),
                              cfunc(a, b),
                              '%s, %s' % (a, b))
Exemple #25
0
 def test_issue_1264(self):
     n = 100
     x = np.random.uniform(size=n*3).reshape((n,3))
     expected = distance_matrix(x)
     actual = njit(distance_matrix)(x)
     np.testing.assert_array_almost_equal(expected, actual)
     # Avoid sporadic failures in MemoryLeakMixin.tearDown()
     gc.collect()
Exemple #26
0
 def test_size_after_slicing(self):
     pyfunc = size_after_slicing_usecase
     cfunc = njit(pyfunc)
     arr = np.arange(2 * 5).reshape(2, 5)
     for i in range(arr.shape[0]):
         self.assertEqual(pyfunc(arr, i), cfunc(arr, i))
     arr = np.arange(2 * 5 * 3).reshape(2, 5, 3)
     for i in range(arr.shape[0]):
         self.assertEqual(pyfunc(arr, i), cfunc(arr, i))
Exemple #27
0
    def test_literal_in(self):
        def pyfunc(x):
            return x in '9876zabiuh'

        cfunc = njit(pyfunc)
        for a in ['a', '9', '1', '', '8uha', '987']:
            args = [a]
            self.assertEqual(pyfunc(*args), cfunc(*args),
                             msg='failed on {}'.format(args))
Exemple #28
0
    def test_literal_find(self):
        def pyfunc(x):
            return 'abc'.find(x), x.find('a')

        cfunc = njit(pyfunc)
        for a in ['ab']:
            args = [a]
            self.assertEqual(pyfunc(*args), cfunc(*args),
                             msg='failed on {}'.format(args))
Exemple #29
0
 def test_repeat(self, flags=no_pyobj_flags):
     pyfunc = repeat_usecase
     cfunc = njit(pyfunc)
     for a in UNICODE_EXAMPLES + ['']:
         for b in (-1, 0, 1, 2, 3, 4, 5, 7, 8, 15, 70):
             self.assertEqual(pyfunc(a, b),
                              cfunc(a, b))
             self.assertEqual(pyfunc(b, a),
                              cfunc(b, a))
Exemple #30
0
    def test_getitem(self):
        pyfunc = getitem_usecase
        cfunc = njit(pyfunc)

        for s in UNICODE_EXAMPLES:
            for i in range(-len(s)):
                self.assertEqual(pyfunc(s, i),
                                 cfunc(s, i),
                                 "'%s'[%d]?" % (s, i))
Exemple #31
0
def _gen_csv_reader_py_pyarrow_jit_func(csv_reader_py):
    # TODO: no_cpython_wrapper=True crashes for some reason
    jit_func = numba.njit(csv_reader_py)
    compiled_funcs.append(jit_func)
    return jit_func
Exemple #32
0
w_bins : :class:`numpy.ndarray`
    W stacking bins of shape :code:`(nw + 1,)`

Returns
-------
:class:`numpy.ndarray`
    W-coordinate centroids of shape :code:`(nw,)`
    in wavelengths.
"""

if on_rtd():

    def w_stacking_centroids(w_bins):
        pass
else:
    w_stacking_centroids = numba.njit(nogil=True,
                                      cache=True)(_w_stacking_centroids)

w_stacking_centroids.__doc__ = WSTACK_DOCS


@numba.jit(nopython=True, nogil=True, cache=True)
def w_bin_masks(uvw, w_bins):
    indices = np.digitize(uvw[:, 2], w_bins) - 1
    return [i == indices for i in range(w_bins.shape[0])]


@numba.jit(nopython=True, nogil=True, cache=True)
def numba_grid(vis, uvw, flags, weights, ref_wave, convolution_filter, w_bins,
               cell_size, grids):

    assert len(grids) == w_bins.shape[0] - 1
Exemple #33
0
    # Maintain order of dimensions
    return ranks.transpose(*da_ams.dims)


def ecdf(rank, n_obs):
    """Return the ECDF
    Recommended as an unbiased estimator for PWM by:
    Hosking, J. R. M., and J. R. Wallis. 1995.
    “A Comparison of Unbiased and Plotting-Position Estimators of L Moments.”
    Water Resources Research 31 (8): 2019–25.
    https://doi.org/10.1029/95WR01230.
    """
    return rank / n_obs


ecdf_jit = nb.njit(ecdf)


def pp_weibull(rank, n_obs):
    """Return the Weibull plotting position
    Recommended by:
    Makkonen, Lasse. 2006.
    “Plotting Positions in Extreme Value Analysis.”
    Journal of Applied Meteorology and Climatology 45 (2): 334–40.
    https://doi.org/10.1175/JAM2349.1.
    """
    return rank / (n_obs + 1)


def pp_cunnane(rank, n_obs):
    """The Cunnane plotting position.
Exemple #34
0
def make_njit_fn(args, fn_expr):
    fn_str = lambdastr(args, fn_expr).replace('MutableDenseMatrix', '')\
                                                  .replace('(([[', '[') \
                                                  .replace(']]))', ']')
    jit_fn = numba.njit(parallel=True)(eval(fn_str))
    return jit_fn
Exemple #35
0
_gdb_cond = os.environ.get('GDB_TEST', None) == '1'
needs_gdb_harness = unittest.skipUnless(_gdb_cond, "needs gdb harness")

# check if gdb is present and working
try:
    _confirm_gdb()
    _HAVE_GDB = True
except Exception:
    _HAVE_GDB = False

_msg = "functioning gdb with correct ptrace permissions is required"
needs_gdb = unittest.skipUnless(_HAVE_GDB, _msg)
long_running = tag('long_running')

_dbg_njit = njit(debug=True)
_dbg_jit = jit(forceobj=True, debug=True)


def impl_gdb_call(a):
    gdb('-ex', 'set confirm off', '-ex', 'c', '-ex', 'q')
    b = a + 1
    c = a * 2.34
    d = (a, b, c)
    print(a, b, c, d)


def impl_gdb_call_w_bp(a):
    gdb_init('-ex', 'set confirm off', '-ex', 'c', '-ex', 'q')
    b = a + 1
    c = a * 2.34
Exemple #36
0
 def test_liftcall4(self):
     with self.assertRaises(errors.TypingError) as raises:
         njit(liftcall4)()
     # Known error.  We only support one context manager per function
     # for body that are lifted.
     self.assertIn("re-entrant", str(raises.exception))
Exemple #37
0
                lower[i][i] = 1.0
            else:
                total = 0.0
                for j in range(i):
                    total += lower[k][j] * upper[j][i]

                lower[k][i] = (x[k][i] - total) / upper[i][i]

    return lower, upper


if __name__ == '__main__':
    n = 1024
    X = np.random.uniform(size=(n, n))

    lu_decomp_original = njit(lu_decomp_original)
    lu_decomp_c_fortran = njit(lu_decomp_c_fortran)

    for i in range(2):
        t1 = time.process_time()
        l, u = lu_decomp_original(X)
        t2 = time.process_time()
        print('Naive took:   ', t2 - t1)

        t1 = time.process_time()
        l2, u2 = lu_decomp_c_fortran(X)
        t2 = time.process_time()
        np.testing.assert_allclose(l, l2)
        np.testing.assert_allclose(u, u2)
        print('Adjusted took:', t2 - t1)
        print('------------')
import numpy as np
from .tensor_data import (
    count,
    index_to_position,
    broadcast_index,
    shape_broadcast,
    MAX_DIMS,
)
from numba import njit, prange

# This code will JIT compile fast versions your tensor_data functions.
# If you get an error, read the docs for NUMBA as to what is allowed
# in these functions.
count = njit(inline="always")(count)
index_to_position = njit(inline="always")(index_to_position)
broadcast_index = njit(inline="always")(broadcast_index)


def tensor_map(fn):
    """
    NUMBA higher-order tensor map function. ::

      fn_map = tensor_map(fn)
      fn_map(out, ... )

    Args:
        fn: function mappings floats-to-floats to apply.
        out (array): storage for out tensor.
        out_shape (array): shape for out tensor.
        out_strides (array): strides for out tensor.
        in_storage (array): storage for in tensor.
Exemple #39
0

def pi(n):
    c = 0
    for _ in range(n):
        x = random.uniform(0, 1)
        y = random.uniform(0, 1)
        if x * x + y * y < 1:
            c += 1
    return c


print("\ndepth:")
n = int(input())
cores = os.cpu_count()
compiled_pi = njit(nogil=True)(pi)
compiled_pi(1)

start = time()
print('\ninterpreted:\n', 4 * pi(n) / n)
end = time()
print('Time taken:', "{:.3f}".format(1000 * (end - start)), "ms")

start = time()
print('\ncompiled:\n', 4 * compiled_pi(n) / n)
end = time()
print('Time taken:', "{:.3f}".format(1000 * (end - start)), "ms")

with ThreadPoolExecutor(cores) as ex:
    start = time()
    bean = ex.map(compiled_pi, np.full(shape=(cores),
Exemple #40
0
    """
    if mapping is None:
        return values
    uniques, indices = np.unique(values, return_inverse=True)
    if isinstance(mapping, dict):
        mapping = lambda x, m=mapping: m.get(x, x)
    elif not callable(mapping):
        raise TypeError("Only `dict` and `callable` mappings are supported.")
    mapped_values = np.array([mapping(x) for x in uniques])[indices]
    return mapped_values


def for_fill_intervals(arr, starts, ends, values):
    """`fill_intervals` subfunction moved out of its scope for consistency."""
    for start, end, value in zip(starts, ends, values):
        arr[start:end] = value
    return arr


njit_fill_intervals = njit(for_fill_intervals)


def fill_intervals(arr, starts, ends, values):
    """Fill intervals in `arr` in limits defined by `starts` and `ends`
    with `values` using either simple for loop or wrapped with @njit."""
    if values.dtype.kind in set('buif'):
        return njit_fill_intervals(arr, starts, ends, values)
    if values.dtype.kind in set('UO'):
        return for_fill_intervals(arr, starts, ends, values)
    raise TypeError("Only numeric, str and object dtypes are supported.")
Exemple #41
0
 def check_issue_4708(pyfunc, m, n):
     expected = pyfunc(m, n)
     got = njit(pyfunc)(m, n)
     # values in arrays are equals,
     # but stronger assertions not hold (layout and strides equality)
     np.testing.assert_equal(got, expected)
from numba import njit, prange, guvectorize
from cuda_friendly_vincenty import vincenty
import numpy as np

wrap = njit(parallel=True)
wrapped_vincenty = wrap(vincenty)


@guvectorize('void(int64, float32[:,:], float32[:])',
             '(),(n, m) -> ()',
             target='parallel',
             cache=True)
def get_min_distances(idx, points, result):
    l = len(points)
    min_dist = -1

    for i in prange(l):
        if i != idx:
            distance = wrapped_vincenty(points[idx, 0], points[idx, 1],
                                        points[i, 0], points[i, 1])

            if (min_dist == -1) or (distance < min_dist):
                min_dist = distance

    result[0] = min_dist


@njit
def compiled_vincenty(point1, point2):
    return wrapped_vincenty(point1[0], point1[1], point2[0], point2[1])
            # Compute residual path delay (linear combination of left
            # and right path delay)
            polyfit = np.polynomial.polynomial.polyfit
            pol = polyfit([self.beam_positions[0], self.beam_positions[1]],
                          [beam_l, beam_r], 1)
            beam = (np.array(num_pixels * [pol[0]]).T +
                    np.array(num_lines * [x_ac]) *
                    np.array(num_pixels * [pol[1]]).T)
            wet_tropo = wt - beam
            wet_tropo_nadir = wt_large[:, naclarge //
                                       2] - beam[:, num_pixels // 2]
        else:
            raise ValueError("nbeam must be in [1, 2]")

        # wt_nadir = wt_large[:, naclarge // 2]

        return {
            "wet_troposphere": wet_tropo,
            "wet_troposphere_nadir": wet_tropo_nadir
        }


if nb is not None:
    _calculate_path_delay = nb.njit(cache=True)(calculate_path_delay)
    _calculate_path_delay_lr = nb.njit(cache=True)(calculate_path_delay_lr)
    _meshgrid = nb.njit(cache=True)(meshgrid)
else:
    _calculate_path_delay = calculate_path_delay
    _calculate_path_delay_lr = calculate_path_delay_lr
    _meshgrid = meshgrid
Exemple #44
0
def register_class_type(cls, spec, class_ctor, builder):
    """
    Internal function to create a jitclass.

    Args
    ----
    cls: the original class object (used as the prototype)
    spec: the structural specification contains the field types.
    class_ctor: the numba type to represent the jitclass
    builder: the internal jitclass builder
    """
    # Normalize spec
    if isinstance(spec, Sequence):
        spec = OrderedDict(spec)
    _validate_spec(spec)

    # Fix up private attribute names
    spec = _fix_up_private_attr(cls.__name__, spec)

    # Copy methods from base classes
    clsdct = {}
    for basecls in reversed(inspect.getmro(cls)):
        clsdct.update(basecls.__dict__)

    methods = dict((k, v) for k, v in clsdct.items()
                   if isinstance(v, pytypes.FunctionType))
    props = dict((k, v) for k, v in clsdct.items() if isinstance(v, property))

    others = dict((k, v) for k, v in clsdct.items()
                  if k not in methods and k not in props)

    # Check for name shadowing
    shadowed = (set(methods) | set(props)) & set(spec)
    if shadowed:
        raise NameError("name shadowing: {0}".format(', '.join(shadowed)))

    docstring = others.pop('__doc__', "")
    _drop_ignored_attrs(others)
    if others:
        msg = "class members are not yet supported: {0}"
        members = ', '.join(others.keys())
        raise TypeError(msg.format(members))

    for k, v in props.items():
        if v.fdel is not None:
            raise TypeError("deleter is not supported: {0}".format(k))

    jitmethods = {}
    for k, v in methods.items():
        jitmethods[k] = njit(v)

    jitprops = {}
    for k, v in props.items():
        dct = {}
        if v.fget:
            dct['get'] = njit(v.fget)
        if v.fset:
            dct['set'] = njit(v.fset)
        jitprops[k] = dct

    # Instantiate class type
    class_type = class_ctor(cls, ConstructorTemplate, spec, jitmethods,
                            jitprops)

    cls = JitClassType(cls.__name__, (cls, ),
                       dict(class_type=class_type, __doc__=docstring))

    # Register resolution of the class object
    typingctx = cpu_target.typing_context
    typingctx.insert_global(cls, class_type)

    # Register class
    targetctx = cpu_target.target_context
    builder(class_type, methods, typingctx, targetctx).register()

    return cls
Exemple #45
0
                PHI, PHI_jm1, COLP, COLP_jm1, POTT, POTT_jm1, PVTF, PVTF_jm1,
                PVTFVB, PVTFVB_jm1, PVTFVB_jm1_kp1, PVTFVB_kp1, dsigma,
                sigma_vb, sigma_vb_kp1, dxjs)
        # NUMERICAL HORIZONTAL DIFUSION
        if i_UVFLX_num_dif and (UVFLX_dif_coef > wp(0.)):
            dVFLXdt = dVFLXdt + num_dif(VFLX, VFLX_im1, VFLX_ip1, VFLX_jm1,
                                        VFLX_jp1, UVFLX_dif_coef)

    return (dVFLXdt, dVFLXdt_TURB)
    #return(dVFLXdt,)


###############################################################################
### SPECIALIZE FOR GPU
###############################################################################
UVFLX_hor_adv = njit(UVFLX_hor_adv_py, device=True, inline=True)
interp_VAR_ds = njit(interp_VAR_ds_py, device=True, inline=True)
coriolis_and_spherical_VWIND = njit(coriolis_and_spherical_VWIND_py,
                                    device=True,
                                    inline=True)
pre_grad = njit(pre_grad_py, device=True, inline=True)
num_dif = njit(num_dif_py, device=True, inline=True)
add_up_tendencies = njit(add_up_tendencies_py, device=True, inline=True)


def launch_cuda_main_kernel(dVFLXdt, VFLX, UWIND, VWIND, RFLX_3D, SFLX_3D,
                            TFLX_3D, QFLX_3D, PHI, PHIVB, COLP, POTT, PVTF,
                            PVTFVB, WWIND_VWIND, KMOM_dVWINDdz, RHO,
                            dVFLXdt_TURB, SMOMYFLX, corf, lat_rad, dlon_rad,
                            dlat_rad, dxjs, dsigma, sigma_vb, UVFLX_dif_coef):
 def test_literal_unroll(self):
     arr = np.array([1, 2], dtype=recordtype2)
     pyfunc = get_field2
     jitfunc = njit(pyfunc)
     self.assertEqual(pyfunc(arr[0]), jitfunc(arr[0]))
Exemple #47
0
        return (x + y) / 2
    if func == 'max':
        return max(x, y)
    if func == 'min':
        return min(x, y)
    if func == 'diff':
        return abs(x - y)


def gumbel():
    return -math.log(-math.log(random.uniform(0.0, 1.0)))


try:
    import numba as nb
    gumbel = nb.njit(gumbel)
except ImportError:
    pass


class GreedyCompressed:
    """A greedy contraction path finder that takes into account the effect of
    compression, and can also make use of subgraph size and centrality.

    Parameters
    ----------
    chi : int
        The maximum bond size between nodes to compress to.
    coeff_size_compressed : float, optional
        When assessing contractions, how to weight the size of the output
        tensor, post compression.
Exemple #48
0
            for i, n in enumerate(nus):
                if nu <= n and i > 0:
                    Mhigh = Ms[i]
                    NUhigh = n
                    Mlow = Ms[i - 1]
                    NUlow = nus[i - 1]
                    break
            if nu - NUlow > NUhigh - nu:
                return Mhigh
            else:
                return Mlow
        else:
            raise ValueError(
                'ERROR: NU out of range. Must be between 0 and 88.4.')

    MofNU = njit(MofNU)
    NUofM = njit(NUofM)

    A = np.zeros((Np + Nc, 10))
    Rp = np.ones(Np + Nc)  # R+, Riemann invariant of C+
    Rm = np.ones(Np + Nc)  # R-, Riemann invariant of C-
    theta = np.ones(Np + Nc)  # Flow turning angle
    nu = np.ones(Np + Nc)  # Prandtl-Meyer angle
    M = np.ones(Np + Nc)  # Mach number
    mu = np.ones(Np + Nc)  # Mach angle
    alpha_p = np.ones(Np + Nc)  # Angle of C+
    alpha_m = np.ones(Np + Nc)  # Angle of C-
    x = np.ones(Np + Nc)  # x-coordinate
    y = np.ones(Np + Nc)  # y-coordinate

    # Array of points along each characteristic line.
Exemple #49
0
	Added [13/10/2018]
	X @ y is found. Scipy & HyperLearn has similar speed. Notice, now
	HyperLearn can be parallelised! This reduces complexity to approx
	O(np/c) where c = no of threads / cores
	"""
    Z = zeros(n, dtype=y.dtype)

    for i in prange(n):
        s = 0
        for j in range(rowIndices[i], rowIndices[i + 1]):
            s += val[j] * y[colPointer[j]]
        Z[i] = s
    return Z


mat_vec = njit(_mat_vec, fastmath=True, nogil=True, cache=True)
mat_vec_parallel = njit(_mat_vec, fastmath=True, nogil=True, parallel=True)


def _matT_vec(val, colPointer, rowIndices, n, p, y):
    """
	Added [13/10/2018]
	X.T @ y is found. Notice how instead of converting CSR to CSC matrix, a direct
	X.T @ y can be found. Same complexity as mat_vec(X, y). Also, HyperLearn is
	parallelized, allowing for O(np/c) complexity.
	"""
    Z = zeros(p, dtype=y.dtype)

    for i in prange(n):
        yi = y[i]
        for j in range(rowIndices[i], rowIndices[i + 1]):
Exemple #50
0

init_state, mu = imag_time_gpe1D(v=initial_trap,
                                 g=g,
                                 dt=1e-3,
                                 epsilon=1e-8,
                                 **params)

init_state, mu = imag_time_gpe1D(init_wavefunction=init_state,
                                 g=g,
                                 v=initial_trap,
                                 dt=1e-4,
                                 epsilon=1e-9,
                                 **params)

flipped_initial_trap = njit(lambda x, t: initial_trap(-x, t))

flipped_init_state, mu = imag_time_gpe1D(v=flipped_initial_trap,
                                         g=g,
                                         dt=1e-3,
                                         epsilon=1e-8,
                                         **params)

flipped_init_state, mu = imag_time_gpe1D(init_wavefunction=flipped_init_state,
                                         g=g,
                                         v=flipped_initial_trap,
                                         dt=1e-4,
                                         epsilon=1e-9,
                                         **params)

# qsys = SplitOpGPE1D(
Exemple #51
0
        threading_enabled = False
        warnings.warn(
            'Threading not available for the simulation.\n'
            'Simulations will still run, but using only 1 thread on CPU.\n'
            'Please ensure that numba 0.34 or >=0.36 is installed.\n'
            '(e.g. by typing `conda update numba` in a terminal)')

# Set the function njit_parallel and prange to the correct object
if not threading_enabled:
    # Use regular serial compilation function
    njit_parallel = njit
    prange = range
    nthreads = 1
else:
    # Use the parallel compilation function
    njit_parallel = njit(parallel=True)
    prange = numba_prange
    nthreads = numba.config.NUMBA_NUM_THREADS


def get_chunk_indices(Ntot, nthreads):
    """
    Divide `Ntot` in `nthreads` chunks (almost equal in size), and
    return the indices that bound the chunks

    Parameters
    ----------
    Ntot: int
        Typically, the number of particles in a species
    nthreads: int
        The number of threads among which the work is divided
Exemple #52
0
 def test_array_ctypes_data(self):
     pyfunc = array_ctypes_data
     cfunc = njit(pyfunc)
     arr = np.arange(3)
     self.assertEqual(pyfunc(arr), cfunc(arr))
Exemple #53
0
def mismatch_from_strains(
    h1,
    h2,
    fmin=0,
    fmax=np.inf,
    noises=None,
    antenna_patterns=None,
    num_polarization_shifts=100,
    num_time_shifts=100,
    time_shift_start=-5,
    time_shift_end=5,
    force_numba=False,
):
    r"""Compute the network-mismatch between ``h1`` and ``h2`` by maximizing the
    overlap over time and polarization shifts.

    Network here means that the inner product is computed for N detectors, as
    provided by the lists antenna_patterns and noises. Noises and antenna
    patterns have to be properly ordered: ``noises[i]`` has to correspond to
    ``antenna_pattern[i]``.

    See :ref:`gw_mismatch:Overlap and mismatch` for formulas and details.

    The mismatch is computed by maximizing over time and polarization shifts.
    Polarization shifts and are around the 2pi, time shifts are specified by
    time_shift_start and time_shift_end. If num_time_shifts is 1, then no time
    shift is performed. For times, we make sure that we always have to zero
    timeshift. All the transformations are done in h2.

    This computation is a maximisation, which is very expensive. So, we have a
    very fast core function called _mismatch_core_numerical to do all the hard
    work. This function is compiled to native code by numba, resulting to
    enormous speed-up.There is an overhead in calling numba. Hence, by default
    we do not always use numba. We use it only when the number
    num_polarization_shifts * num_time_shifts is greater than 500*500. You can
    force using numba passing the keyword argument force_numba=True.

    We do not perform phase shifts here, so this function makes sense only
    for the (2,2) mode.

    h1 and h2 have to be already pre-processed for Fourier transform, so you
    should window them and zero pad as needed.

    :param h1: First strain.
    :type h1: :py:class:`~.TimeSeries`
    :param h2: Second strain (the one that will be modified).
    :type h2: :py:class:`~.TimeSeries`
    :param fmin: Lower limit of the integration.
    :type fmin: float
    :param fmax: Higher limit of the integration.
    :type fmax: float
    :param noises: Power spectral density of the noise for all the detectors.
                   If None, a uniform noise is applied.
    :type noises: list of :py:class:`~.FrequencySeries`, or None
    :param antenna_patterns: Fc, Fp for all the detectors. It has to be ordered
                             in the same way as noises. If None, a uniform antenna
                             pattern is applied.
    :type antenna_patterns: list of tuples, or None
    :param num_polarization_shifts: How many points to divide the range
                                    (0, 2 pi) in the polarization shift.
    :type num_polarization_shifts: int
    :param num_time_shifts: How many points to divide the range
                            (time_shift_start, time_shift_end) in the time shift.
    :type num_time_shifts: int
    :param time_shift_start: Minimum time shift applied. Search will be done
                             linearly up to time_shift_end.
    :type time_shift_start: float
    :param time_shift_end: Largest value of time shift applied.
    :type time_shift_end: float
    :param force_numba: Use numba irrespectively of the size of the input.
    :type force_numba: bool

    """

    # In kuibit, we have beautiful collection of classes to represent
    # different data types (TimeSeries, FrequencySeries, ...).
    # However, from great abstraction comes great performance penalities.
    # Using these classes is too slow for expensive operations.
    # The reason for this are (at least):
    # 1. large number of function calls (expensive in Python)
    # 2. several redundant operations
    # 3. several checks that we can guarantee will be passed
    # ...
    # Computing the mismatch is a numerical operation, we should be able
    # to crunch numbers at the speed of light (ie, as fast as C). For this,
    # we use numba and we break apart all our abstractions to expose only
    # the data as NumPy arrays. In this function we pre-process the
    # FrequencySeries so that we can feed _mismatch_core_numerical with
    # what we need. _mismatch_core_numerical takes only standard NumPy
    # objects (arrays, tuples, and floats) and return the mismatch and
    # the phase/time shifts needed for it.
    #
    # An important step will be to guarantee that everything (the series and
    # the noise) is defined over the same frequency range.
    #
    # What we are doing is:
    # 1. Prepare the arrays for the shifts that have to be performed

    polarization_shifts = np.linspace(0, 2 * np.pi, num_polarization_shifts)

    # We make sure that we always have to zero timeshift.
    time_shifts = np.append(
        [0],
        np.linspace(time_shift_start, time_shift_end, num_time_shifts - 1),
    )

    # 2. We resample h1 and h2 to a common timeseries (linearly spaced). This
    #    guarantees that their Fourier transform will be defined over the same
    #    frequencies. To avoid throwing away signal, we resample the two series
    #    to the union of their times, setting them to zero where they were not
    #    defined, and choosing as number of points the smallest number of
    #    points between the two series.

    (smallest_len, largest_tmax, smallest_tmin) = (
        min(len(h1), len(h2)),
        max(h1.tmax, h2.tmax),
        min(h1.tmin, h2.tmin),
    )

    union_times = np.linspace(smallest_tmin, largest_tmax, smallest_len)

    # ext=1 sets zero where the series is not defined
    h1_res, h2_res = (
        h1.resampled(union_times, ext=1),
        h2.resampled(union_times, ext=1),
    )

    # 3. We take the Fourier transform of the two polarizations. In doing this,
    #    we also make sure that the arrays are complex.This is because we will
    #    be doing complex operations (e.g. phase-shift), so, we will always
    #    deal with complex series. However, enforcing that they are complex
    #    since the beginning makes bookeeping easier for the Fourier transform,
    #    as the operation behaves differently for real and imaginary data.
    #
    #    We crop h1_res to the requested frequencies and we only take the
    #    positive ones. We will resample the noise to match h1_res.

    h1_p_res = h1_res.real()
    h1_c_res = -h1_res.imag()

    h1_p_res.y = h1_p_res.y.astype("complex128")
    h1_c_res.y = h1_c_res.y.astype("complex128")

    h1f_p_res = h1_p_res.to_FrequencySeries()
    h1f_p_res.band_pass(fmin, fmax)
    h1f_p_res.negative_frequencies_remove()

    h1f_c_res = h1_c_res.to_FrequencySeries()
    h1f_c_res.band_pass(fmin, fmax)
    h1f_c_res.negative_frequencies_remove()

    # 3. Then, we resample the noise to have be defined on the same frequencies
    #    as h1f. We will only need to take care of the h2. If the noise is
    #    None, we prepare a unweighted noise (ones everywhere).
    #
    #    The problem with resampling noises is that PSD curves have often
    #    strong discontinuities, which are not correctly captured by the
    #    splines. Therefore, instead of using cubic splines, here we prefer
    #    using a piecewise constant approximation. Since the noise has
    #    typically a lot of points, this should be a better approximation than
    #    having large jumps. kuibit does not have this option, so we use
    #    directly SciPy's interp1d.

    if noises is not None:
        # With this, we can guarantee that everything has the same domain.
        # If there's a None entry, we fill it with a constant noise.
        noises_res = []
        for noise in noises:
            noises_res.append(
                fs.FrequencySeries(h1f_p_res.f, np.ones_like(h1f_p_res.fft)))
            if noise is not None:
                # TODO: Now the Series class has a function for this kind of
                #       resampling. Use that.
                #
                # We start with a FrequencySeries of ones, and we overwrite the
                # fft attribute
                noises_res[-1] = noises[-1].resampled(h1f_p_res.f,
                                                      piecewise_constant=True)
    else:
        # Here we prepare a noise that is made by ones everywhere. This is what
        # happens internally when noises is None. However, here we do it
        # explicitly because we are going to pass it to the numba function.
        noises_res = [
            fs.FrequencySeries(h1f_p_res.f, np.ones_like(h1f_p_res.fft))
        ]

    # 4. We use the linearity of the Fourier transform to apply the antenna
    #    pattern. (This is why we have to carry around the two polarization
    #    seperatebly). We have to compute tilde(h_1) * tilde(h_2).conj().
    #    But h_i = Fp h_p + Fc h_c. So, for linearity
    #    tilde(h_1) = Fp tilde(h_p) + Fc tilde(h_c). Similarly with h_2.
    #    Therefore, we have to prepare the antenna patterns for each detector.

    # This case is "we have 3 noise curves, but we don't care about the antenna
    # response". So we have to have 3 antenna patterns.
    if antenna_patterns is None:
        antenna_patterns = [(1 / 2, 1 / 2)] * len(noises_res)

    # This case is "we have N detectors, but we don't care about the actual
    # noise curve". So we have to have N noises. Before, we set noises =
    # [ones], so we duplicate that.
    #
    # If both noises and antenna_patterns are None, we will have a single
    # element in the noises list, which is what we expect.
    if noises is None:
        noises_res *= len(antenna_patterns)

    # Numba doesn't support lists, so we generate a tuple of arrays
    antenna_patterns = tuple(antenna_patterns)
    noises = tuple(n.fft for n in noises_res)

    # 5. Now, we have to prepare a frequency mask. This is an array of bools
    #    that indicates which frequencies in h2 should be used. This is because
    #    we are taking the Fourier transform in _mismatch_core_numerical, but
    #    we need to make sure that we considering only positive frequencies
    #    from fmin to fmax.

    all_frequencies = np.fft.fftfreq(len(h2_res.t), d=h2_res.dt)
    shifted_frequencies = np.fft.fftshift(all_frequencies)

    frequency_mask = np.array([f in h1f_p_res.f for f in shifted_frequencies])

    # 6. Finally we can call the numerical routine which will return the
    #    un-normalized mismatch and the shifts required. We will Fourier
    #    transform h2 in there. We must do that because we have to perform
    #    the polarization shifts in the time domain.

    frequencies = h1f_p_res.f  # from fmin to fmax

    use_numba = (force_numba
                 or num_polarization_shifts * num_time_shifts >= 500 * 500)

    if use_numba and "njit" not in globals():
        if force_numba:
            warn("numba not available, ignoring force_numba")
        use_numba = False

    if use_numba:
        globals()["objmode"] = numba_objmode
        _core_function = njit(_mismatch_core_numerical)
    else:
        # HACK: Now we have to do something dirty. _mismatch_core_numerical
        #       calls numba.objmode to perform FFTs, but when numba is not
        #       available, objmode is unkown. Hence, we have to provide a dummy
        #       objmode that does nothing. As long as numba doesn't support
        #       FFTs natively, that code has to be here. However, cannot put in
        #       _mismatch_core_numerical because numba wouldn't be able to
        #       compile the function.
        @contextmanager
        def nullcontext(*args, **kwargs):
            yield None

        # We override objmode in the gobal scope with nullcontext
        globals()["objmode"] = nullcontext

        _core_function = _mismatch_core_numerical

    (unnormalized_max_overlap, index_max) = _core_function(
        h1f_c_res.fft,
        h1f_p_res.fft,
        h2_res.y,
        h2_res.dt,
        frequencies,
        frequency_mask,
        noises,
        antenna_patterns,
        polarization_shifts,
        time_shifts,
    )

    # 12. The normalization is constant. Again, we do not include df or the
    #     factor of 4.

    h2_p_res = h2_res.real()
    h2_c_res = -h2_res.imag()

    # Transform to complex
    h2_p_res.y = h2_p_res.y.astype("complex128")
    h2_c_res.y = h2_c_res.y.astype("complex128")

    h2f_p_res = h2_p_res.to_FrequencySeries()
    h2f_p_res.band_pass(fmin, fmax)
    h2f_p_res.negative_frequencies_remove()

    h2f_c_res = h2_c_res.to_FrequencySeries()
    h2f_c_res.band_pass(fmin, fmax)
    h2f_c_res.negative_frequencies_remove()

    inner11 = fs.FrequencySeries(h1f_p_res.f, np.zeros_like(h1f_p_res.f))
    inner22 = fs.FrequencySeries(h2f_p_res.f, np.zeros_like(h2f_p_res.f))

    for noise, antenna_pattern in zip(noises_res, antenna_patterns):

        Fc, Fp = antenna_pattern

        numerator11 = Fp * h1f_p_res + Fc * h1f_c_res
        numerator11 *= (Fp * h1f_p_res + Fc * h1f_c_res).conjugate()

        inner11 += numerator11 / noise

        numerator22 = Fp * h2f_p_res + Fc * h2f_c_res
        numerator22 *= (Fp * h2f_p_res + Fc * h2f_c_res).conjugate()

        inner22 += numerator22 / noise

    inner11 = np.sum(inner11.fft).real
    inner22 = np.sum(inner22.fft).real

    norm = np.sqrt(inner11 * inner22)

    # Values that maximise the overlap

    # pylint: disable=unbalanced-tuple-unpacking
    (p_index,
     t_index) = np.unravel_index(index_max,
                                 (num_polarization_shifts, num_time_shifts))

    # Check t_index is close to the boundary and emit warning
    # We have to check for t_index = 0 because we always put the tshift=0 there
    if (not 0.05 < t_index / num_time_shifts < 0.95) and t_index != 0:
        warn("Maximum of overlap near the boundary of the time shift interval")

    p_shift_max = polarization_shifts[p_index]
    t_shift_max = time_shifts[t_index]

    return 1 - unnormalized_max_overlap / norm, (
        p_shift_max,
        t_shift_max,
    )
Exemple #54
0
from __future__ import print_function, absolute_import, division

import sys
import numpy as np
import threading
import random

from numba import unittest_support as unittest
from numba import njit
from numba import utils
from numba.numpy_support import version as numpy_version
from .support import MemoryLeakMixin, TestCase

nrtjit = njit(_nrt=True, nogil=True)


class BaseTest(TestCase):
    def check_outputs(self, pyfunc, argslist, exact=True):
        cfunc = nrtjit(pyfunc)
        for args in argslist:
            expected = pyfunc(*args)
            ret = cfunc(*args)
            self.assertEqual(ret.size, expected.size)
            self.assertEqual(ret.dtype, expected.dtype)
            self.assertStridesEqual(ret, expected)
            if exact:
                np.testing.assert_equal(expected, ret)
            else:
                np.testing.assert_allclose(expected, ret)

Exemple #55
0
 def jit_compile(self, fn):
     jitted_fn = numba.njit(fn)
     return jitted_fn
Exemple #56
0
    def __init__(self,
                 *,
                 x_grid_dim,
                 x_amplitude,
                 v,
                 k,
                 dt,
                 g,
                 epsilon=1e-2,
                 diff_k=None,
                 diff_v=None,
                 t=0,
                 abs_boundary=1.,
                 fftw_wisdom_fname='fftw.wisdom',
                 **kwargs):
        """
        :param x_grid_dim: the grid size
        :param x_amplitude: the maximum value of the coordinates
        :param v: the potential energy (as a function)
        :param k: the kinetic energy (as a function)
        :param diff_k: the derivative of the potential energy for the Ehrenfest theorem calculations
        :param diff_v: the derivative of the kinetic energy for the Ehrenfest theorem calculations
        :param t: initial value of time
        :param dt: initial time increment
        :param g: the coupling constant
        :param epsilon: relative error tolerance
        :param abs_boundary: absorbing boundary
        :param fftw_wisdom_fname: File name from where the FFT wisdom will be loaded from and saved to
        :param kwargs: ignored
        """

        # saving the properties
        self.x_grid_dim = x_grid_dim
        self.x_amplitude = x_amplitude
        self.v = v
        self.k = k
        self.diff_v = diff_v
        self.t = t
        self.dt = dt
        self.g = g
        self.epsilon = epsilon
        self.abs_boundary = abs_boundary

        ####################################################################################################
        #
        #   Initialize Fourier transform for efficient calculations
        #
        ####################################################################################################

        # Load the FFTW wisdom
        try:
            with open(fftw_wisdom_fname, 'rb') as fftw_wisdow:
                pyfftw.import_wisdom(pickle.load(fftw_wisdow))
        except (FileNotFoundError, EOFError):
            pass

        # allocate the array for wave function
        self.wavefunction = pyfftw.empty_aligned(x_grid_dim, dtype=np.complex)

        # allocate an extra copy for the wavefunction necessary for adaptive time step propagation
        self.wavefunction_next = pyfftw.empty_aligned(self.x_grid_dim,
                                                      dtype=np.complex)

        # allocate the array for wave function in momentum representation
        self.wavefunction_next_p = pyfftw.empty_aligned(x_grid_dim,
                                                        dtype=np.complex)

        # allocate the array for calculating the momentum representation for the energy evaluation
        self.wavefunction_next_p_ = pyfftw.empty_aligned(x_grid_dim,
                                                         dtype=np.complex)

        # parameters for FFT
        self.fft_params = {
            "flags": ('FFTW_MEASURE', 'FFTW_DESTROY_INPUT'),
            "threads": cpu_count(),  #Removed cpu_count from here
            "planning_timelimit": 60,
        }

        # FFT
        self.fft = pyfftw.FFTW(self.wavefunction_next,
                               self.wavefunction_next_p, **self.fft_params)

        # iFFT
        self.ifft = pyfftw.FFTW(self.wavefunction_next_p,
                                self.wavefunction_next,
                                direction='FFTW_BACKWARD',
                                **self.fft_params)

        # fft for momentum representation
        self.fft_p = pyfftw.FFTW(self.wavefunction_next_p,
                                 self.wavefunction_next_p_, **self.fft_params)

        # Save the FFTW wisdom
        with open(fftw_wisdom_fname, 'wb') as fftw_wisdow:
            pickle.dump(pyfftw.export_wisdom(), fftw_wisdow)

        ####################################################################################################
        #
        #   Initialize grids
        #
        ####################################################################################################

        # Check that all attributes were specified
        # make sure self.x_amplitude has a value of power of 2
        assert 2 ** int(np.log2(self.x_grid_dim)) == self.x_grid_dim, \
            "A value of the grid size (x_grid_dim) must be a power of 2"

        # get coordinate step size
        dx = self.dx = 2. * self.x_amplitude / self.x_grid_dim

        # generate coordinate range
        x = self.x = (np.arange(self.x_grid_dim) -
                      self.x_grid_dim / 2) * self.dx
        # The same as
        # self.x = np.linspace(-self.x_amplitude, self.x_amplitude - self.dx , self.x_grid_dim)

        # generate momentum range as it corresponds to FFT frequencies
        p = self.p = (np.arange(self.x_grid_dim) -
                      self.x_grid_dim / 2) * (np.pi / self.x_amplitude)

        # the relative change estimators for the time adaptive scheme
        self.e_n = self.e_n_1 = self.e_n_2 = 0

        self.previous_dt = 0

        # list of self.dt to monitor how the adaptive step method is working
        self.time_increments = []

        ####################################################################################################
        #
        # Codes for efficient evaluation
        #
        ####################################################################################################

        # Decide whether the potential depends on time
        try:
            v(x, 0)
            time_independent_v = False
        except TypeError:
            time_independent_v = True

        # Decide whether the kinetic energy depends on time
        try:
            k(p, 0)
            time_independent_k = False
        except TypeError:
            time_independent_k = True

        # pre-calculate the absorbing potential and the sequence of alternating signs

        abs_boundary = (abs_boundary if isinstance(abs_boundary,
                                                   (float,
                                                    int)) else abs_boundary(x))
        abs_boundary = (-1)**np.arange(self.wavefunction.size) * abs_boundary

        # Cache the potential if it does not depend on time
        if time_independent_v:
            pre_calculated_v = v(x)  # Test by removing T here
            v = njit(lambda _, __: pre_calculated_v)

        # Cache the kinetic energy if it does not depend on time
        if time_independent_k:
            pre_calculated_k = k(p)  # Test by removing T here
            k = njit(lambda _, __: pre_calculated_k)

        @njit  # (parallel=True)
        def expV(wavefunction, t, dt):
            """
            function to efficiently evaluate
                wavefunction *= (-1) ** k * exp(-0.5j * dt * v)
            """
            wavefunction *= abs_boundary * np.exp(
                -0.5j * dt *
                (v(x, t + 0.5 * dt) + g * np.abs(wavefunction)**2))
            wavefunction /= linalg.norm(wavefunction) * np.sqrt(dx)

        self.expV = expV

        @njit  # (parallel=True)
        def expK(wavefunction, t, dt):
            """
            function to efficiently evaluate
                wavefunction *= exp(-1j * dt * k)
            """
            wavefunction *= np.exp(-1j * dt * k(p, t + 0.5 * dt))

        self.expK = expK

        ####################################################################################################

        # Check whether the necessary terms are specified to calculate the first-order Ehrenfest theorems
        if diff_k and diff_v:

            # Cache the potential if it does not depend on time
            if time_independent_v:
                pre_calculated_diff_v = diff_v(x)
                diff_v = njit(lambda _, __: pre_calculated_diff_v)

            # Cache the kinetic energy if it does not depend on time
            if time_independent_k:
                pre_calculated_diff_k = diff_k(p)
                diff_k = njit(lambda _, __: pre_calculated_diff_k)

            # Get codes for efficiently calculating the Ehrenfest relations

            @njit
            def get_p_average_rhs(density, t):
                return np.sum(density * diff_v(x, t))

            self.get_p_average_rhs = get_p_average_rhs

            # The code above is equivalent to
            #self.get_p_average_rhs = njit(lambda density, t: np.sum(density * diff_v(x, t)))

            @njit
            def get_v_average(density, t):
                return np.sum((v(x, t) + 0.5 * g * density / dx) * density)

            self.get_v_average = get_v_average

            @njit
            def get_x_average(density):
                return np.sum(x * density)

            self.get_x_average = get_x_average

            @njit
            def get_x_average_rhs(density, t):
                return np.sum(diff_k(p, t) * density)

            self.get_x_average_rhs = get_x_average_rhs

            @njit
            def get_k_average(density, t):
                return np.sum(k(p, t) * density)

            self.get_k_average = get_k_average

            @njit
            def get_p_average(density):
                return np.sum(p * density)

            self.get_p_average = get_p_average

            # since the variable time propagator is used, we record the time when expectation values are calculated
            self.times = []

            # Lists where the expectation values of x and p
            self.x_average = []
            self.p_average = []

            # Lists where the right hand sides of the Ehrenfest theorems for x and p
            self.x_average_rhs = []
            self.p_average_rhs = []

            # List where the expectation value of the Hamiltonian will be calculated
            self.hamiltonian_average = []

            # sequence of alternating signs for getting the wavefunction in the momentum representation
            self.minus = (-1)**np.arange(self.x_grid_dim)

            # Flag requesting tha the Ehrenfest theorem calculations
            self.is_ehrenfest = True
        else:
            # Since diff_v and diff_k are not specified, we are not going to evaluate the Ehrenfest relations
            self.is_ehrenfest = False
Exemple #57
0
    def add_indices_to_kernel(self, kernel, index_names, ndim, neighborhood,
                              standard_indexed, typemap, calltypes):
        """
        Transforms the stencil kernel as specified by the user into one
        that includes each dimension's index variable as part of the getitem
        calls.  So, in effect array[-1] becomes array[index0-1].
        """
        const_dict = {}
        kernel_consts = []

        if config.DEBUG_ARRAY_OPT >= 1:
            print("add_indices_to_kernel", ndim, neighborhood)
            ir_utils.dump_blocks(kernel.blocks)

        if neighborhood is None:
            need_to_calc_kernel = True
        else:
            need_to_calc_kernel = False
            if len(neighborhood) != ndim:
                raise ValueError("%d dimensional neighborhood specified for %d " \
                    "dimensional input array" % (len(neighborhood), ndim))

        tuple_table = ir_utils.get_tuple_table(kernel.blocks)

        relatively_indexed = set()

        for block in kernel.blocks.values():
            scope = block.scope
            loc = block.loc
            new_body = []
            for stmt in block.body:
                if (isinstance(stmt, ir.Assign)
                        and isinstance(stmt.value, ir.Const)):
                    if config.DEBUG_ARRAY_OPT >= 1:
                        print("remembering in const_dict", stmt.target.name,
                              stmt.value.value)
                    # Remember consts for use later.
                    const_dict[stmt.target.name] = stmt.value.value
                if ((isinstance(stmt, ir.Assign)
                     and isinstance(stmt.value, ir.Expr)
                     and stmt.value.op in ['setitem', 'static_setitem']
                     and stmt.value.value.name in kernel.arg_names)
                        or (isinstance(stmt, ir.SetItem)
                            and stmt.target.name in kernel.arg_names)):
                    raise ValueError("Assignments to arrays passed to stencil " \
                        "kernels is not allowed.")
                if (isinstance(stmt, ir.Assign)
                        and isinstance(stmt.value, ir.Expr)
                        and stmt.value.op in ['getitem', 'static_getitem']
                        and stmt.value.value.name in kernel.arg_names
                        and stmt.value.value.name not in standard_indexed):
                    # We found a getitem from the input array.
                    if stmt.value.op == 'getitem':
                        stmt_index_var = stmt.value.index
                    else:
                        stmt_index_var = stmt.value.index_var
                        # allow static_getitem since rewrite passes are applied
                        #raise ValueError("Unexpected static_getitem in add_indices_to_kernel.")

                    relatively_indexed.add(stmt.value.value.name)

                    # Store the index used after looking up the variable in
                    # the const dictionary.
                    if need_to_calc_kernel:
                        assert hasattr(stmt_index_var, 'name')

                        if stmt_index_var.name in tuple_table:
                            kernel_consts += [tuple_table[stmt_index_var.name]]
                        elif stmt_index_var.name in const_dict:
                            kernel_consts += [const_dict[stmt_index_var.name]]
                        else:
                            raise ValueError(
                                "stencil kernel index is not "
                                "constant, 'neighborhood' option required")

                    if ndim == 1:
                        # Single dimension always has index variable 'index0'.
                        # tmpvar will hold the real index and is computed by
                        # adding the relative offset in stmt.value.index to
                        # the current absolute location in index0.
                        index_var = ir.Var(scope, index_names[0], loc)
                        tmpname = ir_utils.mk_unique_var("stencil_index")
                        tmpvar = ir.Var(scope, tmpname, loc)
                        stmt_index_var_typ = typemap[stmt_index_var.name]
                        # If the array is indexed with a slice then we
                        # have to add the index value with a call to
                        # slice_addition.
                        if isinstance(stmt_index_var_typ,
                                      types.misc.SliceType):
                            sa_var = ir.Var(
                                scope,
                                ir_utils.mk_unique_var("slice_addition"), loc)
                            sa_func = numba.njit(slice_addition)
                            sa_func_typ = types.functions.Dispatcher(sa_func)
                            typemap[sa_var.name] = sa_func_typ
                            g_sa = ir.Global("slice_addition", sa_func, loc)
                            new_body.append(ir.Assign(g_sa, sa_var, loc))
                            slice_addition_call = ir.Expr.call(
                                sa_var, [stmt_index_var, index_var], (), loc)
                            calltypes[
                                slice_addition_call] = sa_func_typ.get_call_type(
                                    self._typingctx,
                                    [stmt_index_var_typ, types.intp], {})
                            new_body.append(
                                ir.Assign(slice_addition_call, tmpvar, loc))
                            new_body.append(
                                ir.Assign(
                                    ir.Expr.getitem(stmt.value.value, tmpvar,
                                                    loc), stmt.target, loc))
                        else:
                            acc_call = ir.Expr.binop(operator.add,
                                                     stmt_index_var, index_var,
                                                     loc)
                            new_body.append(ir.Assign(acc_call, tmpvar, loc))
                            new_body.append(
                                ir.Assign(
                                    ir.Expr.getitem(stmt.value.value, tmpvar,
                                                    loc), stmt.target, loc))
                    else:
                        index_vars = []
                        sum_results = []
                        s_index_name = ir_utils.mk_unique_var("stencil_index")
                        s_index_var = ir.Var(scope, s_index_name, loc)
                        const_index_vars = []
                        ind_stencils = []

                        stmt_index_var_typ = typemap[stmt_index_var.name]
                        # Same idea as above but you have to extract
                        # individual elements out of the tuple indexing
                        # expression and add the corresponding index variable
                        # to them and then reconstitute as a tuple that can
                        # index the array.
                        for dim in range(ndim):
                            tmpname = ir_utils.mk_unique_var("const_index")
                            tmpvar = ir.Var(scope, tmpname, loc)
                            new_body.append(
                                ir.Assign(ir.Const(dim, loc), tmpvar, loc))
                            const_index_vars += [tmpvar]
                            index_var = ir.Var(scope, index_names[dim], loc)
                            index_vars += [index_var]

                            tmpname = ir_utils.mk_unique_var(
                                "ind_stencil_index")
                            tmpvar = ir.Var(scope, tmpname, loc)
                            ind_stencils += [tmpvar]
                            getitemname = ir_utils.mk_unique_var("getitem")
                            getitemvar = ir.Var(scope, getitemname, loc)
                            getitemcall = ir.Expr.getitem(
                                stmt_index_var, const_index_vars[dim], loc)
                            new_body.append(
                                ir.Assign(getitemcall, getitemvar, loc))
                            # Get the type of this particular part of the index tuple.
                            one_index_typ = stmt_index_var_typ[dim]
                            # If the array is indexed with a slice then we
                            # have to add the index value with a call to
                            # slice_addition.
                            if isinstance(one_index_typ, types.misc.SliceType):
                                sa_var = ir.Var(
                                    scope,
                                    ir_utils.mk_unique_var("slice_addition"),
                                    loc)
                                sa_func = numba.njit(slice_addition)
                                sa_func_typ = types.functions.Dispatcher(
                                    sa_func)
                                typemap[sa_var.name] = sa_func_typ
                                g_sa = ir.Global("slice_addition", sa_func,
                                                 loc)
                                new_body.append(ir.Assign(g_sa, sa_var, loc))
                                slice_addition_call = ir.Expr.call(
                                    sa_var, [getitemvar, index_vars[dim]], (),
                                    loc)
                                calltypes[
                                    slice_addition_call] = sa_func_typ.get_call_type(
                                        self._typingctx,
                                        [one_index_typ, types.intp], {})
                                new_body.append(
                                    ir.Assign(slice_addition_call, tmpvar,
                                              loc))
                            else:
                                acc_call = ir.Expr.binop(
                                    operator.add, getitemvar, index_vars[dim],
                                    loc)
                                new_body.append(
                                    ir.Assign(acc_call, tmpvar, loc))

                        tuple_call = ir.Expr.build_tuple(ind_stencils, loc)
                        new_body.append(ir.Assign(tuple_call, s_index_var,
                                                  loc))
                        new_body.append(
                            ir.Assign(
                                ir.Expr.getitem(stmt.value.value, s_index_var,
                                                loc), stmt.target, loc))
                else:
                    new_body.append(stmt)
            block.body = new_body

        if need_to_calc_kernel:
            # Find the size of the kernel by finding the maximum absolute value
            # index used in the kernel specification.
            neighborhood = [[0, 0] for _ in range(ndim)]
            if len(kernel_consts) == 0:
                raise ValueError("Stencil kernel with no accesses to "
                                 "relatively indexed arrays.")

            for index in kernel_consts:
                if isinstance(index, tuple) or isinstance(index, list):
                    for i in range(len(index)):
                        te = index[i]
                        if isinstance(te, ir.Var) and te.name in const_dict:
                            te = const_dict[te.name]
                        if isinstance(te, int):
                            neighborhood[i][0] = min(neighborhood[i][0], te)
                            neighborhood[i][1] = max(neighborhood[i][1], te)
                        else:
                            raise ValueError(
                                "stencil kernel index is not constant,"
                                "'neighborhood' option required")
                    index_len = len(index)
                elif isinstance(index, int):
                    neighborhood[0][0] = min(neighborhood[0][0], index)
                    neighborhood[0][1] = max(neighborhood[0][1], index)
                    index_len = 1
                else:
                    raise ValueError(
                        "Non-tuple or non-integer used as stencil index.")
                if index_len != ndim:
                    raise ValueError(
                        "Stencil index does not match array dimensionality.")

        return (neighborhood, relatively_indexed)
Exemple #58
0
 def foo(x):
     with objmode_context(y='int64[:]'):
         y = njit(bar)(x).astype('int64')
     return x + y
Exemple #59
0
def filters(
    hits,
    VL,
    tauV,
    tauL,
    tres,
    midpoints=1,
    which=['sample mode', 'cross correlation'],
    pbar_batch=None,
    VLER=0.3,
    VLNR=3,
    tcoinc=200,
    dcr=None,
    nph=None,
    sigma=None,
):
    """
    Run filters on hit times.
    
    Parameters
    ----------
    hits : array (nevents, nhits)
        The hit times. Need not be sorted.
    VL : scalar
        VL p_S1_gauss parameter for the filters 'cross correlation',
        'likelihood', 'sample mode cross correlation'.
    tauV, tauL, tres : scalar
        p_S1_gauss parameters for the filters 'cross correlation', 'ER', 'NR',
        'fast', 'slow', 'likelihood', 'sample mode cross correlation'.
    midpoints : int
        The continuous filters are computed on the hits times and on
        `midpoints` evenly spaced intermediate points between each hit, apart
        from the 'coinc*' filters which are computed only on hits.
    which : list of strings
        The filters to compute. Keywords:
            'cross correlation'
            'ER'
            'NR'
            'fast'
            'slow'
            'coinc'
            'coincX' where X is the time in nanoseconds
            'likelihood'
            'sample mode'
            'sample mode cross correlation'
    pbar_batch : int, optional
        If given, a progressbar is shown that ticks every `pbar_batch` events.
    VLER, VLNR : scalar
        VL p_S1_gauss parameter for the ER and NR filters.
    tcoinc : scalar
        Time for the 'coinc' filter.
    dcr, nph, sigma: scalar
        Parameters for the 'likelihood' filter.
    
    Return
    ------
    out : array (nevents,)
        Structured numpy array with each field corresponding to a filter.
        The field values are themselves structured with fields 'time' and
        'value' containing arrays with the filter output.
    """
    hits = np.asarray(hits)
    assert len(hits.shape) == 2
    
    nt = (hits.shape[1] - 1) * (midpoints + 1) + 1
    def filtlength(f):
        if f == 'sample mode':
            return hits.shape[1] - 1
        elif f.startswith('coinc'):
            return hits.shape[1]
        else:
            return nt
    
    out = np.empty(len(hits), dtype=[
        (filter_name, [
            ('time', float, (filtlength(filter_name),)),
            ('value', float, (filtlength(filter_name),))
        ]) for filter_name in which
    ])
    
    all_hits = hits
    all_out = out
    
    template = dict()
    
    smcc = 'sample mode cross correlation'
    havesmcc = smcc in which
    
    sm = 'sample mode'
    havesm = sm in which

    for vl, f in [(VL, smcc), (VL, 'cross correlation'), (VLER, 'ER'), (VLNR, 'NR')]:
        if f not in which:
            continue
        offset = pS1.p_S1_gauss_maximum(vl, tauV, tauL, tres)
        ampl = pS1.p_S1_gauss(offset, vl, tauV, tauL, tres)
        fun = lambda t: pS1.p_S1_gauss(t + offset, vl, tauV, tauL, tres) / ampl
        template[f] = (numba.njit('f8(f8)')(fun), -5 * tres, 10 * tauL)
    
    for tau, f in [(tauV, 'fast'), (tauL, 'slow')]:
        if f not in which:
            continue
        offset = pS1.p_exp_gauss_maximum(tau, tres)
        ampl = pS1.p_exp_gauss(offset, tau, tres)
        fun = lambda t: pS1.p_exp_gauss(t + offset, tau, tres) / ampl
        template[f] = (numba.njit('f8(f8)')(fun), -5 * tres, max(10 * tau, 5 * tres))
    
    coincbounds = {}
    for f in which:
        if f.startswith('coinc'):
            T = float(f[5:]) if len(f) > 5 else tcoinc
            eps = 1e-6
            coincbounds[f] = (-eps * T, (1 - eps) * T)
    if coincbounds:
        templcoinc = numba.njit('f8(f8)')(lambda t: 1)
    
    if 'likelihood' in which:
        offset = pS1.p_S1_gauss_maximum(vl, tauV, tauL, sigma)
        ampl = pS1.log_likelihood(offset, vl, tauV, tauL, sigma, dcr, nph)
        fun = lambda t: pS1.log_likelihood(t + offset, vl, tauV, tauL, sigma, dcr, nph) / ampl
        template['likelihood'] = (numba.njit('f8(f8)')(fun), -5 * sigma, 10 * tauL)

    def batch(s):
        hits = all_hits[s]
        out = all_out[s]
        
        hits = np.sort(hits, axis=-1)
    
        timevalue = {}
    
        if havesm:
            v = filter_sample_mode(hits)
            t = (hits[:, 1:] + hits[:, :-1]) / 2
            timevalue[sm] = (t, v)
        
        if havesmcc or template:
            t = addmidpoints(hits, midpoints)
    
        for filt, (fun, left, right) in template.items():
            v = filter_cross_correlation(hits, t, fun, left, right)
            timevalue[filt] = (t, v)
        
        if havesmcc:
            fun, left, right = template[smcc]
            v = filter_sample_mode_cross_correlation(hits, t, fun, left, right)
            timevalue[smcc] = (t, v)
    
        for filt, (left, right) in coincbounds.items():
            v = filter_cross_correlation(hits, hits, templcoinc, left, right)
            timevalue[filt] = (hits, v)
        
        for k, (t, v) in timevalue.items():
            out[k]['time'] = t
            out[k]['value'] = v
    
    runsliced.runsliced(batch, len(out), pbar_batch)
    
    return out
class CPUUseCase(UseCase):
    def _call(self, ret, *args):
        self._func(ret, *args)


# Using the same function as a cached CPU and CUDA-jitted function


def target_shared_assign(r, x):
    r[()] = x[()]


assign_cuda_kernel = cuda.jit(cache=True)(target_shared_assign)
assign_cuda = CUDAUseCase(assign_cuda_kernel)
assign_cpu_jitted = njit(cache=True)(target_shared_assign)
assign_cpu = CPUUseCase(assign_cpu_jitted)


class _TestModule(CUDATestCase):
    """
    Tests for functionality of this module's functions.
    Note this does not define any "test_*" method, instead check_module()
    should be called by hand.
    """
    def check_module(self, mod):
        self.assertPreciseEqual(mod.assign_cpu(5), 5)
        self.assertPreciseEqual(mod.assign_cpu(5.5), 5.5)
        self.assertPreciseEqual(mod.assign_cuda(5), 5)
        self.assertPreciseEqual(mod.assign_cuda(5.5), 5.5)