def test_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d * d) g = dcpe(f, grad=True) m = d * d x = relay.Var("x") o = op.ones_like(x) x1 = relay.Var("x1") grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d) body = Tuple([x, Tuple([grad])]) body = relay.Let(x1, o, body) expected = Function([d], relay.Let(x, m, body)) expected = run_opt_pass(expected, transform.InferType()) assert_alpha_equal(g, expected)
def test_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d * d) # TODO(mbs): Revisit once DCE eliminates dead writes. g = dcpe(f, grad=True, ignore_impurity=True) m = d * d x = relay.Var("x") o = op.ones_like(x) x1 = relay.Var("x1") grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d) body = Tuple([x, Tuple([grad])]) body = relay.Let(x1, o, body) expected = Function([d], relay.Let(x, m, body)) expected = run_opt_pass(expected, transform.InferType()) tvm.ir.assert_structural_equal(g, expected)