def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} if func in [torch.Tensor.where, torch.where]: _check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0) return _MaskedWhere.apply(*args) if func is torch.Tensor.contiguous: return _MaskedContiguous.apply(args[0]) if func is torch.Tensor.to_dense: return _MaskedToDense.apply(args[0]) if func is torch.Tensor.to_sparse: return _MaskedToSparse.apply(args[0]) if func is torch.Tensor.to_sparse_csr: return _MaskedToSparseCsr.apply(args[0]) if not all(issubclass(cls, t) for t in types): return NotImplemented with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret else: return torch._tensor._convert(ret, cls)
def __torch_function__(cls, func, types, args=(), kwargs=None): """ This __torch_function__ implementation wraps subclasses such that methods called on subclasses return a subclass instance instead of a ``torch.Tensor`` instance. One corollary to this is that you need coverage for torch.Tensor methods if implementing __torch_function__ for subclasses. We recommend always calling ``super().__torch_function__`` as the base case when doing the above. While not mandatory, we recommend making `__torch_function__` a classmethod. """ if kwargs is None: kwargs = {} if not all(issubclass(cls, t) for t in types): return NotImplemented with _C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret else: return _convert(ret, cls)
def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} # We will re-dispatch the execution to ShardedTensor __torch_function__ # if we find there're ShardedTensor operands. We will also check if args/kwargs # are all replicated tensor operands, we have to do this to ensure we do not # converting results back to ReplicatedTensor if not all operands are replicated. all_replicated = True replicated_pg = None def dispatch_arg(arg): # This function returns a tuple, first element represents whether the op been # executed, the second element represents the result of the execution nonlocal replicated_pg, all_replicated if isinstance(arg, ShardedTensor): # redispatch to ShardedTensor # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor return True, arg.__torch_function__(func, types, args, kwargs) if isinstance(arg, ReplicatedTensor): if replicated_pg is None: replicated_pg = arg.process_group elif replicated_pg != arg.process_group: raise RuntimeError( f"ReplicatedTensor operands must be in the same process group " f"in torch function '{func.__name__}', but found at least two " f"ReplicatedTensor operands in different process groups! " ) else: all_replicated = False return False, None for arg in args: redispatched, res = dispatch_arg(arg) if redispatched: return res if kwargs is not None: for k, v in kwargs.items(): redispatched, res = dispatch_arg(v) if redispatched: return res # We cann't do super().__torch_function__() as it implicitly convert the result # back to tensor subclasses, where in our case, we need to control the output type # base on the inter-op rules we defined. with torch._C.DisableTorchFunction(): rs = func(*args, **kwargs) if func in get_default_nowrap_functions(): return rs if all_replicated and isinstance( rs, torch.Tensor) and not isinstance(rs, cls): # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor # __torch_function__, result is a torch.Tensor, then we convert and return a # ReplicatedTensor according to our inter-op rule rs = rs.as_subclass(ReplicatedTensor) # type: ignore[arg-type] # propagate the process_group field to result rs.process_group = replicated_pg # type: ignore[attr-defined] return rs