Exemplo n.º 1
0
 def to_backend_value(self, v, t):
     """Convert an intermediate value to a backend value."""
     if isinstance(t, abstract.AbstractError) or v is abstract.DEAD:
         return None
     elif isinstance(t, abstract.AbstractType):
         # Handle abstract types.
         # Return None if type does not match any torch type.
         myia_type = t.element.xtype()
         return _type_map.get(myia_type, None)
     elif isinstance(t, abstract.AbstractArray):
         return self.from_numpy(v)
     elif isinstance(t, abstract.AbstractScalar):
         if issubclass(t.values[abstract.TYPE],
                       (xtype.Number, xtype.Bool, xtype.Nil)):
             return self.from_scalar(v, t.values[abstract.TYPE])
         elif issubclass(t.values[abstract.TYPE], xtype.EnvType):
             assert len(v._contents) == 0
             return ()
         else:
             raise NotImplementedError(f"to_backend_value for {t}")
     elif isinstance(t, abstract.AbstractTuple):
         return tuple(
             self.to_backend_value(v, t) for v, t in zip(v, t.elements))
     elif isinstance(t, abstract.AbstractTaggedUnion):
         real_t = t.options.get(v.tag)
         return TaggedValue(v.tag, self.to_backend_value(v.value, real_t))
     elif isinstance(t, abstract.AbstractRandomState):
         return self.from_numpy(v.state.copy())
     else:
         raise NotImplementedError(f"to_backend_value for {t}")
Exemplo n.º 2
0
 def convert_tagged(self, v, t):
     tag = get_myia_tag(v.tag)
     try:
         conv_val = self(v[0], t.options.get(tag))
     except TypeError:
         conv_val = self(v.fields[0], t.options.get(tag))
     return TaggedValue(tag, conv_val)
Exemplo n.º 3
0
 def from_backend_value(self, v, t):
     """Convert a backend value to an intermediate value."""
     if isinstance(t, abstract.AbstractScalar):
         return self.to_scalar(v)
     elif isinstance(t, abstract.AbstractArray):
         # Convert torch tensor to numpy tensor.
         output = self.to_numpy(v)
         # If possible and necessary, cast numpy tensor to expected tensor.
         array_type = t.element.xtype()
         if array_type and array_type not in _type_map:
             # Probably u16, u32 or u64. Let's cast.
             output = output.astype(type_to_np_dtype(array_type))
         return output
     elif isinstance(t, abstract.AbstractTuple):
         return tuple(
             self.from_backend_value(ve, te)
             for ve, te in zip(v, t.elements))
     elif isinstance(t, abstract.AbstractTaggedUnion):
         return TaggedValue(
             v.tag, self.from_backend_value(v.value, t.options.get(v.tag)))
     elif isinstance(t, abstract.AbstractRandomState):
         return RandomStateWrapper(self.to_numpy(v))
     elif isinstance(t, abstract.AbstractType):
         if isinstance(t.element, abstract.AbstractHandle):
             return HandleInstance
         else:
             myia_type = t.element.xtype()
             if myia_type in _type_map:
                 return getattr(np, type_to_np_dtype(myia_type))
             else:
                 return v
     else:
         raise NotImplementedError(f"Don't know what to do for {t}")