def test_ref_kind():
    # only contain type kinds
    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
    ft = relay.FuncType(tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([]))

    rt1 = relay.RefType(tt)
    assert check_kind(rt1) == relay.TypeKind.Type
    rt2 = relay.RefType(ft)
    assert check_kind(rt2) == relay.TypeKind.Type
    rt3 = relay.RefType(relay.TupleType([rt1, rt2]))
    assert check_kind(rt3) == relay.TypeKind.Type
Пример #2
0
def test_ref():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    r = relay.RefCreate(x)
    st = relay.scalar_type("float32")
    assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st)
    g = relay.RefRead(r)
    assert relay.ir_pass.infer_type(g).checked_type == st
    w = relay.RefWrite(r, y)
    assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([])
Пример #3
0
def test_ref():
    t = relay.TensorType([], "float32")
    d = relay.Var("d", t)
    r = relay.Var("r", relay.RefType(t))
    x = relay.Var("x")
    body = relay.RefRead(r)
    body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
    body = Let(r, RefCreate(d), body)
    square = Function([d], body)
    expected = run_opt_pass(Function([d], d * d), transform.InferType())
    assert tvm.ir.structural_equal(dcpe(square), expected)
Пример #4
0
def test_ref():
    t = relay.TensorType([], "float32")
    d = relay.Var("d", t)
    r = relay.Var("r", relay.RefType(t))
    x = relay.Var("x")
    body = relay.RefRead(r)
    body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
    body = Let(r, RefCreate(d), body)
    square = Function([d], body)
    expected = transform.OptimizeOnExpr(Function([d], d * d),
                                        transform.InferType())
    assert alpha_equal(dcpe(square), expected)
Пример #5
0
def test_ref():
    t = relay.TensorType([], "float32")
    d = relay.Var("d", t)
    r = relay.Var("r", relay.RefType(t))
    x = relay.Var("x")
    body = relay.RefRead(r)
    body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
    body = Let(r, RefCreate(d), body)
    square = Function([d], body)
    expected = run_opt_pass(Function([d], d * d), transform.InferType())
    # TODO(mbs): Revisit once DCE eliminates dead writes.
    actual = dcpe(square, ignore_impurity=True)
    assert tvm.ir.structural_equal(actual, expected)
Пример #6
0
def test_invalid_ref_kind():
    tp = relay.TypeVar("tp", relay.TypeKind.ShapeVar)
    rt = relay.RefType(tp)
    check_kind(rt)
Пример #7
0
def test_invalid_ref_kind():
    tp = relay.TypeVar('tp', relay.Kind.Shape)
    rt = relay.RefType(tp)
    check_kind(rt)