Exemplo n.º 1
0
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}

        if func in [torch.Tensor.where, torch.where]:
            _check_args_kwargs_length(args,
                                      kwargs,
                                      "__torch_function__, torch.where",
                                      len_args=3,
                                      len_kwargs=0)
            return _MaskedWhere.apply(*args)
        if func is torch.Tensor.contiguous:
            return _MaskedContiguous.apply(args[0])
        if func is torch.Tensor.to_dense:
            return _MaskedToDense.apply(args[0])
        if func is torch.Tensor.to_sparse:
            return _MaskedToSparse.apply(args[0])
        if func is torch.Tensor.to_sparse_csr:
            return _MaskedToSparseCsr.apply(args[0])
        if not all(issubclass(cls, t) for t in types):
            return NotImplemented
        with torch._C.DisableTorchFunction():
            ret = func(*args, **kwargs)
            if func in get_default_nowrap_functions():
                return ret
            else:
                return torch._tensor._convert(ret, cls)
Exemplo n.º 2
0
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        """
        This __torch_function__ implementation wraps subclasses such that
        methods called on subclasses return a subclass instance instead of
        a ``torch.Tensor`` instance.

        One corollary to this is that you need coverage for torch.Tensor
        methods if implementing __torch_function__ for subclasses.

        We recommend always calling ``super().__torch_function__`` as the base
        case when doing the above.

        While not mandatory, we recommend making `__torch_function__` a classmethod.
        """
        if kwargs is None:
            kwargs = {}

        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

        with _C.DisableTorchFunction():
            ret = func(*args, **kwargs)
            if func in get_default_nowrap_functions():
                return ret
            else:
                return _convert(ret, cls)
Exemplo n.º 3
0
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        # We will re-dispatch the execution to ShardedTensor __torch_function__
        # if we find there're ShardedTensor operands. We will also check if args/kwargs
        # are all replicated tensor operands, we have to do this to ensure we do not
        # converting results back to ReplicatedTensor if not all operands are replicated.
        all_replicated = True
        replicated_pg = None

        def dispatch_arg(arg):
            # This function returns a tuple, first element represents whether the op been
            # executed, the second element represents the result of the execution
            nonlocal replicated_pg, all_replicated
            if isinstance(arg, ShardedTensor):
                # redispatch to ShardedTensor
                # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
                return True, arg.__torch_function__(func, types, args, kwargs)
            if isinstance(arg, ReplicatedTensor):
                if replicated_pg is None:
                    replicated_pg = arg.process_group
                elif replicated_pg != arg.process_group:
                    raise RuntimeError(
                        f"ReplicatedTensor operands must be in the same process group "
                        f"in torch function '{func.__name__}', but found at least two "
                        f"ReplicatedTensor operands in different process groups! "
                    )
            else:
                all_replicated = False

            return False, None

        for arg in args:
            redispatched, res = dispatch_arg(arg)
            if redispatched:
                return res

        if kwargs is not None:
            for k, v in kwargs.items():
                redispatched, res = dispatch_arg(v)
                if redispatched:
                    return res

        # We cann't do super().__torch_function__() as it implicitly convert the result
        # back to tensor subclasses, where in our case, we need to control the output type
        # base on the inter-op rules we defined.
        with torch._C.DisableTorchFunction():
            rs = func(*args, **kwargs)
            if func in get_default_nowrap_functions():
                return rs
            if all_replicated and isinstance(
                    rs, torch.Tensor) and not isinstance(rs, cls):
                # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
                # __torch_function__, result is a torch.Tensor, then we convert and return a
                # ReplicatedTensor according to our inter-op rule
                rs = rs.as_subclass(ReplicatedTensor)  # type: ignore[arg-type]
                # propagate the process_group field to result
                rs.process_group = replicated_pg  # type: ignore[attr-defined]

            return rs