Exemplo n.º 1
0
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        """
        This __torch_function__ implementation wraps subclasses such that
        methods called on subclasses return a subclass instance instead of
        a ``torch.Tensor`` instance.

        One corollary to this is that you need coverage for torch.Tensor
        methods if implementing __torch_function__ for subclasses.

        We recommend always calling ``super().__torch_function__`` as the base
        case when doing the above.

        While not mandatory, we recommend making `__torch_function__` a classmethod.
        """
        if kwargs is None:
            kwargs = {}

        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

        with _C.DisableTorchFunction():
            ret = func(*args, **kwargs)
            if func in get_default_nowrap_functions():
                return ret
            else:
                return _convert(ret, cls)
Exemplo n.º 2
0
        def __torch_function__(cls, func, types, args=(), kwargs=None):
            if kwargs is None:
                kwargs = {}

            if not all(issubclass(cls, t) for t in types):
                return NotImplemented

            with _C.DisableTorchFunction():
                ret = func(*args, **kwargs)
                return cls._convert(ret)
Exemplo n.º 3
0
 def __torch_function__(cls, func, types, args=(), kwargs=None):
     ret = super().__torch_function__(func, types, args, kwargs)
     with _C.DisableTorchFunction():
         return _convert(ret, torchplus.Tensor)