def test_checkpoint_alpha_equal(): xs = [ relay.var("x{}".format(i), relay.TensorType((1, ), "float32")) for i in range(4) ] f = relay.Function( xs, relay.annotation.checkpoint( relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))), ) df = transform.gradient(run_infer_type(f)) # run PE and DCE with tvm.transform.PassContext(opt_level=3): passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = tvm.parser.parse_expr(""" #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> (Tensor[(1), float32], (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { %0 = add(%x, %y); %1 = add(%z, %w); let %x1: Tensor[(1), float32] = multiply(%0, %1); let %x2: Tensor[(1), float32] = ones_like(%x1); let %x3: Tensor[(1), float32] = add(%x, %y); let %x4: Tensor[(1), float32] = add(%z, %w); %2 = zeros_like(%x3); %3 = multiply(%x2, %x4); %4 = collapse_sum_like(%3, %x3); let %x5: Tensor[(1), float32] = add(%2, %4); %5 = zeros_like(%x4); %6 = multiply(%x2, %x3); %7 = collapse_sum_like(%6, %x4); let %x6: Tensor[(1), float32] = add(%5, %7); %8 = zeros_like(%x); %9 = collapse_sum_like(%x5, %x); %10 = add(%8, %9); %11 = zeros_like(%y); %12 = collapse_sum_like(%x5, %y); %13 = add(%11, %12); %14 = zeros_like(%z); %15 = collapse_sum_like(%x6, %z); %16 = add(%14, %15); %17 = zeros_like(%w); %18 = collapse_sum_like(%x6, %w); %19 = add(%17, %18); %20 = (%10, %13, %16, %19); (%x1, %20) } """) tvm.ir.assert_structural_equal(df, df_parsed)
def test_before_partial_eval(): """Test transformation before PartialEval""" mod = tvm.IRModule() shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) y = relay.var("y", t) func = relay.Function([x, y], x * y) func = run_infer_type(func) back_func = transform.gradient(func) back_func = run_infer_type(back_func) mod["main"] = back_func seq = tvm.transform.Sequential( [transform.LazyGradientInit(), transform.PartialEvaluate(), transform.DeadCodeElimination()] ) mod = seq(mod) back_func = mod["main"] assert mod["main"].checked_type == relay.FuncType( [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape)
def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) assert Feature.fRefCreate not in detect_feature(x)
def failing_tvm_typecheck(): tvm_pass = transform.PartialEvaluate() ops = [ { 'func': relay.nn.conv2d, 'arity': 2, 'weight': 1 }, ] fuzz_pass(tvm_pass, prog_len=2, ops=ops, shape=[1, 1, 2, 2])
def test_checkpoint_alpha_equal_tuple(): xs = [ relay.var("x{}".format(i), relay.TensorType((1, ), "float32")) for i in range(4) ] f = relay.Function( xs, relay.annotation.checkpoint( relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])), ) df = transform.gradient(run_infer_type(f)) # run PE and DCE with tvm.transform.PassContext(opt_level=3): # See comment in test_checkpoint_alpha_equal above. # TODO(mbs): Revisit once DCE supports dead reference writes. passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), ] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = tvm.parser.parse_expr(""" #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> ((Tensor[(1), float32], Tensor[(1), float32]), (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */; let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */; let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */; let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */; %0 = (%x1, %x2); %1 = zeros_like(%x) /* ty=Tensor[(1), float32] */; %2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */; %3 = add(%1, %2) /* ty=Tensor[(1), float32] */; %4 = zeros_like(%y) /* ty=Tensor[(1), float32] */; %5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */; %6 = add(%4, %5) /* ty=Tensor[(1), float32] */; %7 = zeros_like(%z) /* ty=Tensor[(1), float32] */; %8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */; %9 = add(%7, %8) /* ty=Tensor[(1), float32] */; %10 = zeros_like(%w) /* ty=Tensor[(1), float32] */; %11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */; %12 = add(%10, %11) /* ty=Tensor[(1), float32] */; %13 = (%3, %6, %9, %12); (%0, %13) } """) tvm.ir.assert_structural_equal(df, df_parsed)
def dcpe(expr, mod=None, grad=False): passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] if grad: expr = gradient(run_infer_type(expr)) if mod: assert isinstance(expr, Function) mod["main"] = expr seq = transform.Sequential(passes) mod = seq(mod) return mod["main"] return run_opt_pass(expr, passes)
def test_checkpoint_alpha_equal_tuple(): xs = [ relay.var("x{}".format(i), relay.TensorType((1, ), "float32")) for i in range(4) ] f = relay.Function( xs, relay.annotation.checkpoint( relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])]))) df = transform.gradient(run_infer_type(f)) # run PE and DCE with transform.PassContext(opt_level=3): passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ] mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = relay.parser.fromtext(""" v0.0.4 fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> ((Tensor[(1), float32], Tensor[(1), float32]), (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */; let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */; let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */; let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */; %0 = (%x1, %x2); %1 = zeros_like(%x) /* ty=Tensor[(1), float32] */; %2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */; %3 = add(%1, %2) /* ty=Tensor[(1), float32] */; %4 = zeros_like(%y) /* ty=Tensor[(1), float32] */; %5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */; %6 = add(%4, %5) /* ty=Tensor[(1), float32] */; %7 = zeros_like(%z) /* ty=Tensor[(1), float32] */; %8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */; %9 = add(%7, %8) /* ty=Tensor[(1), float32] */; %10 = zeros_like(%w) /* ty=Tensor[(1), float32] */; %11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */; %12 = add(%10, %11) /* ty=Tensor[(1), float32] */; %13 = (%3, %6, %9, %12); (%0, %13) } """) relay.analysis.assert_alpha_equal(df, df_parsed)
def dcpe(expr, mod=None, grad=False): passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ] if grad: expr = gradient(expr) if mod: assert isinstance(expr, Function) mod[mod.entry_func] = expr seq = transform.Sequential(passes) mod = seq(mod) return mod[mod.entry_func] return transform.OptimizeOnExpr(expr, passes)
def test_partial_eval(): """Test transformation following reverse mode ad and PartialEval""" mod = tvm.IRModule() shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) func = relay.Function([], relay.const(np.ones(shape, dtype))) func = run_infer_type(func) back_func = transform.gradient(func) back_func = run_infer_type(back_func) mod["main"] = back_func back_func = mod["main"] transform.PartialEvaluate()(mod)
def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) # TODO(mbs): Revisit once DCE can eliminate dead writes. x = run_opt_pass( x, tvm.transform.Sequential( [ transform.PartialEvaluate(), transform.InferType(), transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), ] ), ) assert Feature.fRefCreate not in detect_feature(x)
def test_after_partial_eval(): """Test transformation following reverse mode ad and PartialEval""" mod = tvm.IRModule() shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) y = relay.var("y", t) func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype))) func = run_infer_type(func) back_func = transform.gradient(func) back_func = run_infer_type(back_func) mod["main"] = back_func back_func = mod["main"] seq = tvm.transform.Sequential( [ transform.PartialEvaluate(), transform.InferType(), transform.LazyGradientInit(), transform.InferType(), transform.DeadCodeElimination(), transform.InferType(), ] ) mod = seq(mod) assert mod["main"].checked_type == relay.FuncType( [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) x = rand(dtype, *shape) y = rand(dtype, *shape) (forward), (grad_x, grad_y,) = create_executor(mod=mod).evaluate( back_func )(x, y) assert_allclose(forward.numpy(), x.numpy() * y.numpy()) assert_allclose(grad_x.numpy(), y.numpy()) assert_allclose(grad_y.numpy(), x.numpy())
def match_pass_name(name): if name == 'FoldScaleAxis': return transform.FoldScaleAxis() if name == 'BackwardFoldScaleAxis': return transform.BackwardFoldScaleAxis() if name == 'ForwardFoldScaleAxis': return transform.ForwardFoldScaleAxis() if name == 'FuseOps': return transform.FuseOps(3) if name == 'FoldConstant': return transform.FoldConstant() if name == 'CombineParallelConv2d': return transform.CombineParallelConv2D() if name == 'AlterOpLayout': return transform.AlterOpLayout() if name == 'EliminateCommonSubexpr': return transform.EliminateCommonSubexpr() if name == 'PartialEvaluate': return transform.PartialEvaluate() if name == 'CanonicalizeCast': return transform.CanonicalizeCast() if name == 'CanonicalizeOps': return transform.CanonicalizeOps() raise Exception('Name {} does not match any pass'.format(name))
def tipe(expr): return run_opt_pass(expr, [transform.PartialEvaluate(), transform.InferType()])
def test_nat_update(): m = tvm.IRModule() p = Prelude(m) p.mod.import_from_std("nat.rly") m = transform.ToANormalForm()(m) transform.PartialEvaluate()(m)
def tipe(expr): return transform.OptimizeOnExpr(expr, [ transform.InferType(), transform.PartialEvaluate(), transform.InferType() ])
def test(): tvm_pass = transform.PartialEvaluate() fuzz_pass(tvm_pass)
def test_nat_update(): m = tvm.IRModule() p = Prelude(m) add_nat_definitions(p) m = transform.ToANormalForm()(m) transform.PartialEvaluate()(m)
def test_checkpoint_alpha_equal(): xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)] f = relay.Function( xs, relay.annotation.checkpoint( relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])) ), ) df = transform.gradient(run_infer_type(f)) # run PE and DCE with tvm.transform.PassContext(opt_level=3): # The expected output assumes DCE can elide 'dead writes' to references. At the time this unit test was # written DCE would elide all writes, which though unsound in general happens to work for this case. Preserve # that legacy behaviour here using 'ignore_impurity=True'. # TODO(mbs): Revisit once DCE supports dead reference writes. passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), ] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = tvm.parser.parse_expr( """ #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> (Tensor[(1), float32], (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { %0 = add(%x, %y); %1 = add(%z, %w); let %x1: Tensor[(1), float32] = multiply(%0, %1); let %x2: Tensor[(1), float32] = ones_like(%x1); let %x3: Tensor[(1), float32] = add(%x, %y); let %x4: Tensor[(1), float32] = add(%z, %w); %2 = zeros_like(%x3); %3 = multiply(%x2, %x4); %4 = collapse_sum_like(%3, %x3); let %x5: Tensor[(1), float32] = add(%2, %4); %5 = zeros_like(%x4); %6 = multiply(%x2, %x3); %7 = collapse_sum_like(%6, %x4); let %x6: Tensor[(1), float32] = add(%5, %7); %8 = zeros_like(%x); %9 = collapse_sum_like(%x5, %x); %10 = add(%8, %9); %11 = zeros_like(%y); %12 = collapse_sum_like(%x5, %y); %13 = add(%11, %12); %14 = zeros_like(%z); %15 = collapse_sum_like(%x6, %z); %16 = add(%14, %15); %17 = zeros_like(%w); %18 = collapse_sum_like(%x6, %w); %19 = add(%17, %18); %20 = (%10, %13, %16, %19); (%x1, %20) } """ ) tvm.ir.assert_structural_equal(df, df_parsed)