def isclose( a: TensorLikeType, b: TensorLikeType, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, ) -> TensorLikeType: if a.dtype != b.dtype: msg = "Attempting to compare tensors of different dtypes {0} and {1}!".format( a.dtype, b.dtype) raise ValueError(a, b) if rtol < 0: msg = "rtol must be greater than or equal to zero, but got {0}!".format( rtol) if atol < 0: msg = "atol must be greater than or equal to zero, but got {0}!".format( atol) close = eq(a, b) if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): close = logical_or(close, logical_and(isnan(a), isnan(b))) # Note: In case of zero tolerances the closeness inequality degenerates to an equality check. # In this case, the short-circuit prevents false positives as detailed in the paragraph below. if atol == 0 and rtol == 0: return close # Note [closeness error computation] # atol and rtol are provided as doubles, so the computation # rtol * other will produce a float or complex tensor. # When the difference (self - other) is compared to it then the # tensor representing the difference will also be cast to float or complex. # However, since (self - other) in uint8 is very likely to produce a # negative value, this moves the cast forward so the difference is # always computed in a float or complex type. # If the values of the integer tensors cannot be exactly represented # by the default scalar type then this may cause an incorrect result. if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype( a.dtype): a = prims.convert_element_type(a, torch.get_default_dtype()) b = prims.convert_element_type(b, torch.get_default_dtype()) allowed_error = add(atol, abs(mul(b, rtol))) actual_error = abs(sub(a, b)) # Computes finite closeness result = logical_or( close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))) return result
def float_power( a: Union[Tensor, Number], b: Union[Tensor, Number], out: Optional[Tensor] = None ) -> Tensor: assert isinstance(a, (Tensor, Number)) assert isinstance(b, (Tensor, Number)) assert out is None or isinstance(out, TensorLike) # Special-cases Number x Number case if isinstance(a, Number) and isinstance(b, Number): a, b = utils.wrap_scalars(a, b) # Handles type promotion dtype = utils.get_higher_dtype(a, b) if utils.is_complex_dtype(dtype): dtype = torch.complex128 else: dtype = torch.float64 a, b = _convert_dtype(a, b, dtype=dtype) # Broadcasting a, b = broadcast(a, b) result = prims.pow(a, b) if out is not None: out = _maybe_resize_out(out, result.shape) return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] return result
def meta_linalg_qr_helper(input, mode): if mode == "reduced": compute_q = True reduced_mode = True elif mode == "complete": compute_q = True reduced_mode = False elif mode == "r": compute_q = False reduced_mode = True else: raise RuntimeError(f"qr received unrecognized mode {mode}") check(input.ndim >= 2, lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor") check( utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), lambda: f"expected float or complex tensor, but got {input.dtype}" ) m = input.size(-2) n = input.size(-1) mn = min(m, n) if compute_q: Qt_shape = list(input.size()) Qt_shape[-2] = mn if reduced_mode else m Qt_shape[-1] = m Q = input.new_empty(Qt_shape) Q.transpose_(-2, -1) else: Q = input.new_empty(0) Rt_shape = list(input.size()) Rt_shape[-2] = n Rt_shape[-1] = mn if reduced_mode or not compute_q else m R = input.new_empty(Rt_shape) R.transpose_(-2, -1) return (Q, R)
def meta_linalg_cholesky_ex(input, upper=False, check_errors=False): check(input.ndim >= 2, lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor") check( utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), lambda: f"expected float or complex tensor, but got {input.dtype}" ) check(input.size(-1) == input.size(-2), lambda: f"expected square matrix but got {input.shape}") L = input.new_empty(input.size()) L.transpose_(-2, -1) info_sizes = input.size()[:-2] info = input.new_empty(info_sizes, dtype=torch.int) return L, info
def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): """ Checks related to the dtype kwarg in `linalg.*norm` functions """ if dtype is not None: check( is_float_dtype(dtype) or is_complex_dtype(dtype), lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", ) check( is_complex_dtype(dtype) == is_complex_dtype(x_dtype), lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( fn_name=fn_name, d="complex" if is_complex_dtype(x_dtype) else "real", dtype=dtype, ), ) check( get_higher_dtype(dtype, x_dtype) == dtype, lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " "without narrowing to the specified dtype ({dtype})", )
def float_power( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType], ) -> Tensor: # Handles type promotion dtype = utils.get_higher_dtype(a, b) assert dtype is not None if utils.is_complex_dtype(dtype): dtype = torch.complex128 else: dtype = torch.float64 a = _maybe_convert_to_dtype(a, dtype=dtype) # type: ignore[assignment] b = _maybe_convert_to_dtype(b, dtype=dtype) # type: ignore[assignment] a, b = _maybe_broadcast(a, b) return prims.pow(a, b)
def is_c_of_r(complex_dtype, real_dtype): return is_complex_dtype(complex_dtype) and \ corresponding_real_dtype(complex_dtype) == real_dtype
def _elementwise_dtypes( *_args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND ) -> Tuple[torch.dtype, torch.dtype]: """ Computes the computation and result dtypes for elementwise type promotion on the given arguments and with the given elementwise type promotion kind. Note that not all inputs to an elementwise operation necessarily participate in type promotion. For example, the "alpha" parameter of torch.add does not participate in type promotion, although it is cast to the Python type corresponding to the computation dtype that the type promotion algorithm determines. Default elementwise type promotion, which all other type promotion kinds tweak (see below), first decides which of four ordered types to use: bool -> integer -> floating point -> complex The selected type is the "lowest" type in the above list such that all number arguments have a weakly "lower" type and all tensor arguments have a weakly lower corresponding type for their dtype. Once the type is determined, the particular result dtype is found. The dtypes are partially ordered as follows: bool -> uint8, int8 -> int16 -> int32 -> int64 -> float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 The result dtype is selected by: - if no tensor's dtype has the same corresponding type as the one selected, then the result dtype is the (default) dtype corresponding to the selected type (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype) - if the result type is complex then the dtype is: - the default complex dtype if there are no floating point or complex tensors - if there are floating point or complex tensors with one or more dimensions, then the complex dtype corresponding to the highest corresponding complex dtype among those tensors (for example, double + cfloat -> cdouble) - if there are only floating point or complex tensors with zero dimensions, then the complex dtype corresponding to the highest corresponding complex dtype among those tensors - if the first two cases do not apply, the result dtype is the highest dtype among all tensors with one or more dimensions of the output type, and if there are no such tensors then it's the highest dtype among all tensors with zero dimensions of the output type (for example, long + half -> half, even if the half tensor has zero dimensions) The "corresponding complex dtypes" are: float16 -> complex32 bfloat16 -> complex64 float32 -> complex64 float64 -> complex128 complex32 -> complex32 complex64 -> complex64 complex128 -> complex128 The DEFAULT type promotion option computes per above, and uses the result dtype as the computation dtype. The OP_MATH, INT_TO_FLOAT, COMPLEX_TO_FLOAT and BOOL_TO_LONG type promotion options tweak the above slightly. OP_MATH determines a "computation dtype" from the result dtype, and the mapping is simple: float16 -> float32 bfloat16 -> float32 complex32 -> complex64 INT_TO_FLOAT, COMPLEX_TO_FLOAT, and BOOL_TO_LONG compute the computation type in the same way, but INT_TO_FLOAT and BOOL_TO_LONG map the result dtype to another dtype first, and COMPLEX_TO_FLOAT maps its result dtype after the compuation dtype is determined, as follows: INT_TO_FLOAT maps all boolean and integer result dtypes to the default floating point dtype COMPLEX_TO_FLOAT maps complex result dtypes to their corresponding floating point dtype BOOL_TO_LONG maps the boolean result dtype to long The "corresponding floating point dtypes" are: complex32 -> float16 complex64 -> float32 complex128 -> float64 The ALWAYS_BOOL type promotion option always maps the result dtype to bool. Example operators for each type promotion option: DEFAULT : nextafter OP_MATH : add INT_TO_FLOAT : sin COMPLEX_TO_FLOAT : abs BOOL_TO_LONG : pow ALWAYS_BOOL : eq """ args = tuple(x for x in _args if x is not None) highest_type: type = bool for x in args: if not isinstance(x, (Number, TensorLike)): msg = ( "Unexpected type {0} when computing elementwise type promotion!" .format(str(type(x)))) raise ValueError(msg) if isinstance(x, Number): highest_type = utils.get_higher_type(highest_type, type(x)) else: # x is a TensorLike highest_type = utils.get_higher_type(highest_type, utils.dtype_to_type(x.dtype)) result_dtype = None def _find_highest_dtype_filtered(args, filter, *, float_as_complex=False, all_tensors_equal=False ) -> Optional[torch.dtype]: zero_dim_tensor_dtype = None one_plus_dim_tensor_dtype = None for x in args: if isinstance(x, TensorLike) and filter(x.dtype): _dtype = x.dtype if float_as_complex and utils.is_float_dtype(_dtype): _dtype = utils.corresponding_complex_dtype(_dtype) if x.ndim == 0 and not all_tensors_equal: zero_dim_tensor_dtype = utils.get_higher_dtype( zero_dim_tensor_dtype, _dtype) else: # x.ndim > 0 or all_tensors_equal one_plus_dim_tensor_dtype = utils.get_higher_dtype( one_plus_dim_tensor_dtype, _dtype) # Prefers dtype of tensors with one or more dimensions if one_plus_dim_tensor_dtype is not None: return one_plus_dim_tensor_dtype return zero_dim_tensor_dtype if highest_type is float: result_dtype = _find_highest_dtype_filtered(args, utils.is_float_dtype) result_dtype = (torch.get_default_dtype() if result_dtype is None else result_dtype) elif highest_type is complex: # NOTE: complex x float type promotion is incorrectly implemented in PyTorch today # it will treat zero dim and non-zero-dim float and complex tensors equally # unless there's a non-zero-dim complex tensor # the following captures this oddity has_one_plus_dim_complex_tensor = False for x in args: if (isinstance(x, TensorLike) and x.ndim > 0 and utils.is_complex_dtype(x.dtype)): has_one_plus_dim_complex_tensor = True break if has_one_plus_dim_complex_tensor: result_dtype = _find_highest_dtype_filtered( args, lambda x: utils.is_float_dtype(x) or utils.is_complex_dtype(x), float_as_complex=True, ) else: # no complex tensors of rank 1+ # NOTE: bugged case where all tensors are equal result_dtype = _find_highest_dtype_filtered( args, lambda x: utils.is_float_dtype(x) or utils.is_complex_dtype(x), float_as_complex=True, all_tensors_equal=True, ) if result_dtype is None: result_dtype = utils.corresponding_complex_dtype( torch.get_default_dtype()) elif highest_type is int: result_dtype = _find_highest_dtype_filtered(args, utils.is_integer_dtype) result_dtype = torch.long if result_dtype is None else result_dtype else: # highest_type is bool result_dtype = torch.bool if type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: return result_dtype, result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH: return _get_computation_dtype(result_dtype), result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: if utils.is_integer_dtype(result_dtype) or utils.is_boolean_dtype( result_dtype): result_dtype = torch.get_default_dtype() return _get_computation_dtype(result_dtype), result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: if utils.is_complex_dtype(result_dtype): # Note: computation still occurs in complex return _get_computation_dtype( result_dtype), utils.corresponding_real_dtype(result_dtype) return _get_computation_dtype(result_dtype), result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG: if utils.is_boolean_dtype(result_dtype): return torch.long, torch.long return result_dtype, result_dtype elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: return result_dtype, torch.bool else: raise ValueError("Unknown type promotion kind {0}".format( str(type_promotion)))