示例#1
0
    def _make_env_find(self, m, rval_t):
        ctr = m['ctr']
        gv = relay.GlobalVar(f"$_env_find<{ctr.name_hint}>")

        env = relay.Var("env", env_type(env_val()))
        key = relay.Var("key", relay.ty.scalar_type('int64'))
        dft = relay.Var("dft", rval_t)

        k = relay.Var("k")
        v = relay.Var("v")
        r = relay.Var("r")
        x = relay.Var("x")

        extract_clause = adt.Clause(
            adt.PatternConstructor(ctr, [adt.PatternVar(x)]), x)

        empty_clause = adt.Clause(adt.PatternConstructor(empty_env, []), dft)
        cons_clause = adt.Clause(
            adt.PatternConstructor(
                cons_env,
                [adt.PatternVar(k),
                 adt.PatternVar(v),
                 adt.PatternVar(r)]),
            relay.If(relay.equal(key, k),
                     adt.Match(v, [extract_clause], complete=False),
                     relay.Call(gv, [r, key, dft])))
        body = adt.Match(env, [empty_clause, cons_clause])
        fn = relay.Function([env, key, dft], body, rval_t)
        m['env_find'] = (gv, fn)
        return gv, fn
示例#2
0
    def _make_env_update(self, m, rval_t):
        ctr = m['ctr']
        gv = relay.GlobalVar(f"$_env_update<{ctr.name_hint}>")

        env = relay.Var("env", env_type(env_val()))
        key = relay.Var("key", relay.ty.scalar_type('int64'))
        val = relay.Var("val", rval_t)

        k = relay.Var("k")
        v = relay.Var("v")
        r = relay.Var("r")

        empty_clause = adt.Clause(adt.PatternConstructor(empty_env, []),
                                  cons_env(key, ctr(val), env))
        cons_clause = adt.Clause(
            adt.PatternConstructor(
                cons_env,
                [adt.PatternVar(k),
                 adt.PatternVar(v),
                 adt.PatternVar(r)]),
            relay.If(relay.equal(key, k), cons_env(key, ctr(val), env),
                     cons_env(k, v, relay.Call(gv, [r, key, val]))))
        body = adt.Match(env, [empty_clause, cons_clause])
        fn = relay.Function([env, key, val], body, env_type(env_val()))
        m['env_update'] = (gv, fn)
        return gv, fn
示例#3
0
文件: relay.py 项目: GonChen/myia
def relay_hastag(c, x, tag):
    """Implementation of hastag for Relay."""
    assert tag.is_constant(int)
    rtag = get_union_ctr(tag.value, x.abstract.options.get(tag.value))
    t_clause = adt.Clause(adt.PatternConstructor(
        rtag, [adt.PatternWildcard()]), relay.const(True))
    f_clause = adt.Clause(adt.PatternWildcard(), relay.const(False))
    return adt.Match(c.ref(x), [t_clause, f_clause])
示例#4
0
    def do_env_find(self, env, key, dft):
        """Build the code to find a value in env."""
        v = relay.var("v")
        cl = adt.Clause(
            adt.PatternConstructor(self.env_ctr, [adt.PatternVar(v)]), v)
        env_v = adt.Match(env, [cl], complete=False)

        val = relay.TupleGetItem(env_v, self.env_val_map[key][0])
        x = relay.var("x")
        nil_c = adt.Clause(adt.PatternConstructor(nil, []), dft)
        some_c = adt.Clause(adt.PatternConstructor(some, [adt.PatternVar(x)]),
                            x)
        return adt.Match(val, [some_c, nil_c])
示例#5
0
def relay_casttag(c, x, tag):
    """Implementation of casttag for Relay."""
    assert tag.is_constant(int)
    rtag = get_union_ctr(tag.value, x.abstract.options.get(tag.value))
    v = relay.Var("v")
    clause = adt.Clause(adt.PatternConstructor(rtag, [adt.PatternVar(v)]), v)
    return adt.Match(c.ref(x), [clause], complete=False)
示例#6
0
    def do_env_update(self, env_, key, val):
        """Build the code to update the env."""
        v = relay.var("v")
        cl = adt.Clause(
            adt.PatternConstructor(self.env_ctr, [adt.PatternVar(v)]), v)
        env = adt.Match(env_, [cl], complete=False)

        map = dict((i, k) for k, (i, _) in self.env_val_map.items())
        new_env = relay.Tuple([
            some(val) if map[i] == key else relay.TupleGetItem(env, i)
            for i in range(len(map))
        ])
        return self.env_ctr(new_env)