Example #1
0
    def __init__(self, *, name="f", **kwargs):
        name = f"{name}_{Network.cnt}"
        Network.cnt += 1
        if len(Network.stack) is not 0:
            mod = Network.stack[-1].mod
            p = Network.stack[-1].p
        else:
            mod = Module()
            p = Prelude(mod)
            add_nat_definitions(p)

        self.mod = mod
        self.p = p
        self.inputs = []
        self.weights = OrderedSet()
        self.sub_network = OrderedSet()
        self.f = relay.GlobalVar(name)
        self.recurse = relay.Var("recurse")
        self.use_recurse = False
        self.ret_type = None
        body = self.build(**kwargs)
        assert isinstance(body, relay.Expr)
        if self.use_recurse:
            inputs = [copy_var(v) for v in self.inputs]
            body = relay.Let(
                self.recurse,
                relay.Function(inputs, self.call_from_outside(*inputs)), body)
        self.mod[self.f] = relay.Function(self.inputs + self.all_weights(),
                                          body, self.ret_type)
Example #2
0
def test_add_convert():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    cfunc = compile(p.add, mod)
    output = cfunc(int_to_nat(p, 12), int_to_nat(p, 34))
    assert nat_to_int(output) == 46
Example #3
0
def test_double():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    orig = p.double(make_nat_expr(p, 3))
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res, make_nat_expr(p, 6))
Example #4
0
def test_nat_3():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    cfunc = compile(Function([], p.s(p.s(p.s(p.z())))), mod)
    output = cfunc()
    assert nat_to_int(output) == 3
def test_double():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    orig = p.double(make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6))
Example #6
0
def test_global_match_nat_id():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
    orig = Match(make_nat_expr(p, 3), [z_case, s_case])
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res, make_nat_expr(p, 3))
def test_global_match_nat_id():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
    orig = Match(make_nat_expr(p, 3), [z_case, s_case])
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
Example #8
0
def test_nat_id():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    nat_id = GlobalVar("nat_id")
    mod[nat_id] = Function([x], x)
    orig = nat_id(make_nat_expr(p, 3))
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res, make_nat_expr(p, 3))
Example #9
0
def test_swap_loop():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    loop = GlobalVar("loop")
    mod[loop] = Function([x, y], loop(y, x), nat)
    prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
    res = dcpe(prog, mod=mod)
    assert alpha_equal(prog, res)
def test_nat_id():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    nat_id = GlobalVar("nat_id")
    mod[nat_id] = Function([x], x)
    orig = nat_id(make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
Example #11
0
def test_compose():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    x = relay.Var('x')
    inc = GlobalVar('inc')
    mod[inc] = Function([x], p.s(x))
    x = relay.Var('x')
    func = GlobalVar('func')
    f = Function([x], relay.Call(p.compose(inc, p.double), [x]))
    mod[func] = f
    cfunc = compile(func, mod)
    assert nat_to_int(cfunc(p.s(p.s(p.z())))) == 5
def test_swap_loop():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    loop = GlobalVar("loop")
    mod[loop] = Function([x, y], loop(y, x), nat)
    prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
    res = Function([], prog)
    res = dcpe(res, mod=mod)
    assert tvm.ir.structural_equal(prog, res.body)
Example #13
0
def test_match_nat_id():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    nat_id = GlobalVar("nat_id")
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
    mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
    orig = nat_id(make_nat_expr(p, 3))
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res, make_nat_expr(p, 3))
Example #14
0
def test_nat_add():
    mod = relay.Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat
    add = p.add
    s = p.s
    z = p.z
    ctx = tvm.context("llvm", 0)
    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(intrp.evaluate(add(s(z()), s(z())))) == 2
    assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())),
                                                 mod))) == 2
    assert "let" in mod[add].astext()
Example #15
0
def test_constructor_tag_round_trip():
    mod1 = tvm.IRModule()
    p1 = Prelude(mod1)
    add_nat_definitions(p1)
    mod2 = tvm.IRModule()
    p2 = Prelude(mod2)
    add_nat_definitions(p2)

    # ensure hashes match across modules
    ctors1 = constructor_list(p1)
    ctors2 = constructor_list(p2)

    for i in range(len(ctors1)):
        tag = ctors1[i].tag
        ctor = mod2.get_constructor(tag)
        assert ctor == ctors2[i]
        assert ctor.name_hint == ctors1[i].name_hint
Example #16
0
def test_recursion():
    mod = relay.Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    shape = (10, 10)
    dtype = 'float32'
    t = relay.TensorType(shape, dtype)
    x = relay.var("x", t)
    double = relay.Function([x], x + x)
    i = relay.var("i", t)
    func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
    mod[mod.entry_func] = func
    mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod)
    mod[mod.entry_func] = un_cps(mod[mod.entry_func])
    ex = create_executor(mod=mod)
    i_nd = rand(dtype, *shape)
    forward = ex.evaluate(mod.entry_func)(i_nd)
    tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
Example #17
0
def test_constructor_tag_differences():
    # ensure that if we have the type data for a given ADT, the tags
    # for the constructors of the *same ADT* are simple offsets from
    # each other
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)

    adts = adt_list(p)
    for adt in adts:
        data = mod[adt]
        for i in range(len(data.constructors) - 1):
            ctor1 = data.constructors[i]
            ctor2 = data.constructors[i + 1]
            assert ctor2.tag - ctor1.tag == 1
            # make sure there is something present at the MSB
            assert ctor1.tag - i != 0
            assert ctor2.tag - (i + 1) != 0
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat
    add = p.add
    s = p.s
    z = p.z
    ctx = tvm.context("llvm", 0)
    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar("f")
    mod[f] = relay.Function([], expr)
    mod = transform.ToANormalForm()(mod)
    expr = mod["f"]
    assert count(p, intrp.evaluate(expr.body)) == 2
    assert Feature.fLet in detect_feature(mod[add])
Example #19
0
def test_pow():
    mod = relay.Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    shape = (10, 10)
    dtype = 'float32'
    t = relay.TensorType(shape, dtype)
    x = relay.var("x", t)
    double = relay.Function([x], x + x)
    i = relay.var("i", t)
    func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
    back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod)
    assert back_func.checked_type == relay.FuncType(
        [t], relay.TupleType([t, relay.TupleType([t])]))
    i_nd = rand(dtype, *shape)
    ex = create_executor(mod=mod)
    forward, (grad_i, ) = ex.evaluate(back_func)(i_nd)
    tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
    tvm.testing.assert_allclose(grad_i.asnumpy(),
                                8 * np.ones_like(grad_i.asnumpy()))
Example #20
0
def test_abs_diff():
    # TODO(@M.K.): refactor using tuple pattern (not yet implemented)
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    xp = Var("x'", nat)
    yp = Var("y'", nat)
    diff = GlobalVar("diff")
    y_z_case = Clause(PatternConstructor(p.z, []), x)
    y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
    x_z_case = Clause(PatternConstructor(p.z, []), y)
    x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
    mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
    orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, make_nat_expr(p, 4))
def test_nat_update():
    m = tvm.IRModule()
    p = Prelude(m)
    add_nat_definitions(p)
    m = transform.ToANormalForm()(m)
    transform.PartialEvaluate()(m)
Example #22
0
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay.backend.interpreter import ConstructorValue
from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr

import numpy as np

mod = relay.Module()
p = Prelude(mod)
add_nat_definitions(p)


def count(e):
    return count_(p, e)


ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")

z = p.z
s = p.s
nat = p.nat
double = p.double
add = p.add