def meta_cdist_forward(x1, x2, p, compute_mode): check( x1.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", ) check( x2.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", ) check( x1.size(-1) == x2.size(-1), lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", ) check( utils.is_float_dtype(x1.dtype), lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", ) check( utils.is_float_dtype(x2.dtype), lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", ) check(p >= 0, lambda: "cdist only supports non-negative p values") check( compute_mode >= 0 and compute_mode <= 2, lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}", ) r1 = x1.size(-2) r2 = x2.size(-2) batch_tensor1 = x1.shape[:-2] batch_tensor2 = x2.shape[:-2] output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) output_shape.extend([r1, r2]) return x1.new_empty(output_shape)
def isclose( a: TensorLikeType, b: TensorLikeType, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, ) -> TensorLikeType: if a.dtype != b.dtype: msg = "Attempting to compare tensors of different dtypes {0} and {1}!".format( a.dtype, b.dtype) raise ValueError(a, b) if rtol < 0: msg = "rtol must be greater than or equal to zero, but got {0}!".format( rtol) if atol < 0: msg = "atol must be greater than or equal to zero, but got {0}!".format( atol) close = eq(a, b) if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): close = logical_or(close, logical_and(isnan(a), isnan(b))) # Note: In case of zero tolerances the closeness inequality degenerates to an equality check. # In this case, the short-circuit prevents false positives as detailed in the paragraph below. if atol == 0 and rtol == 0: return close # Note [closeness error computation] # atol and rtol are provided as doubles, so the computation # rtol * other will produce a float or complex tensor. # When the difference (self - other) is compared to it then the # tensor representing the difference will also be cast to float or complex. # However, since (self - other) in uint8 is very likely to produce a # negative value, this moves the cast forward so the difference is # always computed in a float or complex type. # If the values of the integer tensors cannot be exactly represented # by the default scalar type then this may cause an incorrect result. if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype( a.dtype): a = prims.convert_element_type(a, torch.get_default_dtype()) b = prims.convert_element_type(b, torch.get_default_dtype()) allowed_error = add(atol, abs(mul(b, rtol))) actual_error = abs(sub(a, b)) # Computes finite closeness result = logical_or( close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))) return result
def norm(self: Tensor, p: float = 2, dim: List[int] = None, keepdim: bool = False): if dim is None: dim = [] if p == 0: return (self != 0).sum(dim, keepdim=keepdim) elif p == float('inf'): return self.abs().amax(dim, keepdim=keepdim) elif p == -float('inf'): return self.abs().amin(dim, keepdim=keepdim) def fast_pow(x, ord): if ord == 1.0: return x elif ord == 2.0: return x.square() elif ord == 0.5: return x.sqrt() else: return x.pow(ord) if not (p % 2.0 == 0.0 and utils.is_float_dtype(self.dtype)): self = self.abs() return fast_pow(fast_pow(self, p).sum(dim, keepdim=keepdim), 1.0 / p)
def meta_linalg_qr_helper(input, mode): if mode == "reduced": compute_q = True reduced_mode = True elif mode == "complete": compute_q = True reduced_mode = False elif mode == "r": compute_q = False reduced_mode = True else: raise RuntimeError(f"qr received unrecognized mode {mode}") check(input.ndim >= 2, lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor") check( utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), lambda: f"expected float or complex tensor, but got {input.dtype}" ) m = input.size(-2) n = input.size(-1) mn = min(m, n) if compute_q: Qt_shape = list(input.size()) Qt_shape[-2] = mn if reduced_mode else m Qt_shape[-1] = m Q = input.new_empty(Qt_shape) Q.transpose_(-2, -1) else: Q = input.new_empty(0) Rt_shape = list(input.size()) Rt_shape[-2] = n Rt_shape[-1] = mn if reduced_mode or not compute_q else m R = input.new_empty(Rt_shape) R.transpose_(-2, -1) return (Q, R)
def _find_highest_dtype_filtered(args, filter, *, float_as_complex=False, all_tensors_equal=False ) -> Optional[torch.dtype]: zero_dim_tensor_dtype = None one_plus_dim_tensor_dtype = None for x in args: if isinstance(x, TensorLike) and filter(x.dtype): _dtype = x.dtype if float_as_complex and utils.is_float_dtype(_dtype): _dtype = utils.corresponding_complex_dtype(_dtype) if x.ndim == 0 and not all_tensors_equal: zero_dim_tensor_dtype = utils.get_higher_dtype( zero_dim_tensor_dtype, _dtype) else: # x.ndim > 0 or all_tensors_equal one_plus_dim_tensor_dtype = utils.get_higher_dtype( one_plus_dim_tensor_dtype, _dtype) # Prefers dtype of tensors with one or more dimensions if one_plus_dim_tensor_dtype is not None: return one_plus_dim_tensor_dtype return zero_dim_tensor_dtype
def vector_norm( x: TensorLikeType, ord: float = 2.0, dim: Optional[DimsType] = None, keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, ) -> Tensor: # Checks check_fp_or_complex(x.dtype, "linalg.vector_norm") if isinstance(dim, int): dim = [dim] # type: ignore[assignment] elif not isinstance(dim, List) and dim is not None: # refs.amin just accepts List rather than DimType (Tuple) dim = list(dim) # type: ignore[assignment] if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): check( dim is not None and len(dim) != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " "because the operation does not have an identity", ) shape = x.shape assert dim is not None # mypy does not seem to be able to see through check? for d in dim: check( shape[d] != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " f"dimension {d} because this dimension is empty and the " "operation does not have an identity", ) check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") computation_dtype, result_dtype = reduction_dtypes( x, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype ) to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype) # Implementation if ord == 0.0: return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) elif ord == float("inf"): return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim)) elif ord == float("-inf"): return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim)) else: # From here on the computation dtype is important as the reduction is non-trivial x = prims.convert_element_type(x, computation_dtype) reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim) # Avoid computing a sqrt in abs and then squaring (more stable) # This could potentially be done for complex dtypes as # x = torch.real(torch.conj(x) * x)) # and it should be more stable, but it's not clear whether it'll be faster on, say # CPU (abs is 1 vectorised operation), so leaving it just for real dtypes for now if not (ord % 2.0 == 0.0 and is_float_dtype(x.dtype)): x = torch.abs(x) return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
def meta_linalg_cholesky_ex(input, upper=False, check_errors=False): check(input.ndim >= 2, lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor") check( utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), lambda: f"expected float or complex tensor, but got {input.dtype}" ) check(input.size(-1) == input.size(-2), lambda: f"expected square matrix but got {input.shape}") L = input.new_empty(input.size()) L.transpose_(-2, -1) info_sizes = input.size()[:-2] info = input.new_empty(info_sizes, dtype=torch.int) return L, info
def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): """ Checks related to the dtype kwarg in `linalg.*norm` functions """ if dtype is not None: check( is_float_dtype(dtype) or is_complex_dtype(dtype), lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", ) check( is_complex_dtype(dtype) == is_complex_dtype(x_dtype), lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( fn_name=fn_name, d="complex" if is_complex_dtype(x_dtype) else "real", dtype=dtype, ), ) check( get_higher_dtype(dtype, x_dtype) == dtype, lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " "without narrowing to the specified dtype ({dtype})", )
def meta_embedding_bag( weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=-1, ): check( indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}", ) check( offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}", ) check( utils.is_float_dtype(weight.dtype), lambda: f"expected weight to be floating point type, got {weight.dtype}", ) num_bags = offsets.size(0) if include_last_offset: check( num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1" ) num_bags -= 1 output = weight.new_empty(num_bags, weight.size(1)) MODE_SUM, MODE_MEAN, MODE_MAX = range(3) if per_sample_weights is not None: check( mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", ) check( per_sample_weights.dtype == weight.dtype, lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", ) check( per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", ) check( per_sample_weights.numel() == indices.numel(), lambda: ( f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " f"to be the same as indices.numel() ({indices.numel()})" ), ) def is_fast_path_index_select_scale(src, scale, output, padding_idx): return ( is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1 ) def is_fast_path_index_select(src, output, padding_idx): return ( (src.dtype == torch.float or src.dtype == torch.half) and src.stride(1) == 1 and output.stride(1) == 1 and padding_idx < 0 ) def is_fast_path(src, scale, output, padding_idx): if scale is not None: return is_fast_path_index_select_scale(src, scale, output, padding_idx) else: return is_fast_path_index_select(src, output, padding_idx) if offsets.device.type != "cpu": offset2bag = indices.new_empty(indices.size(0)) bag_size = indices.new_empty(offsets.size()) if mode == MODE_MAX: max_indices = indices.new_empty(num_bags, weight.size(1)) else: max_indices = indices.new_empty(0) else: fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum: offset2bag = offsets.new_empty(indices.size(0)) else: offset2bag = offsets.new_empty(0) bag_size = offsets.new_empty(num_bags) max_indices = offsets.new_empty(bag_size.size()) return output, offset2bag, bag_size, max_indices
def _elementwise_dtypes( *_args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND ) -> Tuple[torch.dtype, torch.dtype]: """ Computes the computation and result dtypes for elementwise type promotion on the given arguments and with the given elementwise type promotion kind. Note that not all inputs to an elementwise operation necessarily participate in type promotion. For example, the "alpha" parameter of torch.add does not participate in type promotion, although it is cast to the Python type corresponding to the computation dtype that the type promotion algorithm determines. Default elementwise type promotion, which all other type promotion kinds tweak (see below), first decides which of four ordered types to use: bool -> integer -> floating point -> complex The selected type is the "lowest" type in the above list such that all number arguments have a weakly "lower" type and all tensor arguments have a weakly lower corresponding type for their dtype. Once the type is determined, the particular result dtype is found. The dtypes are partially ordered as follows: bool -> uint8, int8 -> int16 -> int32 -> int64 -> float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 The result dtype is selected by: - if no tensor's dtype has the same corresponding type as the one selected, then the result dtype is the (default) dtype corresponding to the selected type (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype) - if the result type is complex then the dtype is: - the default complex dtype if there are no floating point or complex tensors - if there are floating point or complex tensors with one or more dimensions, then the complex dtype corresponding to the highest corresponding complex dtype among those tensors (for example, double + cfloat -> cdouble) - if there are only floating point or complex tensors with zero dimensions, then the complex dtype corresponding to the highest corresponding complex dtype among those tensors - if the first two cases do not apply, the result dtype is the highest dtype among all tensors with one or more dimensions of the output type, and if there are no such tensors then it's the highest dtype among all tensors with zero dimensions of the output type (for example, long + half -> half, even if the half tensor has zero dimensions) The "corresponding complex dtypes" are: float16 -> complex32 bfloat16 -> complex64 float32 -> complex64 float64 -> complex128 complex32 -> complex32 complex64 -> complex64 complex128 -> complex128 The DEFAULT type promotion option computes per above, and uses the result dtype as the computation dtype. The OP_MATH, INT_TO_FLOAT, COMPLEX_TO_FLOAT and BOOL_TO_LONG type promotion options tweak the above slightly. OP_MATH determines a "computation dtype" from the result dtype, and the mapping is simple: float16 -> float32 bfloat16 -> float32 complex32 -> complex64 INT_TO_FLOAT, COMPLEX_TO_FLOAT, and BOOL_TO_LONG compute the computation type in the same way, but INT_TO_FLOAT and BOOL_TO_LONG map the result dtype to another dtype first, and COMPLEX_TO_FLOAT maps its result dtype after the compuation dtype is determined, as follows: INT_TO_FLOAT maps all boolean and integer result dtypes to the default floating point dtype COMPLEX_TO_FLOAT maps complex result dtypes to their corresponding floating point dtype BOOL_TO_LONG maps the boolean result dtype to long The "corresponding floating point dtypes" are: complex32 -> float16 complex64 -> float32 complex128 -> float64 The ALWAYS_BOOL type promotion option always maps the result dtype to bool. Example operators for each type promotion option: DEFAULT : nextafter OP_MATH : add INT_TO_FLOAT : sin COMPLEX_TO_FLOAT : abs BOOL_TO_LONG : pow ALWAYS_BOOL : eq """ args = tuple(x for x in _args if x is not None) highest_type: type = bool for x in args: if not isinstance(x, (Number, TensorLike)): msg = ( "Unexpected type {0} when computing elementwise type promotion!" .format(str(type(x)))) raise ValueError(msg) if isinstance(x, Number): highest_type = utils.get_higher_type(highest_type, type(x)) else: # x is a TensorLike highest_type = utils.get_higher_type(highest_type, utils.dtype_to_type(x.dtype)) result_dtype = None def _find_highest_dtype_filtered(args, filter, *, float_as_complex=False, all_tensors_equal=False ) -> Optional[torch.dtype]: zero_dim_tensor_dtype = None one_plus_dim_tensor_dtype = None for x in args: if isinstance(x, TensorLike) and filter(x.dtype): _dtype = x.dtype if float_as_complex and utils.is_float_dtype(_dtype): _dtype = utils.corresponding_complex_dtype(_dtype) if x.ndim == 0 and not all_tensors_equal: zero_dim_tensor_dtype = utils.get_higher_dtype( zero_dim_tensor_dtype, _dtype) else: # x.ndim > 0 or all_tensors_equal one_plus_dim_tensor_dtype = utils.get_higher_dtype( one_plus_dim_tensor_dtype, _dtype) # Prefers dtype of tensors with one or more dimensions if one_plus_dim_tensor_dtype is not None: return one_plus_dim_tensor_dtype return zero_dim_tensor_dtype if highest_type is float: result_dtype = _find_highest_dtype_filtered(args, utils.is_float_dtype) result_dtype = (torch.get_default_dtype() if result_dtype is None else result_dtype) elif highest_type is complex: # NOTE: complex x float type promotion is incorrectly implemented in PyTorch today # it will treat zero dim and non-zero-dim float and complex tensors equally # unless there's a non-zero-dim complex tensor # the following captures this oddity has_one_plus_dim_complex_tensor = False for x in args: if (isinstance(x, TensorLike) and x.ndim > 0 and utils.is_complex_dtype(x.dtype)): has_one_plus_dim_complex_tensor = True break if has_one_plus_dim_complex_tensor: result_dtype = _find_highest_dtype_filtered( args, lambda x: utils.is_float_dtype(x) or utils.is_complex_dtype(x), float_as_complex=True, ) else: # no complex tensors of rank 1+ # NOTE: bugged case where all tensors are equal result_dtype = _find_highest_dtype_filtered( args, lambda x: utils.is_float_dtype(x) or utils.is_complex_dtype(x), float_as_complex=True, all_tensors_equal=True, ) if result_dtype is None: result_dtype = utils.corresponding_complex_dtype( torch.get_default_dtype()) elif highest_type is int: result_dtype = _find_highest_dtype_filtered(args, utils.is_integer_dtype) result_dtype = torch.long if result_dtype is None else result_dtype else: # highest_type is bool result_dtype = torch.bool if type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: return result_dtype, result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH: return _get_computation_dtype(result_dtype), result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: if utils.is_integer_dtype(result_dtype) or utils.is_boolean_dtype( result_dtype): result_dtype = torch.get_default_dtype() return _get_computation_dtype(result_dtype), result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: if utils.is_complex_dtype(result_dtype): # Note: computation still occurs in complex return _get_computation_dtype( result_dtype), utils.corresponding_real_dtype(result_dtype) return _get_computation_dtype(result_dtype), result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG: if utils.is_boolean_dtype(result_dtype): return torch.long, torch.long return result_dtype, result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: return result_dtype, torch.bool else: raise ValueError("Unknown type promotion kind {0}".format( str(type_promotion)))