def test_tuple(): ib = IRBuilder() dup = ib.global_var('dup') x = ib.param('x') with ib.decl(dup, x): ib.ret(relay.Tuple([x, x])) # todo: why is this not generalized? fn_ty = func_type([tensor_type()], relay.TupleType([tensor_type(), tensor_type()])) assert_decl_has_type(ib.env, dup, fn_ty)
def test_concat(): """ Program: def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) { return concat(x, y); } """ ib = IRBuilder() try_concat2 = ib.global_var('try_concat2') x = ib.param('x', ty=tensor_type(3, 2)) y = ib.param('y', ty=tensor_type(2, 2)) with ib.decl(try_concat2, x, y): ib.ret(concat(x, y)) fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2)) assert_decl_has_type(ib.env, try_concat2, fn_ty)
def test_dual_op(): """Program: fn (x : Tensor[f32, (10, 10)]) { let t1 = log(x); let t2 = add(t1, x); return t1; } """ b = IRBuilder() with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() t1 = b.let('t1', log(x)) t2 = b.let('t2', add(t1, x)) b.ret(t2) assert_has_type(func.to_func(), func_type([tensor_type(10, 10)], tensor_type(10, 10)))
def test_add_broadcast_op(): """ Program: fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { return x + y; } """ b = IRBuilder() x = b.param('x', tensor_type(10, 4)) y = b.param('y', tensor_type(5, 10, 1)) with b.function(x, y) as func: b.ret(add(x.var, y.var)) b.ret(func) prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert_has_type(func.to_func(), expected_ty)
def check_binary_broadcast_op(opfunc): """ Program: fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { return x <op> y; } """ b = IRBuilder() x = b.param('x', tensor_type(10, 4)) y = b.param('y', tensor_type(5, 10, 1)) with b.function(x, y) as func: b.ret(opfunc(x.var, y.var)) b.ret(func) prog, env = b.get() expected_ty = func_type( [tensor_type(10, 4), tensor_type(5, 10, 1)], tensor_type(5, 10, 4)) assert_has_type(func.to_func(), expected_ty)
def check_binary_op(opfunc): """ Program: fn (x, y) { return x <op> y; } """ b = IRBuilder() x = b.param('x', tensor_type(5, 5, 5)) y = b.param('y', tensor_type(5, 5, 5)) with b.function(x, y) as func: b.ret(opfunc(x.var, y.var)) b.ret(func) prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert_has_type(func.to_func(), expected_ty)