def test_ref_kind(): # only contain type kinds tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32') ft = relay.FuncType(tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([])) rt1 = relay.RefType(tt) assert check_kind(rt1) == relay.TypeKind.Type rt2 = relay.RefType(ft) assert check_kind(rt2) == relay.TypeKind.Type rt3 = relay.RefType(relay.TupleType([rt1, rt2])) assert check_kind(rt3) == relay.TypeKind.Type
def test_ref(): x = relay.var("x", "float32") y = relay.var("y", "float32") r = relay.RefCreate(x) st = relay.scalar_type("float32") assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st) g = relay.RefRead(r) assert relay.ir_pass.infer_type(g).checked_type == st w = relay.RefWrite(r, y) assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([])
def test_ref(): t = relay.TensorType([], "float32") d = relay.Var("d", t) r = relay.Var("r", relay.RefType(t)) x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) expected = run_opt_pass(Function([d], d * d), transform.InferType()) assert tvm.ir.structural_equal(dcpe(square), expected)
def test_ref(): t = relay.TensorType([], "float32") d = relay.Var("d", t) r = relay.Var("r", relay.RefType(t)) x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) expected = transform.OptimizeOnExpr(Function([d], d * d), transform.InferType()) assert alpha_equal(dcpe(square), expected)
def test_ref(): t = relay.TensorType([], "float32") d = relay.Var("d", t) r = relay.Var("r", relay.RefType(t)) x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) expected = run_opt_pass(Function([d], d * d), transform.InferType()) # TODO(mbs): Revisit once DCE eliminates dead writes. actual = dcpe(square, ignore_impurity=True) assert tvm.ir.structural_equal(actual, expected)
def test_invalid_ref_kind(): tp = relay.TypeVar("tp", relay.TypeKind.ShapeVar) rt = relay.RefType(tp) check_kind(rt)
def test_invalid_ref_kind(): tp = relay.TypeVar('tp', relay.Kind.Shape) rt = relay.RefType(tp) check_kind(rt)