def to_backend_value(self, v, t): """Convert an intermediate value to a backend value.""" if isinstance(t, abstract.AbstractError) or v is abstract.DEAD: return None elif isinstance(t, abstract.AbstractType): # Handle abstract types. # Return None if type does not match any torch type. myia_type = t.element.xtype() return _type_map.get(myia_type, None) elif isinstance(t, abstract.AbstractArray): return self.from_numpy(v) elif isinstance(t, abstract.AbstractScalar): if issubclass(t.values[abstract.TYPE], (xtype.Number, xtype.Bool, xtype.Nil)): return self.from_scalar(v, t.values[abstract.TYPE]) elif issubclass(t.values[abstract.TYPE], xtype.EnvType): assert len(v._contents) == 0 return () else: raise NotImplementedError(f"to_backend_value for {t}") elif isinstance(t, abstract.AbstractTuple): return tuple( self.to_backend_value(v, t) for v, t in zip(v, t.elements)) elif isinstance(t, abstract.AbstractTaggedUnion): real_t = t.options.get(v.tag) return TaggedValue(v.tag, self.to_backend_value(v.value, real_t)) elif isinstance(t, abstract.AbstractRandomState): return self.from_numpy(v.state.copy()) else: raise NotImplementedError(f"to_backend_value for {t}")
def convert_tagged(self, v, t): tag = get_myia_tag(v.tag) try: conv_val = self(v[0], t.options.get(tag)) except TypeError: conv_val = self(v.fields[0], t.options.get(tag)) return TaggedValue(tag, conv_val)
def from_backend_value(self, v, t): """Convert a backend value to an intermediate value.""" if isinstance(t, abstract.AbstractScalar): return self.to_scalar(v) elif isinstance(t, abstract.AbstractArray): # Convert torch tensor to numpy tensor. output = self.to_numpy(v) # If possible and necessary, cast numpy tensor to expected tensor. array_type = t.element.xtype() if array_type and array_type not in _type_map: # Probably u16, u32 or u64. Let's cast. output = output.astype(type_to_np_dtype(array_type)) return output elif isinstance(t, abstract.AbstractTuple): return tuple( self.from_backend_value(ve, te) for ve, te in zip(v, t.elements)) elif isinstance(t, abstract.AbstractTaggedUnion): return TaggedValue( v.tag, self.from_backend_value(v.value, t.options.get(v.tag))) elif isinstance(t, abstract.AbstractRandomState): return RandomStateWrapper(self.to_numpy(v)) elif isinstance(t, abstract.AbstractType): if isinstance(t.element, abstract.AbstractHandle): return HandleInstance else: myia_type = t.element.xtype() if myia_type in _type_map: return getattr(np, type_to_np_dtype(myia_type)) else: return v else: raise NotImplementedError(f"Don't know what to do for {t}")