def _safe_copy_out( *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False ): # Checks same device if copy_from.device != copy_to.device: msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format( copy_from.device, copy_to.device ) raise RuntimeError(msg) # Checks safe cast if exact_dtype: utils.check( copy_from.dtype == copy_to.dtype, lambda: f"Expected out tensor to have dtype {copy_from.dtype} " "but got {copy_to.dtype} instead", ) else: utils.check( utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " "but this can't be cast because it is not safe!", ) return copy_to.copy_(copy_from)
def _fft_c2r( func_name: str, input: TensorLikeType, n: Optional[int], dim: int, norm: NormType, forward: bool, ) -> TensorLikeType: """Common code for performing any complex to real FFT (irfft or hfft)""" input = _maybe_promote_tensor_fft(input, require_complex=True) dims = (utils.canonicalize_dim(input.ndim, dim), ) last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified") if n is not None: input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1, )) if forward: input = torch.conj(input) output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
def norm( A: TensorLikeType, ord: Optional[Union[float, str]] = None, dim: Optional[DimsType] = None, keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, ) -> TensorLikeType: if dim is not None: if isinstance(dim, int): dim = (dim, ) # type: ignore[assignment] check( len(dim) in (1, 2), lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", ) elif ord is not None: check( A.ndim in (1, 2), lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", ) if ord is not None and ((dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2)): if dim is None: dim = (0, 1) return matrix_norm(A, ord, dim, keepdim, dtype=dtype) else: if ord is None: ord = 2.0 return vector_norm(A, ord, dim, keepdim, dtype=dtype)
def meta_dot(self, tensor): check( self.dim() == 1 and tensor.dim() == 1, lambda: f"1D tensors expected, but got {self.dim()}D and {tensor.dim()}D tensors", ) return self.new_empty(())
def ihfftn( input: TensorLikeType, s: Optional[ShapeType] = None, dim: Optional[DimsType] = None, norm: NormType = None, ) -> TensorLikeType: check( not input.dtype.is_complex, lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", ) shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") input = _maybe_promote_tensor_fft(input, require_complex=False) input = _resize_fft_input(input, dim, shape) tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True) if len(dim) == 1: tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False) return prims.conj(tmp) tmp = prims.conj_physical(tmp) tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False) return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.prelu """ check( isinstance(a, TensorLike), lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", ) check( isinstance(weight, TensorLike), lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", ) if weight.numel() != 1: check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") channel_size = a.shape[1] if a.ndim >= 2 else 1 check( weight.numel() == channel_size, lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" f" {weight.numel()} and channel size = {channel_size}.", ) check( weight.ndim == 0 or weight.ndim == 1, lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " f"ndim = {weight.ndim}", ) weight = prims.broadcast_in_dim(weight, a.shape, tuple() if weight.ndim == 0 else (1, )) return refs.where(a > 0, a, a * weight)
def _canonicalize_fft_c2r_shape_and_dim_args( fname: str, input: TensorLikeType, s: Optional[ShapeType], dim: Optional[DimsType], ) -> _CanonicalizeC2rReturn: """Canonicalize shape and dim arguments for n-dimensional c2r transforms, as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") if s is None or s[-1] == -1: last_dim_size = 2 * (input.shape[dim[-1]] - 1) else: last_dim_size = shape[-1] check( last_dim_size >= 1, lambda: f"Invalid number of data points ({last_dim_size}) specified", ) shape_list = list(shape) shape_list[-1] = last_dim_size // 2 + 1 return _CanonicalizeC2rReturn(shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size)
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: dim = utils.canonicalize_dims(a.ndim, dim) check( a.shape[dim] % 2 == 0, lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", ) b, c = torch.tensor_split(a, 2, dim) return b * torch.sigmoid(c)
def _fn(*args, out=None, **kwargs): if is_factory_fn and out is not None: for k in factory_kwargs: out_attr = getattr(out, k) if k not in kwargs: kwargs[k] = out_attr result = fn(*args, **kwargs) assert (isinstance(result, TensorLike) and is_tensor or isinstance(result, Tuple) # type: ignore[arg-type] and len(result) == len(out_names)) if out is not None: # Naively you might expect this assert to be true, but # it's not: # # assert type(out) == type(result) # # The reason is that functions under this wrapper can # get registered to the Meta dispatch key, and that # means they can be executed in a context where tensor # subclasses are disabled (with no_dispatch), which is a # handy way for an is-a tensor subclass (e.g., # FakeTensor) to have the normal meta backend create a # meta tensor, to be wrapped once it gets returned. # In this situation, you will get a FakeTensor as # the output tensor, but not the result--which will # be a normal meta tensor, but this is perfectly # harmless. if is_tensor: assert isinstance(out, TensorLike) # These two operations are done in-place _maybe_resize_out(out, result.shape) _safe_copy_out( copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] else: assert isinstance(out, Tuple) # type: ignore[arg-type] utils.check( len(out) == len(result), lambda: f"expected tuple of {len(result)} elements but got {len(out)}", TypeError, ) for r, o in zip(result, out): # These two operations are done in-place _maybe_resize_out(o, r.shape) _safe_copy_out( copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] else: out = result # mypy does not see through the definition of out_type given that it's in a different scope return out if is_tensor else return_type( *out) # type: ignore[operator]
def _apply_norm(x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool) -> TensorLikeType: """Apply normalization to the un-normalized FFT result""" check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") if norm == "ortho": return x * (1 / math.sqrt(signal_numel)) normalize = (not forward and (norm is None or norm == "backward")) or ( forward and norm == "forward") return x * (1 / signal_numel) if normalize else x
def meta_diag(self, dim=0): check(self.dim() in (1, 2), lambda: "matrix or a vector expected") if self.dim() == 1: sz = self.size(0) + abs(dim) return self.new_empty((sz, sz)) # case: dim is 2 if dim >= 0: sz = min(self.size(0), self.size(1) - dim) else: sz = min(self.size(0) + dim, self.size(1)) return self.new_empty((sz, ))
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") check(p >= 0, lambda: "pdist only supports non-negative p values") # For p == 2 we can use an efficient implementation, but other values of p # require creating a much bigger tensor for an intermediate step if p == 2: aTa = torch.mm(a, a.T) aTa_diag = torch.diag(aTa) t = torch.sqrt( torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) else: t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) return t.flatten().index_select(0, i[0] * t.shape[0] + i[1])
def softshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # softshrink(x) = x - lambd if x > lambd # = x + lambd if x < -lambd # = 0 otherwise check( lambd >= 0, lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", ) ge_mask = a > lambd le_mask = a < -lambd zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask)) result = refs.where(ge_mask, a - lambd, a) result = refs.where(le_mask, a + lambd, result) return refs.where(zero_mask, 0, result)
def rfftn( input: TensorLikeType, s: Optional[ShapeType] = None, dim: Optional[DimsType] = None, norm: NormType = None, ) -> TensorLikeType: check( not input.dtype.is_complex, lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", ) shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) input = _maybe_promote_tensor_fft(input, require_complex=False) input = _resize_fft_input(input, dim, shape) out = prims.fft_r2c(input, dim=dim, onesided=True) return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
def _canonicalize_fft_shape_and_dim_args( input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]) -> _ShapeAndDims: """Convert the shape and dim arguments into a canonical form where neither are optional""" input_dim = input.ndim input_sizes = input.shape if dim is not None: if not isinstance(dim, Sequence): dim = (dim, ) ret_dims = utils.canonicalize_dims(input_dim, dim) # Check dims are unique check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique") if shape is not None: if not isinstance(shape, Sequence): shape = (shape, ) # Has shape, might have dim check( dim is None or len(dim) == len(shape), lambda: "When given, dim and shape arguments must have the same length", ) transform_ndim = len(shape) check( transform_ndim <= input_dim, lambda: f"Got shape with {transform_ndim} values but input tensor " f"only has {input_dim} dimensions.", ) # If shape is given, dims defaults to the last len(shape) dimensions if dim is None: ret_dims = tuple(range(input_dim - transform_ndim, input_dim)) # Translate any -1 values in shape to the default length ret_shape = tuple(s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)) elif dim is None: # No shape, no dim ret_dims = tuple(range(input_dim)) ret_shape = tuple(input_sizes) else: # No shape, has dim ret_shape = tuple(input_sizes[d] for d in ret_dims) for n in ret_shape: check(n > 0, lambda: f"Invalid number of data points ({n}) specified") return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
def _fftn_c2c( function_name: str, input: TensorLikeType, shape: Tuple[int, ...], dim: Tuple[int, ...], norm: NormType, forward: bool, ) -> TensorLikeType: """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" check( input.dtype.is_complex, lambda: f"{function_name} expects a complex input tensor, " f"but got {input.dtype}", ) x = _resize_fft_input(input, dim, shape) output = prims.fft_c2c(x, dim=dim, forward=forward) return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
def _fft_c2c( func_name: str, input: TensorLikeType, n: Optional[int], dim: int, norm: NormType, forward: bool, ) -> TensorLikeType: """Common code for performing any complex to complex FFT (fft or ifft)""" check( input.dtype.is_complex, lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", ) dims = (utils.canonicalize_dim(input.ndim, dim), ) if n is not None: input = _resize_fft_input(input, dims, (n, )) ret = prims.fft_c2c(input, dim=dims, forward=forward) return _apply_norm(ret, norm, input.shape[dim], forward)
def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): """ Checks related to the dtype kwarg in `linalg.*norm` functions """ if dtype is not None: check( utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", ) check( utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}". format( fn_name=fn_name, d="complex" if utils.is_complex_dtype(x_dtype) else "real", dtype=dtype, ), ) check( utils.get_higher_dtype(dtype, x_dtype) == dtype, lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " "without narrowing to the specified dtype ({dtype})", )
def meta_pad2d(self, padding): valid_dims = self.size(1) != 0 and self.size(2) != 0 check( (self.ndim == 3 and valid_dims) or (self.ndim == 4 and valid_dims and self.size(3) != 0), lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}", ) if self.ndim == 4: nbatch, nplane, input_h, input_w = self.shape else: nbatch = 1 nplane, input_h, input_w = self.shape pad_l, pad_r, pad_t, pad_b = padding output_h = input_h + pad_t + pad_b output_w = input_w + pad_l + pad_r if self.ndim == 3: return self.new_empty((nplane, output_h, output_w)) else: return self.new_empty((nbatch, nplane, output_h, output_w))
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(1) dim2 = batch2.size(2) self = self.expand((dim1, dim2)) check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") check( batch1.size(0) == batch2.size(0), lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}", ) check( batch1.size(2) == batch2.size(1), lambda: (f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} " f"and {batch2.size(1)}x{batch2.size(2)})"), ) check( self.size(0) == dim1 and self.size(1) == dim2, lambda: "self tensor does not match matmul output shape", ) return self.new_empty(self.size())
def _fft_r2c( func_name: str, input: TensorLikeType, n: Optional[int], dim: int, norm: NormType, forward: bool, onesided: bool, ) -> TensorLikeType: """Common code for performing any real to complex FFT (rfft or ihfft)""" check( not input.dtype.is_complex, lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", ) input = _maybe_promote_tensor_fft(input) dims = (utils.canonicalize_dim(input.ndim, dim), ) if n is not None: input = _resize_fft_input(input, dims, (n, )) ret = prims.fft_r2c(input, dim=dims, onesided=onesided) ret = _apply_norm(ret, norm, input.shape[dim], forward) return ret if forward else torch.conj(ret)
def vector_norm( x: TensorLikeType, ord: float = 2.0, dim: Optional[DimsType] = None, keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, ) -> Tensor: # Checks check_fp_or_complex(x.dtype, "linalg.vector_norm") if isinstance(dim, int): dim = [dim] # type: ignore[assignment] elif not isinstance(dim, List) and dim is not None: # refs.amin just accepts List rather than DimType (Tuple) dim = list(dim) # type: ignore[assignment] if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): check( dim is not None and len(dim) != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " "because the operation does not have an identity", ) shape = x.shape assert dim is not None # mypy does not seem to be able to see through check? for d in dim: check( shape[d] != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " f"dimension {d} because this dimension is empty and the " "operation does not have an identity", ) check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") computation_dtype, result_dtype = utils.reduction_dtypes( x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype) to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype) # Implementation if ord == 0.0: return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) elif ord == float("inf"): return to_result_dtype( refs.amax(torch.abs(x), dim=dim, keepdim=keepdim)) elif ord == float("-inf"): return to_result_dtype( refs.amin(torch.abs(x), dim=dim, keepdim=keepdim)) else: # From here on the computation dtype is important as the reduction is non-trivial x = prims.convert_element_type(x, computation_dtype) reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim) if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)): x = torch.abs(x) return to_result_dtype( torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
def dot_check(self, other): check( self.dim() == 1 and other.dim() == 1, lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", )
def matrix_norm( A: TensorLikeType, ord: Union[float, str] = "fro", dim: DimsType = (-2, -1), keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, ) -> TensorLikeType: # shape check_is_matrix(A, "linalg.matrix_norm") # dim dim = utils.canonicalize_dims(A.ndim, dim) if isinstance(dim, int): dim = (dim, ) # type: ignore[assignment] check( len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}") check( dim[0] != dim[1], lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", ) # dtype arg check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") if isinstance(ord, str): # ord check( ord in ("fro", "nuc"), lambda: "linalg.matrix_norm: Order {ord} not supported.", ) # dtype check_fp_or_complex(A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc") if ord == "fro": return vector_norm(A, 2, dim, keepdim, dtype=dtype) else: # ord == "nuc" if dtype is not None: A = prims.convert_element_type(A, dtype) perm = backshift_permutation(dim[0], dim[1], A.ndim) result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) if keepdim: inv_perm = inverse_permutation(perm) result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) return result else: # ord abs_ord = abs(ord) check( abs_ord in (2, 1, float("inf")), lambda: "linalg.matrix_norm: Order {ord} not supported.", ) # dtype check_fp_or_complex(A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2) max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) if abs_ord == 2.0: if dtype is not None: A = prims.convert_element_type(A, dtype) perm = backshift_permutation(dim[0], dim[1], A.ndim) result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) if keepdim: inv_perm = inverse_permutation(perm) result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) return result else: # 1, -1, inf, -inf dim0, dim1 = dim if abs_ord == float("inf"): dim0, dim1 = dim1, dim0 if not keepdim and (dim0 < dim1): dim1 -= 1 return max_min( vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1)
def meta_adaptive_avg_pool2d(self, output_size): check( self.ndim == 3 or self.ndim == 4, lambda: f"Expected 3D or 4D tensor, but got {self.shape}", ) return self.new_empty(self.shape[:-2] + tuple(output_size))
def meta_embedding_bag( weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=-1, ): check( indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}", ) check( offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}", ) check( utils.is_float_dtype(weight.dtype), lambda: f"expected weight to be floating point type, got {weight.dtype}", ) num_bags = offsets.size(0) if include_last_offset: check(num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1") num_bags -= 1 output = weight.new_empty(num_bags, weight.size(1)) MODE_SUM, MODE_MEAN, MODE_MAX = range(3) if per_sample_weights is not None: check( mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", ) check( per_sample_weights.dtype == weight.dtype, lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", ) check( per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", ) check( per_sample_weights.numel() == indices.numel(), lambda: (f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " f"to be the same as indices.numel() ({indices.numel()})"), ) def is_fast_path_index_select_scale(src, scale, output, padding_idx): return (is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1) def is_fast_path_index_select(src, output, padding_idx): return ((src.dtype == torch.float or src.dtype == torch.half) and src.stride(1) == 1 and output.stride(1) == 1 and padding_idx < 0) def is_fast_path(src, scale, output, padding_idx): if scale is not None: return is_fast_path_index_select_scale(src, scale, output, padding_idx) else: return is_fast_path_index_select(src, output, padding_idx) if offsets.device.type != "cpu": offset2bag = indices.new_empty(indices.size(0)) bag_size = indices.new_empty(offsets.size()) if mode == MODE_MAX: max_indices = indices.new_empty(num_bags, weight.size(1)) else: max_indices = indices.new_empty(0) else: fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum: offset2bag = offsets.new_empty(indices.size(0)) else: offset2bag = offsets.new_empty(0) bag_size = offsets.new_empty(num_bags) max_indices = offsets.new_empty(bag_size.size()) return output, offset2bag, bag_size, max_indices
def meta_cdist_forward(x1, x2, p, compute_mode): check( x1.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", ) check( x2.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", ) check( x1.size(-1) == x2.size(-1), lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", ) check( utils.is_float_dtype(x1.dtype), lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", ) check( utils.is_float_dtype(x2.dtype), lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", ) check(p >= 0, lambda: "cdist only supports non-negative p values") check( compute_mode >= 0 and compute_mode <= 2, lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}", ) r1 = x1.size(-2) r2 = x2.size(-2) batch_tensor1 = x1.shape[:-2] batch_tensor2 = x2.shape[:-2] output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) output_shape.extend([r1, r2]) return x1.new_empty(output_shape)
def meta_index_Tensor(self, indices): check(indices, lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors result: List[Optional[Tensor]] = [] for i, index in enumerate(indices): if index is not None: check( index.dtype in [torch.long, torch.int8, torch.bool], lambda: "tensors used as indices must be long, byte or bool tensors", ) if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) check( k + index.ndim <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim}", IndexError, ) for j in range(index.ndim): check( index.shape[j] == self.shape[k + j], lambda: f"The shape of the mask {index.shape} at index {i} " f"does not match the shape of the indexed tensor {self.shape} at index {k + j}", IndexError, ) result.append(nonzero.select(1, j)) else: result.append(index) else: result.append(index) indices = result check( len(indices) <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})", ) # expand_outplace import torch._refs as refs # avoid import cycle in mypy indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors while len(indices) < self.ndim: indices.append(None) # hasContiguousSubspace # true if all non-null tensors are adjacent # See: # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency state = 0 has_contiguous_subspace = False for index in indices: if state == 0: if index is not None: state = 1 elif state == 1: if index is None: state = 2 else: if index is not None: break else: has_contiguous_subspace = True # transposeToFront # This is the logic that causes the newly inserted dimensions to show up # at the beginning of the tensor, if they're not contiguous if not has_contiguous_subspace: dims = [] transposed_indices = [] for i, index in enumerate(indices): if index is not None: dims.append(i) transposed_indices.append(index) for i, index in enumerate(indices): if index is None: dims.append(i) transposed_indices.append(index) self = self.permute(dims) indices = transposed_indices # AdvancedIndex::AdvancedIndex # Now we can assume the indices have contiguous subspace # This is simplified from AdvancedIndex which goes to more effort # to put the input and indices in a form so that TensorIterator can # take them. If we write a ref for this, probably that logic should # get implemented before_shape: List[int] = [] after_shape: List[int] = [] replacement_shape: List[int] = [] for dim, index in enumerate(indices): if index is None: if replacement_shape: after_shape.append(self.shape[dim]) else: before_shape.append(self.shape[dim]) else: replacement_shape = list(index.shape) return self.new_empty(before_shape + replacement_shape + after_shape)
def meta_adaptive_avg_pool3d(self, output_size): check( self.ndim == 4 or self.ndim == 5, lambda: f"Expected 4D or 5D tensor, but got {self.shape}", ) return self.new_empty(self.shape[:-3] + tuple(output_size))