def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() x = b.let('x', 1.0, value_type=scalar_type('float64')) b.ret(x) prog, env = b.get() assert_has_type(prog, scalar_type('float64'))
def test_let(): b = IRBuilder() x = b.let('x', 1) b.ret(x) prog, _ = b.get() assert isinstance(prog, Let) var = prog.var value = prog.value assert var.name_hint == 'x' assert var == prog.body assert isinstance(value, Constant) assert value.data.asnumpy() == np.array(1)
def test_decl(): """Program: def f(x : Tensor[f32, (10, 10)]) { let lx = log(x); return lx; } """ b = IRBuilder() x = b.param('x') with b.decl('f', x): lx = b.let('lx', log(x)) b.ret(lx) _, env = b.get() assert_decl_has_type(env, 'f', func_type(['float32'], 'float32'))
def test_add_broadcast_op(): """ Program: fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { return x + y; } """ b = IRBuilder() x = b.param('x', tensor_type(10, 4)) y = b.param('y', tensor_type(5, 10, 1)) with b.function(x, y) as func: b.ret(add(x.var, y.var)) b.ret(func) prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert_has_type(func.to_func(), expected_ty)
def check_binary_broadcast_op(opfunc): """ Program: fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { return x <op> y; } """ b = IRBuilder() x = b.param('x', tensor_type(10, 4)) y = b.param('y', tensor_type(5, 10, 1)) with b.function(x, y) as func: b.ret(opfunc(x.var, y.var)) b.ret(func) prog, env = b.get() expected_ty = func_type( [tensor_type(10, 4), tensor_type(5, 10, 1)], tensor_type(5, 10, 4)) assert_has_type(func.to_func(), expected_ty)
def check_binary_op(opfunc): """ Program: fn (x, y) { return x <op> y; } """ b = IRBuilder() x = b.param('x', tensor_type(5, 5, 5)) y = b.param('y', tensor_type(5, 5, 5)) with b.function(x, y) as func: b.ret(opfunc(x.var, y.var)) b.ret(func) prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert_has_type(func.to_func(), expected_ty)