示例#1
0
 def _get_env_val_map(self, t):
     t = broaden(t)
     rt = to_relay_type(t)
     if t not in self.env_val_map:
         name = f"v{len(self.env_val_map)}"
         self.env_val_map[t] = {'ctr': adt.Constructor(name, [rt], env_val)}
     return self.env_val_map[t], rt
示例#2
0
def get_union_ctr(tag, t):
    """Get the relay constructor for a union tag."""
    if tag not in tag_map:
        assert t is not None
        rt = to_relay_type(t)
        ctr = adt.Constructor(f"c{tag}", [rt], union_type)
        tag_map[tag] = ctr
    return tag_map[tag]
示例#3
0
    def initialize(self, mod, mng):
        """Add types to the module."""
        if mng is not None:
            for node in mng.all_nodes:
                if isinstance(node.abstract, AbstractTaggedUnion):
                    for opt in node.abstract.options:
                        get_union_ctr(*opt)
                elif node.is_apply(P.env_setitem):
                    key = node.inputs[2]
                    tt = to_relay_type(node.inputs[3].abstract)
                    assert key.is_constant()
                    self.env_val_map[key.value] = tt
        env_val_keys = sorted(list(self.env_val_map.keys()))

        for i, k in enumerate(env_val_keys):
            self.env_val_map[k] = (i, self.env_val_map[k])

        mod[union_type] = adt.TypeData(union_type, [], list(tag_map.values()))
        mod[option_type] = adt.TypeData(option_type, [a], [nil, some])
        self.env_ctr = adt.Constructor("v", [self._build_env_type()], env_type)
        mod[env_type] = adt.TypeData(env_type, [], [self.env_ctr, dead_env])
示例#4
0
    AbstractError,
    AbstractFunctionUnique,
    AbstractHandle,
    AbstractRandomState,
    AbstractScalar,
    AbstractTaggedUnion,
    AbstractTuple,
    AbstractType,
    TypedPrimitive,
)
from myia.operations import primitives as P
from myia.utils import overload
from myia.xtype import Bool, EnvType, Nil, UniverseType, type_to_np_dtype

union_type = relay.GlobalTypeVar("$_union_adt")
empty_union = adt.Constructor("empty", [], union_type)
tag_map = {None: empty_union}
rev_tag_map = {}


def get_union_ctr(tag, t):
    """Get the relay constructor for a union tag."""
    if tag not in tag_map:
        assert t is not None
        rt = to_relay_type(t)
        ctr = adt.Constructor(f"c{tag}", [rt], union_type)
        tag_map[tag] = ctr
    return tag_map[tag]


def fill_reverse_tag_map():
示例#5
0
from ...abstract import (
    AbstractArray,
    AbstractError,
    AbstractFunction,
    AbstractScalar,
    AbstractTaggedUnion,
    AbstractTuple,
    TypedPrimitive,
    VirtualFunction,
    broaden,
)
from ...utils import overload
from ...xtype import Bool, EnvType, Nil, type_to_np_dtype

union_type = relay.GlobalTypeVar('$_union_adt')
empty_union = adt.Constructor("empty", [], union_type)
tag_map = {}
rev_tag_map = {}


def get_union_ctr(tag, t):
    """Get the relay constructor for a union tag."""
    if tag not in tag_map:
        assert t is not None
        rt = to_relay_type(t)
        ctr = adt.Constructor(f"c{tag}", [rt], union_type)
        tag_map[tag] = ctr
        rev_tag_map[ctr] = tag
    return tag_map[tag]