def _get_env_val_map(self, t): t = broaden(t) rt = to_relay_type(t) if t not in self.env_val_map: name = f"v{len(self.env_val_map)}" self.env_val_map[t] = {'ctr': adt.Constructor(name, [rt], env_val)} return self.env_val_map[t], rt
def get_union_ctr(tag, t): """Get the relay constructor for a union tag.""" if tag not in tag_map: assert t is not None rt = to_relay_type(t) ctr = adt.Constructor(f"c{tag}", [rt], union_type) tag_map[tag] = ctr return tag_map[tag]
def initialize(self, mod, mng): """Add types to the module.""" if mng is not None: for node in mng.all_nodes: if isinstance(node.abstract, AbstractTaggedUnion): for opt in node.abstract.options: get_union_ctr(*opt) elif node.is_apply(P.env_setitem): key = node.inputs[2] tt = to_relay_type(node.inputs[3].abstract) assert key.is_constant() self.env_val_map[key.value] = tt env_val_keys = sorted(list(self.env_val_map.keys())) for i, k in enumerate(env_val_keys): self.env_val_map[k] = (i, self.env_val_map[k]) mod[union_type] = adt.TypeData(union_type, [], list(tag_map.values())) mod[option_type] = adt.TypeData(option_type, [a], [nil, some]) self.env_ctr = adt.Constructor("v", [self._build_env_type()], env_type) mod[env_type] = adt.TypeData(env_type, [], [self.env_ctr, dead_env])
AbstractError, AbstractFunctionUnique, AbstractHandle, AbstractRandomState, AbstractScalar, AbstractTaggedUnion, AbstractTuple, AbstractType, TypedPrimitive, ) from myia.operations import primitives as P from myia.utils import overload from myia.xtype import Bool, EnvType, Nil, UniverseType, type_to_np_dtype union_type = relay.GlobalTypeVar("$_union_adt") empty_union = adt.Constructor("empty", [], union_type) tag_map = {None: empty_union} rev_tag_map = {} def get_union_ctr(tag, t): """Get the relay constructor for a union tag.""" if tag not in tag_map: assert t is not None rt = to_relay_type(t) ctr = adt.Constructor(f"c{tag}", [rt], union_type) tag_map[tag] = ctr return tag_map[tag] def fill_reverse_tag_map():
from ...abstract import ( AbstractArray, AbstractError, AbstractFunction, AbstractScalar, AbstractTaggedUnion, AbstractTuple, TypedPrimitive, VirtualFunction, broaden, ) from ...utils import overload from ...xtype import Bool, EnvType, Nil, type_to_np_dtype union_type = relay.GlobalTypeVar('$_union_adt') empty_union = adt.Constructor("empty", [], union_type) tag_map = {} rev_tag_map = {} def get_union_ctr(tag, t): """Get the relay constructor for a union tag.""" if tag not in tag_map: assert t is not None rt = to_relay_type(t) ctr = adt.Constructor(f"c{tag}", [rt], union_type) tag_map[tag] = ctr rev_tag_map[ctr] = tag return tag_map[tag]