def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] if a.ndim <= 1: return a _permutation = list(range(0, a.ndim)) _permutation[_dim0] = _dim1 _permutation[_dim1] = _dim0 return prims.transpose(a, _permutation)
def matrix_norm( A: TensorLikeType, ord: Union[float, str] = "fro", dim: DimsType = (-2, -1), keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, ) -> TensorLikeType: # shape check_is_matrix(A, "linalg.matrix_norm") # dim dim = utils.canonicalize_dims(A.ndim, dim) if isinstance(dim, int): dim = (dim, ) # type: ignore[assignment] check( len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}") check( dim[0] != dim[1], lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", ) # dtype arg check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") if isinstance(ord, str): # ord check( ord in ("fro", "nuc"), lambda: "linalg.matrix_norm: Order {ord} not supported.", ) # dtype check_fp_or_complex(A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc") if ord == "fro": return vector_norm(A, 2, dim, keepdim, dtype=dtype) else: # ord == "nuc" if dtype is not None: A = prims.convert_element_type(A, dtype) perm = backshift_permutation(dim[0], dim[1], A.ndim) result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) if keepdim: inv_perm = inverse_permutation(perm) result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) return result else: # ord abs_ord = abs(ord) check( abs_ord in (2, 1, float("inf")), lambda: "linalg.matrix_norm: Order {ord} not supported.", ) # dtype check_fp_or_complex(A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2) max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) if abs_ord == 2.0: if dtype is not None: A = prims.convert_element_type(A, dtype) perm = backshift_permutation(dim[0], dim[1], A.ndim) result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) if keepdim: inv_perm = inverse_permutation(perm) result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) return result else: # 1, -1, inf, -inf dim0, dim1 = dim if abs_ord == float("inf"): dim0, dim1 = dim1, dim0 if not keepdim and (dim0 < dim1): dim1 -= 1 return max_min( vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1)
def permute(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: _permutation = utils.canonicalize_dims(a.ndim, dims) return prims.transpose(a, _permutation)