Ejemplo n.º 1
0
 def __init__(self, device):
     """Create a PyTorch backend on the given device."""
     self.device = torch.device(device)
     self.compiler = CompileGraphs(lambda lst: pytorch_convert(lst, self),
                                   nonlinear_ops, self)
Ejemplo n.º 2
0
class PyTorchBackend(Backend):
    """Backend to run using pytorch.

    Backend options:

        :device: the target device for data storage ('cpu', 'cuda', 'cuda:X')

    """
    def __init__(self, device):
        """Create a PyTorch backend on the given device."""
        self.device = torch.device(device)
        self.compiler = CompileGraphs(lambda lst: pytorch_convert(lst, self),
                                      nonlinear_ops, self)

    def compile(self, graph, *others):
        """Compile a graph."""
        manage(graph)
        graph = closure_convert(graph)
        return self.compiler.compile_and_link(graph)

    def to_numpy(self, v):
        """Make a numpy array from a torch tensor."""
        if v.is_cuda:
            with untested_legacy():
                v = v.cpu()
        return v.detach().numpy()

    def from_numpy(self, a):
        """Make a torch tensor from a numpy array."""
        return torch.from_numpy(a).to(self.device)

    def to_scalar(self, v):
        """Convert a torch tensor to a scalar."""
        if (v is None) or (v is True) or (v is False) or (isinstance(v, str)):
            return v
        else:
            return v.item()

    def from_scalar(self, s, t):
        """Convert a scalar to a torch tensor."""
        if s is None:
            return None
        dt = type_to_np_dtype(t)
        return np.asarray(s, dtype=dt)

    def from_backend_value(self, v, t):
        """Convert a backend value to an intermediate value."""
        if isinstance(t, abstract.AbstractScalar):
            return self.to_scalar(v)
        elif isinstance(t, abstract.AbstractArray):
            # Convert torch tensor to numpy tensor.
            output = self.to_numpy(v)
            # If possible and necessary, cast numpy tensor to expected tensor.
            array_type = t.element.xtype()
            if array_type and array_type not in _type_map:
                # Probably u16, u32 or u64. Let's cast.
                output = output.astype(type_to_np_dtype(array_type))
            return output
        elif isinstance(t, abstract.AbstractTuple):
            return tuple(
                self.from_backend_value(ve, te)
                for ve, te in zip(v, t.elements))
        elif isinstance(t, abstract.AbstractTaggedUnion):
            return TaggedValue(
                v.tag, self.from_backend_value(v.value, t.options.get(v.tag)))
        elif isinstance(t, abstract.AbstractRandomState):
            return RandomStateWrapper(self.to_numpy(v))
        else:
            raise NotImplementedError(f"Don't know what to do for {t}")

    def to_backend_value(self, v, t):
        """Convert an intermediate value to a backend value."""
        if isinstance(t, abstract.AbstractError) or v is abstract.DEAD:
            return None
        elif isinstance(t, abstract.AbstractType):
            # Handle abstract types.
            # Return None if type does not match any torch type.
            myia_type = t.element.xtype()
            return _type_map.get(myia_type, None)
        elif isinstance(t, abstract.AbstractArray):
            return self.from_numpy(v)
        elif isinstance(t, abstract.AbstractScalar):
            if issubclass(t.values[abstract.TYPE],
                          (xtype.Number, xtype.Bool, xtype.Nil)):
                return self.from_scalar(v, t.values[abstract.TYPE])
            elif issubclass(t.values[abstract.TYPE], xtype.EnvType):
                assert len(v._contents) == 0
                return ()
            else:
                raise NotImplementedError(f"to_backend_value for {t}")
        elif isinstance(t, abstract.AbstractTuple):
            return tuple(
                self.to_backend_value(v, t) for v, t in zip(v, t.elements))
        elif isinstance(t, abstract.AbstractTaggedUnion):
            real_t = t.options.get(v.tag)
            return TaggedValue(v.tag, self.to_backend_value(v.value, real_t))
        elif isinstance(t, abstract.AbstractRandomState):
            return self.from_numpy(v.state.copy())
        else:
            raise NotImplementedError(f"to_backend_value for {t}")