def test_jit_allocate_out_1arg(): """test jit_allocate_out of functions with 1 argument""" def f(arr, out=None, args=None): out[:] = arr return out a = np.linspace(0, 1, 3) g = jit_allocate_out(out_shape=a.shape)(f) np.testing.assert_equal(g(a), a) np.testing.assert_equal(jit_allocate_out(f)(a), a)
def test_jit_allocate_out_2arg(): """test jit_allocate_out of functions with 1 argument""" def f(a, b, out=None, args=None): out[:] = a + b return out a = np.linspace(0, 1, 3) b = np.linspace(1, 2, 3) c = np.linspace(1, 3, 3) g = jit_allocate_out(out_shape=a.shape, num_args=2)(f) np.testing.assert_equal(g(a, b), c) np.testing.assert_equal(jit_allocate_out(num_args=2)(f)(a, b), c)
def test_jit_allocate_out_1arg(): """test jit_allocate_out of functions with 1 argument""" def f(arr, out=None, args=None): out[:] = arr return out jit_count = int(JIT_COUNT) a = np.linspace(0, 1, 3) g = jit_allocate_out(out_shape=a.shape)(f) np.testing.assert_equal(g(a), a) np.testing.assert_equal(jit_allocate_out(f)(a), a) if nb.config.DISABLE_JIT: assert int(JIT_COUNT) == jit_count else: assert int(JIT_COUNT) == jit_count + 2
def test_jit_allocate_out_2arg(): """test jit_allocate_out of functions with 1 argument""" def f(a, b, out=None, args=None): out[:] = a + b return out jit_count = int(JIT_COUNT) a = np.linspace(0, 1, 3) b = np.linspace(1, 2, 3) c = np.linspace(1, 3, 3) g = jit_allocate_out(out_shape=a.shape, num_args=2)(f) np.testing.assert_equal(g(a, b), c) np.testing.assert_equal(jit_allocate_out(num_args=2)(f)(a, b), c) if nb.config.DISABLE_JIT: assert int(JIT_COUNT) == jit_count else: assert int(JIT_COUNT) == jit_count + 2