def expect(fn, args, expected, msg = None, valid_types = None): """ Helper function used for testing, assert that Parakeet evaluates given code to the correct result """ if hasattr(expected, 'dtype') and expected.dtype == 'float16': expected = expected.astype('float32') untyped_fn = frontend.ast_conversion.translate_function_value(fn) try: interp_result = run_untyped_fn(untyped_fn, _copy_list(args), backend = "interp") except: if testing_find_broken_transform: find_broken_transform(fn, args, expected) raise label = "interp: inputs = %s" % ", ".join(str(arg) for arg in args) if msg is not None: label += "-" + str(msg) try: expect_eq(interp_result, expected, label) except: if testing_find_broken_transform: find_broken_transform(fn, args, expected) raise native_result = run_python_fn(fn, _copy_list(args), backend="c") if valid_types is not None: if not isinstance(valid_types, (tuple, list)): valid_types = [valid_types] assert type(native_result) in valid_types, \ "Expected result to have type in %s but got %s" % (valid_types, type(native_result)) label = "native: inputs = %s" % ", ".join(str(arg) for arg in args) if msg is not None: label += "-" + str(msg) try: expect_eq(native_result, expected, label) except: if testing_find_broken_transform: find_broken_transform(fn, args, expected) raise
def test_identity(): expect_eq(identity_i64(1), 1) expect_eq(identity_i64(-1), -1) expect_eq(identity_f64(1.0), 1.0) expect_eq(identity_f64(-1.0), -1.0)
def test_sum(): expect_eq(sum_i64(np.array([1,2,3])), 6) expect_eq(sum_f64(np.array([-1.0, 1.0, 2.0])), 2.0)
def test_vec_add(): xs,ys = np.array([1,2,3]), np.array([10,20,30]) zs = vec_add(xs, ys) expected = xs + ys expect_eq(zs, expected)