예제 #1
0
def expect(fn, args, expected, msg = None, valid_types = None):
  """
  Helper function used for testing, assert that Parakeet evaluates given code to
  the correct result
  """
  if hasattr(expected, 'dtype') and expected.dtype == 'float16':
    expected = expected.astype('float32')

  untyped_fn = frontend.ast_conversion.translate_function_value(fn)
  
  try: 
    interp_result = run_untyped_fn(untyped_fn, _copy_list(args), backend = "interp")
  except: 
    if testing_find_broken_transform: find_broken_transform(fn, args, expected)
    raise
   
  label = "interp: inputs = %s" % ", ".join(str(arg) for arg in args)

  if msg is not None:
    label += "-" + str(msg)
    
  try: 
    expect_eq(interp_result, expected, label)
  except: 
    if testing_find_broken_transform: find_broken_transform(fn, args, expected)
    raise 

  native_result = run_python_fn(fn, _copy_list(args), backend="c")
  
  if valid_types is not None:
    if not isinstance(valid_types, (tuple, list)):
      valid_types = [valid_types]
    assert type(native_result) in valid_types, \
      "Expected result to have type in %s but got %s" % (valid_types, type(native_result))
  label = "native: inputs = %s" % ", ".join(str(arg) for arg in args)
  if msg is not None:
      label += "-" + str(msg)
  try:
    expect_eq(native_result, expected, label)
  except:
    if testing_find_broken_transform:
      find_broken_transform(fn, args, expected)
    raise 
예제 #2
0
def test_identity():
  expect_eq(identity_i64(1), 1)
  expect_eq(identity_i64(-1), -1)
  expect_eq(identity_f64(1.0), 1.0)
  expect_eq(identity_f64(-1.0), -1.0)
예제 #3
0
def test_sum():
  expect_eq(sum_i64(np.array([1,2,3])), 6)
  expect_eq(sum_f64(np.array([-1.0, 1.0, 2.0])), 2.0)
예제 #4
0
def test_vec_add(): 
  xs,ys = np.array([1,2,3]), np.array([10,20,30])
  zs = vec_add(xs, ys)
  expected = xs + ys 
  expect_eq(zs, expected)