Exemplo n.º 1
0
def test_eta_expand_constructor():
    mod = tvm.parser.fromtext(r"""
        #[version = "0.0.5"]
        type List[A] {
            Cons(A, List[A]),
            Nil,
        }
        def @main[A]() -> fn(A, List[A]) -> List[A] {
            Cons
        }
    """)
    seq = tvm.transform.Sequential(
        [_transform.EtaExpand(expand_constructor=True)])
    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)
    expected = tvm.parser.fromtext(r"""
        #[version = "0.0.5"]
        type List[A] {
            Cons(A, List[A]),
            Nil,
        }
        def @main[A]() -> fn(A, List[A]) -> List[A] {
            fn [A](%x: A, %xs: List[A]) -> List[A] {
                Cons(%x, %xs)
            }
        }
    """)
    tvm.ir.assert_structural_equal(mod["main"],
                                   expected["main"],
                                   map_free_vars=True)
Exemplo n.º 2
0
def test_eta_expand_global_var():
    mod = tvm.parser.fromtext(r"""
        #[version = "0.0.5"]
        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
            %x
        }
        def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] {
            @aux
        }
    """)
    seq = tvm.transform.Sequential(
        [_transform.EtaExpand(expand_global_var=True)])
    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)
    expected = tvm.parser.fromtext(r"""
        #[version = "0.0.5"]
        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
            %x
        }
        def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] {
            fn (%x: Tensor[(), int32]) -> Tensor[(), int32] {
                @aux(%x)
            }
        }
    """)
    tvm.ir.assert_structural_equal(mod["main"],
                                   expected["main"],
                                   map_free_vars=True)
def test_eta_expand_constructor():
    mod = relay.fromtext(r"""
        v0.0.4
        type List[A] {
            Cons(A, List[A]),
            Nil,
        }
        def @main[A]() -> (fn(A, List[A]) -> List[A]) {
            Cons
        }
    """)
    seq = _transform.Sequential(
        [_transform.EtaExpand(expand_constructor=True)])
    with _transform.PassContext(opt_level=3):
        mod = seq(mod)
    expected = relay.fromtext(r"""
        v0.0.4
        type List[A] {
            Cons(A, List[A]),
            Nil,
        }
        def @main[A]() -> (fn(A, List[A]) -> List[A]) {
            fn [A](%x: A, %xs: List[A]) -> List[A] {
                Cons(%x, %xs)
            }
        }
    """)
    relay.analysis.assert_graph_equal(mod['main'], expected['main'])
def test_eta_expand_global_var():
    mod = relay.fromtext(r"""
        v0.0.4
        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
            %x
        }
        def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
            @aux
        }
    """)
    seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
    with _transform.PassContext(opt_level=3):
        mod = seq(mod)
    expected = relay.fromtext(r"""
        v0.0.4
        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
            %x
        }
        def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
            fn (%x: Tensor[(), int32]) -> Tensor[(), int32] {
                @aux(%x)
            }
        }
    """)
    relay.analysis.assert_graph_equal(mod['main'], expected['main'])
Exemplo n.º 5
0
def test_eta_expand_basic():
    x = relay.var('x', 'int32')
    orig = relay.Function([x], x)
    mod = _module.Module.from_expr(orig)
    seq = _transform.Sequential([_transform.EtaExpand()])
    with _transform.PassContext(opt_level=3):
        mod = seq(mod)

    got = mod[mod.entry_func.name_hint]

    y = relay.var('y', 'int32')
    expected = relay.Function([y], orig(y))

    got = relay.ir_pass.infer_type(got, mod)
    expected = relay.ir_pass.infer_type(expected, mod)
    assert (relay.ir_pass.alpha_equal(got, expected))
def test_eta_expand_basic():
    x = relay.var('x', 'int32')
    orig = relay.Function([x], x)
    mod = _module.Module.from_expr(orig)
    seq = _transform.Sequential([_transform.EtaExpand()])
    with _transform.PassContext(opt_level=3):
        mod = seq(mod)

    got = mod["main"]

    y = relay.var('y', 'int32')
    expected = relay.Function([y], orig(y))
    gv = relay.GlobalVar("gv")
    mod[gv] = expected
    mod = _transform.InferType()(mod)
    expected = mod["gv"]
    assert (relay.analysis.alpha_equal(got, expected))