Ejemplo n.º 1
0
def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
    ndim = self.dim()
    wrapped_dims = utils.canonicalize_dims(ndim, dims)
    assert isinstance(wrapped_dims, tuple)
    for idx in range(ndim - 1, -1, -1):
        if idx in wrapped_dims:
            self = self.squeeze(idx)
    return self
Ejemplo n.º 2
0
def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
    _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1))  # type: ignore[misc]

    if a.ndim <= 1:
        return a

    _permutation = list(range(0, a.ndim))
    _permutation[_dim0] = _dim1
    _permutation[_dim1] = _dim0
    return prims.transpose(a, _permutation)
Ejemplo n.º 3
0
def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor:
    dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1))  # type: ignore[misc]

    if self.dim() <= 1:
        return self

    if dim0 == dim1:
        return self
    perm = list(range(self.dim()))
    perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
    return torch.permute(self, perm)
Ejemplo n.º 4
0
def cat(tensors: TensorSequenceType,
        dim: int = 0,
        out: TensorLikeType = None) -> TensorLikeType:
    if len(tensors) == 0:
        msg = "cat expects at least one tensor, but received zero!"
        raise ValueError(msg)

    _dim = utils.canonicalize_dims(tensors[0].ndim, dim)
    dtype, _ = _elementwise_dtypes(
        *tensors, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)

    _tensors = _convert_dtype(*tensors, dtype=dtype)
    result = prims.concatenate(_tensors, _dim)

    if out is not None:
        out = _maybe_resize_out(out, result.shape)
        return copy_to(out, result,
                       allow_cross_device=False)  # type: ignore[arg-type]

    return result
Ejemplo n.º 5
0
def expand_dims(a: TensorLikeType,
                dimensions: DimsSequenceType) -> TensorLikeType:
    """
    Creates a view of a with a.ndim + len(dimensions) dimensions, with new
    dimensions of length one at the dimensions specified by dimensions.
    """
    dims = sorted(utils.canonicalize_dims(
        a.ndim, dimensions))  # type: ignore[arg-type]
    if len(set(dims)) != len(dims):
        msg = "Received duplicate dimensions to expand in {0}".format(
            str(dimensions))
        raise ValueError(msg)

    new_shape = list(a.shape)
    for idx in dims:
        new_shape.insert(idx, 1)

    broadcast_dimensions = [
        idx for idx in range(len(new_shape)) if idx not in dimensions
    ]
    return broadcast_in_dim(a, new_shape, broadcast_dimensions)
Ejemplo n.º 6
0
def permute(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
    _permutation = utils.canonicalize_dims(a.ndim, dims)
    return prims.transpose(a, _permutation)
Ejemplo n.º 7
0
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
    dims = utils.canonicalize_dims(a.ndim, dims)  # type: ignore[assignment]
    return prims.rev(a, dims)
Ejemplo n.º 8
0
def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
    _dim = utils.canonicalize_dims(tensors[0].ndim, dim)
    return prims.concatenate(tensors, _dim)