def _make_env_find(self, m, rval_t): ctr = m['ctr'] gv = relay.GlobalVar(f"$_env_find<{ctr.name_hint}>") env = relay.Var("env", env_type(env_val())) key = relay.Var("key", relay.ty.scalar_type('int64')) dft = relay.Var("dft", rval_t) k = relay.Var("k") v = relay.Var("v") r = relay.Var("r") x = relay.Var("x") extract_clause = adt.Clause( adt.PatternConstructor(ctr, [adt.PatternVar(x)]), x) empty_clause = adt.Clause(adt.PatternConstructor(empty_env, []), dft) cons_clause = adt.Clause( adt.PatternConstructor( cons_env, [adt.PatternVar(k), adt.PatternVar(v), adt.PatternVar(r)]), relay.If(relay.equal(key, k), adt.Match(v, [extract_clause], complete=False), relay.Call(gv, [r, key, dft]))) body = adt.Match(env, [empty_clause, cons_clause]) fn = relay.Function([env, key, dft], body, rval_t) m['env_find'] = (gv, fn) return gv, fn
def _make_env_update(self, m, rval_t): ctr = m['ctr'] gv = relay.GlobalVar(f"$_env_update<{ctr.name_hint}>") env = relay.Var("env", env_type(env_val())) key = relay.Var("key", relay.ty.scalar_type('int64')) val = relay.Var("val", rval_t) k = relay.Var("k") v = relay.Var("v") r = relay.Var("r") empty_clause = adt.Clause(adt.PatternConstructor(empty_env, []), cons_env(key, ctr(val), env)) cons_clause = adt.Clause( adt.PatternConstructor( cons_env, [adt.PatternVar(k), adt.PatternVar(v), adt.PatternVar(r)]), relay.If(relay.equal(key, k), cons_env(key, ctr(val), env), cons_env(k, v, relay.Call(gv, [r, key, val])))) body = adt.Match(env, [empty_clause, cons_clause]) fn = relay.Function([env, key, val], body, env_type(env_val())) m['env_update'] = (gv, fn) return gv, fn
def do_env_find(self, env, key, dft): """Build the code to find a value in env.""" v = relay.var("v") cl = adt.Clause( adt.PatternConstructor(self.env_ctr, [adt.PatternVar(v)]), v) env_v = adt.Match(env, [cl], complete=False) val = relay.TupleGetItem(env_v, self.env_val_map[key][0]) x = relay.var("x") nil_c = adt.Clause(adt.PatternConstructor(nil, []), dft) some_c = adt.Clause(adt.PatternConstructor(some, [adt.PatternVar(x)]), x) return adt.Match(val, [some_c, nil_c])
def relay_casttag(c, x, tag): """Implementation of casttag for Relay.""" assert tag.is_constant(int) rtag = get_union_ctr(tag.value, x.abstract.options.get(tag.value)) v = relay.Var("v") clause = adt.Clause(adt.PatternConstructor(rtag, [adt.PatternVar(v)]), v) return adt.Match(c.ref(x), [clause], complete=False)
def relay_hastag(c, x, tag): """Implementation of hastag for Relay.""" assert tag.is_constant(int) rtag = get_union_ctr(tag.value, x.abstract.options.get(tag.value)) t_clause = adt.Clause(adt.PatternConstructor( rtag, [adt.PatternWildcard()]), relay.const(True)) f_clause = adt.Clause(adt.PatternWildcard(), relay.const(False)) return adt.Match(c.ref(x), [t_clause, f_clause])
def do_env_update(self, env_, key, val): """Build the code to update the env.""" v = relay.var("v") cl = adt.Clause( adt.PatternConstructor(self.env_ctr, [adt.PatternVar(v)]), v) env = adt.Match(env_, [cl], complete=False) map = dict((i, k) for k, (i, _) in self.env_val_map.items()) new_env = relay.Tuple([ some(val) if map[i] == key else relay.TupleGetItem(env, i) for i in range(len(map)) ]) return self.env_ctr(new_env)