コード例 #1
0
def test_type_relation():
    func = tvm.get_env_func('tvm.relay.type_relation.Broadcast')
    attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4))
    tp = TypeVar('tp')
    tf = FuncType([], TupleType([]), [], [])
    tt = TensorType([1, 2, 3], 'float32')
    tr = TypeRelation(func, [tp, tf, tt], 2, attrs)

    check_visit(tr)
コード例 #2
0
def test_type_relation():
    func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
    attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4))
    tp = TypeVar("tp")
    tf = FuncType([], TupleType([]), [], [])
    tt = TensorType([1, 2, 3], "float32")
    tr = TypeRelation(func, [tp, tf, tt], 2, attrs)

    check_visit(tr)
コード例 #3
0
def define_nat_nth(prelude):
    """Defines a function to get the nth eleemnt of a list using
    a nat to index into the list.

    nat_nth(l, n): fun<a>(list[a], nat) -> a
    """
    prelude.nat_nth = GlobalVar("nat_nth")
    a = TypeVar("a")
    x = Var("x", prelude.l(a))
    n = Var("n", prelude.nat())
    y = Var("y")

    z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x))
    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
                    prelude.nat_nth(prelude.tl(x), y))

    prelude.mod[prelude.nat_nth] = Function([x, n], Match(n, [z_case, s_case]),
                                            a, [a])
コード例 #4
0
def define_nat_iterate(prelude):
    """Defines a function that takes a number n and a function f;
    returns a closure that takes an argument and applies f
    n times to its argument.

    Signature: fn<a>(fn(a) -> a, nat) -> fn(a) -> a
    """
    prelude.nat_iterate = GlobalVar("nat_iterate")
    a = TypeVar("a")
    f = Var("f", FuncType([a], a))
    x = Var("x", prelude.nat())
    y = Var("y", prelude.nat())

    z_case = Clause(PatternConstructor(prelude.z), prelude.id)
    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
                    prelude.compose(f, prelude.nat_iterate(f, y)))

    prelude.mod[prelude.nat_iterate] = Function([f, x],
                                                Match(x, [z_case, s_case]),
                                                FuncType([a], a), [a])
コード例 #5
0
def define_nat_update(prelude):
    """Defines a function to update the nth element of a list and return the updated list.

    nat_update(l, i, v) : fun<a>(list[a], nat, a) -> list[a]
    """
    prelude.nat_update = GlobalVar("nat_update")
    a = TypeVar("a")
    # pylint: disable=invalid-name
    l = Var("l", prelude.l(a))
    n = Var("n", prelude.nat())
    v = Var("v", a)
    y = Var("y")

    z_case = Clause(PatternConstructor(prelude.z),
                    prelude.cons(v, prelude.tl(l)))
    s_case = Clause(
        PatternConstructor(prelude.s, [PatternVar(y)]),
        prelude.cons(prelude.hd(l), prelude.nat_update(prelude.tl(l), y, v)))

    prelude.mod[prelude.nat_update] = Function([l, n, v],
                                               Match(n, [z_case, s_case]),
                                               prelude.l(a), [a])
コード例 #6
0
def test_func_type():
    tv = TypeVar("tv")
    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
    ft = FuncType([tt], tt, type_params=[tv])
    check_visit(ft)
コード例 #7
0
def test_type_var():
    tv = TypeVar("a")
    check_visit(tv)
コード例 #8
0
def test_type_data():
    td = TypeData(GlobalTypeVar("td"), [TypeVar("tv")], [])
    check_visit(td)
コード例 #9
0
def test_type_data():
    td = TypeData(GlobalTypeVar('td'), [TypeVar('tv')], [])
    check_visit(td)
コード例 #10
0
def test_func_type():
    tv = TypeVar('tv')
    tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
    ft = FuncType([tt], tt, type_params=[tv])
    check_visit(ft)