def test_branch_elimination(self): from nitrous.module import dump add_5 = False add_any = True @function(Long, a=Long, b=Bool) def f1(a, b): if add_any and b: a += 5 return a @function(Long, a=Long) def f2(a): if add_any and add_5: a += 5 return a m1 = module([f1]) ir = " ".join(dump(m1).split("\n")) # In first function, conditional depends on a parameter self.assertRegexpMatches(ir, "icmp") m2 = module([f2]) ir = " ".join(dump(m2).split("\n")) # In second, entire conditional is resolved at # compile time and optimized away self.assertNotRegexpMatches(ir, "icmp")
def test_type_mismatch(self): @function(Bool, x=Long) def f1(x): return x < 1.0 message = ">>> return x < 1.0" with self.assertRaisesRegexp(TypeError, message): module([f1])
def test_assign_wrong_type(self): @function(x=Array(Long, shape=(1, )), y=Double) def f(x, y): x[0] = y message = ">>> x\[0\] = y" with self.assertRaisesRegexp(TypeError, message): module([f])
def test_not_iterable(self): @function(a=Long) def foo(a): b, = a message = "Value of type 'Long' is not an iterable" with self.assertRaisesRegexp(TypeError, message): module([foo])
def test_return_non_void(self): """Raise error if void function returns non-void value""" @function() def f(): return 5 message = ">>> return 5" with self.assertRaisesRegexp(ValueError, message): module([f])
def test_missing_return(self): """Raise error if no return in function with non-void return type.""" @function(Double) def f(): pass message = ">>> pass" with self.assertRaisesRegexp(TypeError, message): module([f])
def test_unexpected_type(self): """Raise error if returning unexpected value type.""" @function(Double, x=Double) def f(x): return 5 message = ">>> return 5" with self.assertRaisesRegexp(TypeError, message): module([f])
def test_compound_test(self): """Support compound conditionals such as 1 < x < 2.""" @function(Bool, x=Long) def f1(x): return 1 < x < 2 message = ">>> return 1 < x < 2" with self.assertRaisesRegexp(NotImplementedError, message): module([f1])
def test_assign_wrong_type(self): @function(x=Array(Long, shape=(1,)), y=Double) def f(x, y): x[0] = y message = ">>> x\[0\] = y" with self.assertRaisesRegexp(TypeError, message): module([f])
def test_missing_symbol(self): """Raise error if cannot resolve a symbol.""" @function(Double, y=Long) def x(y): return z error = ">>> return z" with self.assertRaisesRegexp(NameError, error): module([x])
def test_shape_mismatch(self): """Raise error if packed/unpacked tuple lengths differ""" @function(Long, a=Long, b=Long) def foo(a, b): b, = a, b message = "Cannot unpack 2 values into 1" with self.assertRaisesRegexp(ValueError, message): module([foo])
def test_unsupported_slice(self): """Raise error on unsupported context (eg. `del x`).""" @function(Long, y=Long) def x(y): y[:] return 0 message = ">>> y\[:\]" with self.assertRaisesRegexp(NotImplementedError, message): module([x])
def test_unsupported_target(self): """Check for unsupported assignments.""" @function(Long, a=Long, b=Long) def f(a, b): a, b = 1 return 0 message = ">>> a, b = 1" with self.assertRaisesRegexp(TypeError, message): module([f])
def test_unsupported_chain(self): """Raise error on chain assignment.""" @function(Long) def f(): a = b = 1 return 0 message = ">>> a = b = 1" with self.assertRaisesRegexp(NotImplementedError, message): module([f])
def test_not_wrong_type(self): @function(Bool, a=Long) def not_(a): na = not a return na message = " >>> na = not a" with self.assertRaisesRegexp(TypeError, message): module([not_])
def test_if_expr_type_mismatch(self): """Raise error when `if` expression clause types don't match.""" # Simple expression @function(Long, a=Long, b=Long) def max2(a, b): return 1.0 if a > b else 0 message = ">>> return 1.0 if a > b else 0" with self.assertRaisesRegexp(TypeError, message): module([max2])
def test_symbol_out_of_scope(self): """Raise error if symbol is available but not in the current scope.""" @function(Double, y=Long) def x(y): for i in range(y): z = i return z error = ">>> return z" with self.assertRaisesRegexp(NameError, error): module([x])
def test_unsupported_context(self): """Raise error on unsupported context (eg. `del x`).""" @function(Long) def x(): y = 1 del y return 0 message = ">>> del y" with self.assertRaisesRegexp(NotImplementedError, message): module([x])
def test_duplicate_function(self): def get_foo(): @function(Long, a=Long) def foo(a): return 1 return foo message = "Duplicate function name: foo" with self.assertRaisesRegexp(RuntimeError, message): module([get_foo(), get_foo()])
def test_invalid_cast(self): from nitrous.lib import cast from nitrous.types import Structure S = Structure("S", ("x", Long)) @function(Long, a=S) def int_to_long(a): return cast(a, Long) with self.assertRaisesRegexp(TypeError, "Cannot cast"): module([int_to_long])
def test_call_wrong_arg_count(self): @function(Long, x=Long) def f1(x): return x @function(Long, x=Long) def f2(x): return f1(x, 1) message = "f1\(\) takes exactly 1 argument\(s\) \(2 given\)" with self.assertRaisesRegexp(TypeError, message): module([f2])
def test_call_wrong_arg_type(self): @function(Long, x=Long) def f1(x): return x @function(Long, x=Long) def f2(x): return f1(1.0) message = "f1\(\) called with wrong argument type\(s\) for x" with self.assertRaisesRegexp(TypeError, message): module([f2])
def test_for_else(self): """for/else clause is not supported.""" @function(Long, n=Long) def loop_1(n): for i in range(n): pass else: pass return 0 message = ">>> for i in range\(n\):" with self.assertRaisesRegexp(NotImplementedError, message): module([loop_1])
def test_for_range(self): """More advanced loop ranges.""" import ctypes Long8 = Array(Long, (8, )) @function(Long, data=Long8, start=Long, end=Long) def loop_1(data, start, end): for i in range(start, end): data[i] = i return 0 @function(Long, data=Long8, start=Long, end=Long, step=Long) def loop_2(data, start, end, step): for i in range(start, end, step): data[i] = i return 0 m = module([loop_1, loop_2]) data = (ctypes.c_long * 8)() m.loop_1(data, 2, 7) self.assertEqual(list(data), [0, 0, 2, 3, 4, 5, 6, 0]) data = (ctypes.c_long * 8)() m.loop_2(data, 2, 7, 2) self.assertEqual(list(data), [0, 0, 2, 0, 4, 0, 6, 0])
def test_or_eval(self): """Don't evaluate `or`ed expressions if not necessary.""" from nitrous.types.array import Array Bool1 = Array(Bool, (1,)) @function(Bool, b=Bool1) def side_effect(b): b[0] = True return True @function(Bool, a=Bool, b=Bool1) def or_(a, b): return a or side_effect(b) m = module([or_]) # First value is true; should skip b x = (Bool.c_type * 1)(False) m.or_(True, x) self.assertFalse(x[0]) # First value is false; should evaluate b x = (Bool.c_type * 1)(False) m.or_(False, x) self.assertTrue(x[0])
def test_and_eval(self): """Don't evaluate `and`ed expressions if not necessary.""" from nitrous.types.array import Array Bool1 = Array(Bool, (1,)) @function(Bool, b=Bool1) def side_effect(b): b[0] = True return True @function(Bool, a=Bool, b=Bool1) def and_(a, b): return a and side_effect(b) m = module([and_]) # First value is true; b should evaluate as well x = (Bool.c_type * 1)(False) m.and_(True, x) self.assertTrue(x[0]) # First value is false; should skip next term. x = (Bool.c_type * 1)(False) m.and_(False, x) self.assertFalse(x[0])
def test_string_literal(self): @function(String) def return_const(): return "hello world" m = module([return_const]) self.assertEqual(m.return_const(), "hello world")
def test_if(self): # if clause only @function(Long, a=Long, b=Long) def max2_1(a, b): v = b if a > b: v = a return v # if/else clause @function(Long, a=Long, b=Long) def max2_2(a, b): v = 0 if a > b: v = a else: v = b return v m = module([max2_1, max2_2]) for f in [m.max2_1, m.max2_2]: self.assertEqual(f(2, 3), 3) self.assertEqual(f(4, 1), 4)
def test_return_void(self): @function() def f(): return m = module([f]) self.assertIsNone(m.f())
def test_return_implicit_void(self): @function() def f(): pass m = module([f]) self.assertIsNone(m.f())
def test_for_range(self): """More advanced loop ranges.""" import ctypes Long8 = Array(Long, (8,)) @function(Long, data=Long8, start=Long, end=Long) def loop_1(data, start, end): for i in range(start, end): data[i] = i return 0 @function(Long, data=Long8, start=Long, end=Long, step=Long) def loop_2(data, start, end, step): for i in range(start, end, step): data[i] = i return 0 m = module([loop_1, loop_2]) data = (ctypes.c_long * 8)() m.loop_1(data, 2, 7) self.assertEqual(list(data), [0, 0, 2, 3, 4, 5, 6, 0]) data = (ctypes.c_long * 8)() m.loop_2(data, 2, 7, 2) self.assertEqual(list(data), [0, 0, 2, 0, 4, 0, 6, 0])
def setUp(self): @function(Pointer(Double), x=Pointer(Double)) def f(x): return x self.m = module([f]) self.addCleanup(delattr, self, "m")
def test_sqrt(self): @function(Double, x=Double) def sqrt(x): return nitrous.lib.math.sqrt(Double)(x) m = module([sqrt]) self.assertAlmostEqual(math.sqrt(10.0), m.sqrt(10.0))
def test_fill(self): m = module([fill]) v = (Float.c_type * 4)(1, 2, 3, 4) m.fill(v, 100.0) self.assertEqual(list(v), [100.0] * 4)
def test_exp(self): @function(Double, x=Double) def exp(x): return nitrous.lib.math.exp(Double)(x) m = module([exp]) self.assertAlmostEqual(math.exp(10.0), m.exp(10.0))