def meta_cdist_forward(x1, x2, p, compute_mode): check( x1.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", ) check( x2.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", ) check( x1.size(-1) == x2.size(-1), lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", ) check( utils.is_float_dtype(x1.dtype), lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", ) check( utils.is_float_dtype(x2.dtype), lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", ) check(p >= 0, lambda: "cdist only supports non-negative p values") check( compute_mode >= 0 and compute_mode <= 2, lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}", ) r1 = x1.size(-2) r2 = x2.size(-2) batch_tensor1 = x1.shape[:-2] batch_tensor2 = x2.shape[:-2] output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) output_shape.extend([r1, r2]) return x1.new_empty(output_shape)
def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): """ Checks related to the dtype kwarg in `linalg.*norm` functions """ if dtype is not None: check( utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", ) check( utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}". format( fn_name=fn_name, d="complex" if utils.is_complex_dtype(x_dtype) else "real", dtype=dtype, ), ) check( utils.get_higher_dtype(dtype, x_dtype) == dtype, lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " "without narrowing to the specified dtype ({dtype})", )
def norm(self: Tensor, p: float = 2, dim: List[int] = None, keepdim: bool = False): if dim is None: dim = [] if p == 0: return (self != 0).sum(dim, keepdim=keepdim) elif p == float("inf"): return self.abs().amax(dim, keepdim=keepdim) elif p == -float("inf"): return self.abs().amin(dim, keepdim=keepdim) def fast_pow(x, ord): if ord == 1.0: return x elif ord == 2.0: return x.square() elif ord == 0.5: return x.sqrt() else: return x.pow(ord) if not (p % 2.0 == 0.0 and utils.is_float_dtype(self.dtype)): self = self.abs() return fast_pow(fast_pow(self, p).sum(dim, keepdim=keepdim), 1.0 / p)
def meta_embedding_bag( weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=-1, ): check( indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}", ) check( offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}", ) check( utils.is_float_dtype(weight.dtype), lambda: f"expected weight to be floating point type, got {weight.dtype}", ) num_bags = offsets.size(0) if include_last_offset: check(num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1") num_bags -= 1 output = weight.new_empty(num_bags, weight.size(1)) MODE_SUM, MODE_MEAN, MODE_MAX = range(3) if per_sample_weights is not None: check( mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", ) check( per_sample_weights.dtype == weight.dtype, lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", ) check( per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", ) check( per_sample_weights.numel() == indices.numel(), lambda: (f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " f"to be the same as indices.numel() ({indices.numel()})"), ) def is_fast_path_index_select_scale(src, scale, output, padding_idx): return (is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1) def is_fast_path_index_select(src, output, padding_idx): return ((src.dtype == torch.float or src.dtype == torch.half) and src.stride(1) == 1 and output.stride(1) == 1 and padding_idx < 0) def is_fast_path(src, scale, output, padding_idx): if scale is not None: return is_fast_path_index_select_scale(src, scale, output, padding_idx) else: return is_fast_path_index_select(src, output, padding_idx) if offsets.device.type != "cpu": offset2bag = indices.new_empty(indices.size(0)) bag_size = indices.new_empty(offsets.size()) if mode == MODE_MAX: max_indices = indices.new_empty(num_bags, weight.size(1)) else: max_indices = indices.new_empty(0) else: fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum: offset2bag = offsets.new_empty(indices.size(0)) else: offset2bag = offsets.new_empty(0) bag_size = offsets.new_empty(num_bags) max_indices = offsets.new_empty(bag_size.size()) return output, offset2bag, bag_size, max_indices
def vector_norm( x: TensorLikeType, ord: float = 2.0, dim: Optional[DimsType] = None, keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, ) -> Tensor: # Checks check_fp_or_complex(x.dtype, "linalg.vector_norm") if isinstance(dim, int): dim = [dim] # type: ignore[assignment] elif not isinstance(dim, List) and dim is not None: # refs.amin just accepts List rather than DimType (Tuple) dim = list(dim) # type: ignore[assignment] if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): check( dim is not None and len(dim) != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " "because the operation does not have an identity", ) shape = x.shape assert dim is not None # mypy does not seem to be able to see through check? for d in dim: check( shape[d] != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " f"dimension {d} because this dimension is empty and the " "operation does not have an identity", ) check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") computation_dtype, result_dtype = utils.reduction_dtypes( x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype) to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype) # Implementation if ord == 0.0: return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) elif ord == float("inf"): return to_result_dtype( refs.amax(torch.abs(x), dim=dim, keepdim=keepdim)) elif ord == float("-inf"): return to_result_dtype( refs.amin(torch.abs(x), dim=dim, keepdim=keepdim)) else: # From here on the computation dtype is important as the reduction is non-trivial x = prims.convert_element_type(x, computation_dtype) reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim) if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)): x = torch.abs(x) return to_result_dtype( torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))