def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor: ndim = self.dim() wrapped_dims = utils.canonicalize_dims(ndim, dims) assert isinstance(wrapped_dims, tuple) for idx in range(ndim - 1, -1, -1): if idx in wrapped_dims: self = self.squeeze(idx) return self
def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] if a.ndim <= 1: return a _permutation = list(range(0, a.ndim)) _permutation[_dim0] = _dim1 _permutation[_dim1] = _dim0 return prims.transpose(a, _permutation)
def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor: dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc] if self.dim() <= 1: return self if dim0 == dim1: return self perm = list(range(self.dim())) perm[dim0], perm[dim1] = perm[dim1], perm[dim0] return torch.permute(self, perm)
def cat(tensors: TensorSequenceType, dim: int = 0, out: TensorLikeType = None) -> TensorLikeType: if len(tensors) == 0: msg = "cat expects at least one tensor, but received zero!" raise ValueError(msg) _dim = utils.canonicalize_dims(tensors[0].ndim, dim) dtype, _ = _elementwise_dtypes( *tensors, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) _tensors = _convert_dtype(*tensors, dtype=dtype) result = prims.concatenate(_tensors, _dim) if out is not None: out = _maybe_resize_out(out, result.shape) return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] return result
def expand_dims(a: TensorLikeType, dimensions: DimsSequenceType) -> TensorLikeType: """ Creates a view of a with a.ndim + len(dimensions) dimensions, with new dimensions of length one at the dimensions specified by dimensions. """ dims = sorted(utils.canonicalize_dims( a.ndim, dimensions)) # type: ignore[arg-type] if len(set(dims)) != len(dims): msg = "Received duplicate dimensions to expand in {0}".format( str(dimensions)) raise ValueError(msg) new_shape = list(a.shape) for idx in dims: new_shape.insert(idx, 1) broadcast_dimensions = [ idx for idx in range(len(new_shape)) if idx not in dimensions ] return broadcast_in_dim(a, new_shape, broadcast_dimensions)
def permute(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: _permutation = utils.canonicalize_dims(a.ndim, dims) return prims.transpose(a, _permutation)
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment] return prims.rev(a, dims)
def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: _dim = utils.canonicalize_dims(tensors[0].ndim, dim) return prims.concatenate(tensors, _dim)