Example #1
0
def _split_dim_meta(a: TensorLikeType, dim: int,
                    outer_length: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    utils.validate_idx(a.ndim, dim)
    utils.validate_dim_length(outer_length)

    # Verifies the dim can be split with the specified lhs_length
    _inner_length = a.shape[dim] / outer_length
    inner_length: int = int(_inner_length)

    if inner_length != _inner_length:
        msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format(
            a.shape[dim], outer_length)
        raise ValueError(msg)

    new_shape: List[int] = []
    new_strides: List[int] = []
    for idx in range(a.ndim):
        if idx == dim:
            new_shape.extend((outer_length, inner_length))
            new_strides.extend(
                (a.stride()[idx] * inner_length, a.stride()[idx]))
        else:
            new_shape.append(a.shape[idx])
            new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
Example #2
0
def _collapse_view_meta(a: TensorLikeType, start: int,
                        end: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)

    shape = a.shape
    strides = a.stride()

    utils.validate_idx(shape, start)
    utils.validate_exclusive_idx(shape, end)

    # Verifies end is strictly greater than start
    # (Collapse requires a non-empty interval)
    assert end > start

    length = 1
    stride = 1
    for idx in range(start, end):
        if idx != (end - 1):
            assert strides[idx] == strides[idx + 1] * shape[idx + 1]
        length = length * shape[idx]
        stride = stride * strides[idx]

    new_shape = shape[:start] + (length, ) + shape[end:]
    new_strides = strides[:start] + (stride, ) + shape[end:]

    return TensorMeta(a, shape=new_shape, strides=new_strides)
Example #3
0
def _split_dim_meta(a: TensorLikeType, dim: int,
                    outer_length: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    utils.validate_idx(a.shape, dim)
    utils.validate_dim_length(outer_length)

    # Verifies the dim can be split with the specified lhs_length
    _inner_length = a.shape[dim] / outer_length
    inner_length: int = int(_inner_length)
    assert inner_length == _inner_length

    new_shape: List[int] = []
    new_strides: List[int] = []
    for idx in a.shape:
        if idx == dim:
            new_shape.extend((outer_length, inner_length))
            new_strides.extend(
                (a.stride()[idx] * inner_length, a.stride()[idx]))
        else:
            new_shape.append(a.shape[idx])
            new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
Example #4
0
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
    assert isinstance(a, TensorLike)

    for idx in dimensions:
        utils.validate_idx(a.ndim, idx)
        assert a.shape[idx] == 1

    new_shape = []
    new_strides = []
    for idx in range(len(a.shape)):
        if idx in dimensions:
            continue

        new_shape.append(a.shape[idx])
        new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
Example #5
0
def _transpose_meta(a: TensorLikeType,
                    permutation: DimsSequenceType) -> TensorLikeType:
    if a.ndim != len(permutation):
        msg = "Attempting to permute a tensor of rank {0}, but received a permutation of length {1}!".format(
            a.ndim, len(permutation))
        raise ValueError(msg)

    if not utils.is_valid_permutation(a.ndim, permutation):
        msg = "Received an invalid permutation, {0}!".format(permutation)
        raise ValueError(msg)

    new_shape = [0] * a.ndim
    new_strides = [0] * a.ndim
    for idx, dim in enumerate(permutation):
        new_shape[idx] = a.shape[dim]
        new_strides[idx] = a.stride()[dim]

    return TensorMeta(a, shape=tuple(new_shape), strides=tuple(new_strides))
Example #6
0
def _broadcast_in_dim_meta(a: TensorLikeType, shape: ShapeType,
                           broadcast_dimensions: Sequence[int]):
    # Type checks
    assert isinstance(a, TensorLike)
    assert isinstance(shape, Sequence)
    assert isinstance(broadcast_dimensions, Sequence)

    # every dimension must be accounted for
    assert a.ndim == len(broadcast_dimensions)

    # broadcast shape must have weakly more dimensions
    assert len(shape) >= a.ndim

    # broadcast_dimensions must be an ascending sequence
    # (no relative reordering of dims) of integers and
    # each dimension must be within the new shape
    def _greater_than_reduce(acc, x):
        assert isinstance(x, int)
        assert x > acc
        assert x < len(shape)

        return x

    reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions,
           -1)

    # shape must be broadcastable to
    for idx, new_idx in enumerate(broadcast_dimensions):
        assert a.shape[idx] == 1 or a.shape[idx] == shape[new_idx]

    new_strides = []
    original_idx = 0
    for idx in range(len(shape)):
        if idx in broadcast_dimensions:
            new_strides.append(a.stride()[original_idx])
            original_idx = original_idx + 1
        else:
            new_strides.append(0)

    return TensorMeta(a, shape=shape, strides=new_strides)
Example #7
0
def _collapse_view_helper(
        a: TensorLikeType, start: int,
        end: int) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
    assert isinstance(a, TensorLike)

    # Special-case for zero dimensional tensors
    if a.ndim == 0:
        shape = (1, )
        strides = (1, )
    else:
        shape = a.shape  # type: ignore[assignment]
        strides = a.stride()

    utils.validate_idx(len(shape), start)
    utils.validate_exclusive_idx(len(shape), end)

    # Verifies end is strictly greater than start
    # (Collapse requires a non-empty interval)
    if end <= start:
        msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format(
            end, start)
        raise ValueError(msg)

    length = 1
    stride = 1
    for idx in range(start, end):
        if idx != (end - 1):
            if not (strides[idx] == strides[idx + 1] * shape[idx + 1]):
                return None, None
        length = length * shape[idx]
        stride = stride * strides[idx]

    new_shape = shape[:start] + (length, ) + shape[end:]
    new_strides = strides[:start] + (stride, ) + shape[end:]

    return new_shape, new_strides
Example #8
0
def _slice_meta(
    a: TensorLikeType,
    start_indices: DimsSequenceType,
    limit_indices: DimsSequenceType,
    strides: Optional[StrideType] = None,
) -> TensorLikeType:
    _strides = strides if strides is not None else [1] * len(start_indices)

    if a.ndim != len(start_indices):
        msg = "Attempting to slice tensor of rank {0} with start_indices of length {1}!".format(
            a.ndim, len(start_indices))
        raise ValueError(msg)

    if a.ndim != len(limit_indices):
        msg = "Attempting to slice tensor of rank {0} with limit_indices of length {1}!".format(
            a.ndim, len(limit_indices))
        raise ValueError(msg)

    if a.ndim != len(_strides):
        msg = (
            "Attempting to slice tensor of rank {0} with strides of length {1}!"
            .format(a.ndim, len(limit_indices)))
        raise ValueError(msg)

    for x, y in zip(start_indices, a.shape):
        if x < 0:
            msg = "Attempting to slice a tensor with a negative start index of {0}!".format(
                x)
            raise ValueError(msg)
        if x > y:
            msg = (
                "Attempting to slice a tensor but a start index in {0} is greater than"
                " the length of its corresponding dimension in shape {1}".
                format(start_indices, a.shape))
            raise ValueError(msg)

    for x, y, z in zip(limit_indices, a.shape, start_indices):
        if x < 0:
            msg = "Attempting to slice a tensor with a negative stop index of {0}!".format(
                x)
            raise ValueError(msg)
        if x > y:
            msg = (
                "Attempting to slice a tensor but a stop index in {0} is greater than the length of "
                " its corresponding dimension in shape {1}".format(
                    limit_indices, a.shape))
            raise ValueError(msg)
        if x < z:
            msg = (
                "Attempting to slice a tensor but a start index in {0} is greater than "
                " its corresponding stop index {1}".format(x, z))

    for x in _strides:
        if x <= 0:
            msg = (
                "Attempting to slice a tensor with a non-positive step of {0}!"
                .format(x))
            raise ValueError(msg)

    new_shape = []
    for x, y, z in zip(start_indices, limit_indices, _strides):
        new_shape.append(math.floor((y - x) / z))

    new_strides = []
    for x, y in zip(a.stride(), _strides):
        new_strides.append(x * y)

    return TensorMeta(a, shape=new_shape, strides=new_strides)
Example #9
0
def _reshape_view_helper(a: TensorLikeType, shape: ShapeType, *,
                         allow_copy: bool) -> TensorLikeType:
    # NOTE: Reshape may be given a shape with a -1 length
    # This indicates that the dimension's length should be inferred
    # Creates a valid shape

    for idx in range(len(shape)):
        if shape[idx] == -1:
            # Verifies there's only one dimension of length -1 in the shape
            if shape.count(-1) > 1:
                msg = "Can only infer the length of one dimension, but got shape {0}!".format(
                    str(shape))
                raise ValueError(msg)

            # TODO: improve error message
            if a.numel() > 0:
                length = reduce(operator.floordiv,
                                (x for x in shape if x != -1), a.numel())
            else:
                msg = "Cannot reshape a tensor of zero elements into shape {0} because the unspecified length is ambiguous!".format(
                    str(shape))
                raise ValueError(msg)

            shape = list(shape)
            shape[idx] = length
            break

    # Short-circuits if shape is the same
    utils.validate_shape(shape)
    if tuple(a.shape) == tuple(shape):
        return prims.view_of(a)

    numel = reduce(operator.mul, shape) if len(shape) > 0 else 1
    if a.numel() != numel:
        msg = "Attempting to reshape a tensor with shape {0} and {1} elements to a shape {2} with {3} elements!".format(
            str(a.shape), a.numel(), str(shape), numel)
        raise ValueError(msg)

    # Special-cases tensors with no elements
    if a.numel() == 0:
        return as_strided(a, shape, utils.make_contiguous_strides_for(shape))

    # Special-cases reshaping zero dim tensors
    if a.ndim == 0:
        _a = a
        for length in shape:
            assert length == 1
            _a = unsqueeze(_a, -1)
        return _a

    # Special-cases reshaping to zero dim tensors
    if len(shape) == 0:
        _a = a
        for length in a.shape:
            assert length == 1
            _a = squeeze(_a, -1)
        return _a

    # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape

    # NOTE [Reshape Algorithm]
    # This algorithm works by attempting to greedily construct the desired dimensions in
    # the output shape, left to right. It does this by, conceptually, accumulating
    # dimensions of the original tensor, also left to right, until the dimension
    # can be constructed using prims.split_dim.
    # The algorithm also has special handling for tail squeezes/unsqueezes, like
    # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
    #
    # This algorithm does not flatten the original tensor and then split dims as appropriate
    # because that would create copies more often than this algorithm. flatten is the only
    # operation below which can create a view or a copy, and while it prefers creating
    # views it may sometimes create a copy if the tensor's strides do not permit a view.
    # As a result, this algorithm tries to minimize flattening.
    #
    # Note that a better version of this algorithm may exist. Regions which could be
    # flattened without creating a copy can be identified in advance, and that might
    # allow fewer flatten calls or faster short-circuiting to make a copy.
    idx = 0
    a_ = a
    for length in shape:
        # Handles tail unsqueezes
        if idx >= a_.ndim:
            assert length == 1
            a_ = unsqueeze(a_, -1)
            idx = idx + 1
            continue

        # Skips dimensions that are already the correct length
        if length == a_.shape[idx]:
            idx = idx + 1
            continue

        # Gathers enough original dimensions such that this new dimension can be created
        # Note that this accumulation will terminate because we've verified a and the shape
        # specify the same number of elements above
        accum = a_.shape[idx]
        end = idx
        while accum % length != 0:
            end = end + 1
            accum = accum * a_.shape[end]
        if end != idx:
            # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
            # This flattening is why reshape sometimes creates a copy -- because flattening
            # may return a view of a copy

            # Checks if collapse can be a view and short-circuits to copying reshape if it can't
            new_shape, new_strides = prims._collapse_view_helper(
                a_, idx, end + 1)
            if new_shape is None:
                if allow_copy:
                    return prims.reshape(a, shape)

                msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor with shape {2}!".format(
                    a.shape, a.stride(), shape)
                raise ValueError(msg)

            a_ = flatten(a_, idx, end)

        # Splits the (possibly flattened) dimension to create the desired dim length
        if accum != length:
            a_ = prims.split_dim(a_, idx, length)

        idx = idx + 1

    # Squeezes tail
    while idx < a_.ndim:
        assert a_.shape[idx] == 1
        a_ = squeeze(a_, idx)

    return a_