def test_eta_expand_constructor(): mod = relay.fromtext(r""" v0.0.4 type List[A] { Cons(A, List[A]), Nil, } def @main[A]() -> (fn(A, List[A]) -> List[A]) { Cons } """) seq = _transform.Sequential( [_transform.EtaExpand(expand_constructor=True)]) with _transform.PassContext(opt_level=3): mod = seq(mod) expected = relay.fromtext(r""" v0.0.4 type List[A] { Cons(A, List[A]), Nil, } def @main[A]() -> (fn(A, List[A]) -> List[A]) { fn [A](%x: A, %xs: List[A]) -> List[A] { Cons(%x, %xs) } } """) relay.analysis.assert_graph_equal(mod['main'], expected['main'])
def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) return mod["main"]
def test_eta_expand_global_var(): mod = relay.fromtext(r""" v0.0.4 def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { @aux } """) seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) with _transform.PassContext(opt_level=3): mod = seq(mod) expected = relay.fromtext(r""" v0.0.4 def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { fn (%x: Tensor[(), int32]) -> Tensor[(), int32] { @aux(%x) } } """) relay.analysis.assert_graph_equal(mod['main'], expected['main'])
def test_checkpoint_alpha_equal(): xs = [ relay.var("x{}".format(i), relay.TensorType((1, ), "float32")) for i in range(4) ] f = relay.Function( xs, relay.annotation.checkpoint( relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])))) df = transform.gradient(run_infer_type(f)) # run PE and DCE with transform.PassContext(opt_level=3): passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ] mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = relay.parser.fromtext(""" v0.0.4 fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> (Tensor[(1), float32], (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { %0 = add(%x, %y); %1 = add(%z, %w); let %x1: Tensor[(1), float32] = multiply(%0, %1); let %x2: Tensor[(1), float32] = ones_like(%x1); let %x3: Tensor[(1), float32] = add(%x, %y); let %x4: Tensor[(1), float32] = add(%z, %w); %2 = zeros_like(%x3); %3 = multiply(%x2, %x4); %4 = collapse_sum_like(%3, %x3); let %x5: Tensor[(1), float32] = add(%2, %4); %5 = zeros_like(%x4); %6 = multiply(%x2, %x3); %7 = collapse_sum_like(%6, %x4); let %x6: Tensor[(1), float32] = add(%5, %7); %8 = zeros_like(%x); %9 = collapse_sum_like(%x5, %x); %10 = add(%8, %9); %11 = zeros_like(%y); %12 = collapse_sum_like(%x5, %y); %13 = add(%11, %12); %14 = zeros_like(%z); %15 = collapse_sum_like(%x6, %z); %16 = add(%14, %15); %17 = zeros_like(%w); %18 = collapse_sum_like(%x6, %w); %19 = add(%17, %18); %20 = (%10, %13, %16, %19); (%x1, %20) } """) relay.analysis.assert_alpha_equal(df, df_parsed)
def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body
def test_checkpoint_alpha_equal_tuple(): xs = [ relay.var("x{}".format(i), relay.TensorType((1, ), "float32")) for i in range(4) ] f = relay.Function( xs, relay.annotation.checkpoint( relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])]))) df = transform.gradient(run_infer_type(f)) # run PE and DCE with transform.PassContext(opt_level=3): passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ] mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = relay.parser.fromtext(""" v0.0.4 fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> ((Tensor[(1), float32], Tensor[(1), float32]), (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */; let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */; let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */; let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */; %0 = (%x1, %x2); %1 = zeros_like(%x) /* ty=Tensor[(1), float32] */; %2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */; %3 = add(%1, %2) /* ty=Tensor[(1), float32] */; %4 = zeros_like(%y) /* ty=Tensor[(1), float32] */; %5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */; %6 = add(%4, %5) /* ty=Tensor[(1), float32] */; %7 = zeros_like(%z) /* ty=Tensor[(1), float32] */; %8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */; %9 = add(%7, %8) /* ty=Tensor[(1), float32] */; %10 = zeros_like(%w) /* ty=Tensor[(1), float32] */; %11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */; %12 = add(%10, %11) /* ty=Tensor[(1), float32] */; %13 = (%3, %6, %9, %12); (%0, %13) } """) relay.analysis.assert_alpha_equal(df, df_parsed)
def test_eta_expand_basic(): x = relay.var('x', 'int32') orig = relay.Function([x], x) mod = _module.Module.from_expr(orig) seq = _transform.Sequential([_transform.EtaExpand()]) with _transform.PassContext(opt_level=3): mod = seq(mod) got = mod[mod.entry_func.name_hint] y = relay.var('y', 'int32') expected = relay.Function([y], orig(y)) got = relay.ir_pass.infer_type(got, mod) expected = relay.ir_pass.infer_type(expected, mod) assert (relay.ir_pass.alpha_equal(got, expected))
def test_eta_expand_basic(): x = relay.var('x', 'int32') orig = relay.Function([x], x) mod = _module.Module.from_expr(orig) seq = _transform.Sequential([_transform.EtaExpand()]) with _transform.PassContext(opt_level=3): mod = seq(mod) got = mod["main"] y = relay.var('y', 'int32') expected = relay.Function([y], orig(y)) gv = relay.GlobalVar("gv") mod[gv] = expected mod = _transform.InferType()(mod) expected = mod["gv"] assert (relay.analysis.alpha_equal(got, expected))
def check(shape): data = relay.var("data", shape=shape, dtype="int8") conv_weight = relay.var("weight") bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32") bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32") y = before(data, conv_weight, bias1, bias2) mod = tvm.IRModule.from_expr(y) seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(), _transform.InferType()]) with _transform.PassContext(opt_level=3): mod = seq(mod) y = mod["main"] y_expected = expected(data, conv_weight, bias1, bias2) gv = relay.GlobalVar("expected") mod[gv] = y_expected mod = _transform.InferType()(mod) y_expected = mod["expected"] assert relay.analysis.alpha_equal(y, y_expected)
def check(shape): data = relay.var("data", shape=shape, dtype="int8") conv_weight = relay.var("weight") bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32") bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32") y = before(data, conv_weight, bias1, bias2) mod = _module.Module.from_expr(y) seq = _transform.Sequential([ _transform.InferType(), _transform.CanonicalizeCast(), _transform.InferType() ]) with _transform.PassContext(opt_level=3): mod = seq(mod) y = mod[mod.entry_func.name_hint] y_expected = expected(data, conv_weight, bias1, bias2) y_expected = relay.ir_pass.infer_type(y_expected) assert relay.ir_pass.alpha_equal(y, y_expected)