示例#1
0
def _convert_element_type_meta(a: TensorLikeType,
                               dtype: torch.dtype) -> TensorLikeType:
    # Type checks
    assert isinstance(a, TensorLike)
    assert isinstance(dtype, torch.dtype)

    return TensorMeta(a, dtype=dtype)
示例#2
0
def _concatenate_meta(tensors: Sequence[TensorLikeType],
                      dim: int) -> TensorLikeType:
    assert len(tensors) > 0

    for tensor in tensors:
        assert isinstance(tensor, TensorLike)

    utils.check_same_dtype(tensors)
    utils.check_same_device(tensors, allow_scalars=False)

    shape = tensors[0].shape
    utils.validate_idx(shape, dim)

    # Verifies same shape (except in the concat dimension)
    concat_length = 0
    for tensor in tensors:
        for idx, (common_length, length) in enumerate(zip(shape,
                                                          tensor.shape)):
            if idx == dim:
                concat_length = concat_length + length
            else:
                assert length == common_length

    new_shape = list(tensors[0].shape).copy()
    new_shape[dim] = concat_length
    return TensorMeta(
        tensors[0],
        shape=new_shape,
        strides=utils.make_contiguous_strides_for(new_shape),
    )
示例#3
0
def _split_dim_meta(a: TensorLikeType, dim: int,
                    outer_length: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    utils.validate_idx(a.ndim, dim)
    utils.validate_dim_length(outer_length)

    # Verifies the dim can be split with the specified lhs_length
    _inner_length = a.shape[dim] / outer_length
    inner_length: int = int(_inner_length)

    if inner_length != _inner_length:
        msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format(
            a.shape[dim], outer_length)
        raise ValueError(msg)

    new_shape: List[int] = []
    new_strides: List[int] = []
    for idx in range(a.ndim):
        if idx == dim:
            new_shape.extend((outer_length, inner_length))
            new_strides.extend(
                (a.stride()[idx] * inner_length, a.stride()[idx]))
        else:
            new_shape.append(a.shape[idx])
            new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#4
0
def _collapse_view_meta(a: TensorLikeType, start: int,
                        end: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)

    shape = a.shape
    strides = a.stride()

    utils.validate_idx(shape, start)
    utils.validate_exclusive_idx(shape, end)

    # Verifies end is strictly greater than start
    # (Collapse requires a non-empty interval)
    assert end > start

    length = 1
    stride = 1
    for idx in range(start, end):
        if idx != (end - 1):
            assert strides[idx] == strides[idx + 1] * shape[idx + 1]
        length = length * shape[idx]
        stride = stride * strides[idx]

    new_shape = shape[:start] + (length, ) + shape[end:]
    new_strides = strides[:start] + (stride, ) + shape[end:]

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#5
0
def _concatenate_meta(tensors: Sequence[TensorLikeType],
                      dim: int) -> TensorLikeType:
    if len(tensors) == 0:
        msg = "concatenate expects at least one tensor, but received zero!"
        raise ValueError(msg)

    for tensor in tensors:
        assert isinstance(tensor, TensorLike)

    utils.check_same_dtype(*tensors)
    utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)

    shape = tensors[0].shape
    utils.validate_idx(tensors[0].ndim, dim)

    # Verifies same shape (except in the concat dimension)
    concat_length = 0
    for tensor in tensors:
        for idx, (common_length, length) in enumerate(zip(shape,
                                                          tensor.shape)):
            if idx == dim:
                concat_length = concat_length + length
            else:
                assert length == common_length

    new_shape = list(tensors[0].shape).copy()
    new_shape[dim] = concat_length
    return TensorMeta(
        tensors[0],
        shape=new_shape,
        strides=utils.make_contiguous_strides_for(new_shape),
    )
示例#6
0
def _collapse_view_meta(a: TensorLikeType, start: int,
                        end: int) -> TensorLikeType:
    new_shape, new_strides = _collapse_view_helper(a, start, end)

    if new_shape is None:
        msg = "Attempting to view a collapsed tensor, but no such view exists!"
        raise ValueError(msg)

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#7
0
def _elementwise_meta(*args, type_promotion):
    """
    Meta function for elementwise operations that produce outputs in the same dtype
    as their inputs.

    Stride logic is currently incorrect.
    """

    assert len(args) > 0

    utils.check_same_device(*args, allow_cpu_scalar_tensors=True)
    utils.check_same_shape(*args, allow_cpu_scalar_tensors=True)
    utils.check_same_dtype(*args)

    strides = None
    tensor = None
    number = None
    for arg in args:
        if isinstance(arg, TensorLike):
            if strides is None:
                strides = arg.stride()

            if tensor is None:
                tensor = arg

            if arg.stride() != strides:
                return TensorMeta(arg,
                                  strides=utils.make_contiguous_strides_for(
                                      arg.shape))
        elif isinstance(arg, Number):
            if number is None:
                number = arg

    # TODO: fix strides
    if tensor is not None:
        if 0 in tensor.stride() and tensor.numel() > 0:
            return TensorMeta(tensor,
                              strides=utils.make_contiguous_strides_for(
                                  tensor.shape))
        else:
            return TensorMeta(tensor)

    return TensorMeta(number)
示例#8
0
def _reduction_meta(inp, dims, *, output_dtype=None):
    """
    Meta function for single output reduction operations
    Stride logic is incorrect
    """
    assert isinstance(inp, TensorLike)
    if output_dtype is None:
        output_dtype = inp.dtype
    output_shape = utils.compute_reduction_output_shape(inp.shape, dims)
    return TensorMeta(shape=output_shape,
                      dtype=output_dtype,
                      device=inp.device)
示例#9
0
    def _traced(*args, executor="aten"):
        ctx = PrimContext()
        with ctx:
            placeholders = []
            for arg in args:
                if isinstance(arg, torch.Tensor):
                    placeholders.append(ctx.placeholder(TensorMeta(arg)))
                else:
                    placeholders.append(ctx.placeholder(arg))

            result = fn(*placeholders)
            ctx.output(result)
        return execute(ctx, *args, executor=executor)
示例#10
0
    def _traced(*args, executor="aten"):
        ctx: PrimContext
        with torch.overrides.push_torch_function_mode(
                PrimContext) as ctx:  # type: ignore[attr-defined, assignment]
            placeholders = []
            for arg in args:
                if isinstance(arg, torch.Tensor):
                    placeholders.append(ctx.placeholder(TensorMeta(arg)))
                else:
                    placeholders.append(ctx.placeholder(arg))

            result = fn(*placeholders)
            ctx.output(result)
        return execute(ctx, *args, executor=executor)
示例#11
0
def _reshape_meta(a: TensorLikeType, shape: ShapeType):
    assert isinstance(a, TensorLike)
    utils.validate_shape(shape)

    # Validates the tensor and the requested shape have the
    # same number of elements
    numel = reduce(operator.mul, shape)
    if numel != a.numel():
        msg = "Attempting to reshape a tensor with {0} elements to a shape with {1} elements!".format(
            a.numel(), numel)
        raise ValueError(msg)

    return TensorMeta(a,
                      shape=shape,
                      strides=utils.make_contiguous_strides_for(shape))
示例#12
0
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
    assert isinstance(a, TensorLike)

    for idx in dimensions:
        utils.validate_idx(a.ndim, idx)
        assert a.shape[idx] == 1

    new_shape = []
    new_strides = []
    for idx in range(len(a.shape)):
        if idx in dimensions:
            continue

        new_shape.append(a.shape[idx])
        new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#13
0
def _transpose_meta(a: TensorLikeType,
                    permutation: DimsSequenceType) -> TensorLikeType:
    if a.ndim != len(permutation):
        msg = "Attempting to permute a tensor of rank {0}, but received a permutation of length {1}!".format(
            a.ndim, len(permutation))
        raise ValueError(msg)

    if not utils.is_valid_permutation(a.ndim, permutation):
        msg = "Received an invalid permutation, {0}!".format(permutation)
        raise ValueError(msg)

    new_shape = [0] * a.ndim
    new_strides = [0] * a.ndim
    for idx, dim in enumerate(permutation):
        new_shape[idx] = a.shape[dim]
        new_strides[idx] = a.stride()[dim]

    return TensorMeta(a, shape=tuple(new_shape), strides=tuple(new_strides))
示例#14
0
def _broadcast_in_dim_meta(a: TensorLikeType, shape: ShapeType,
                           broadcast_dimensions: Sequence[int]):
    # Type checks
    assert isinstance(a, TensorLike)
    assert isinstance(shape, Sequence)
    assert isinstance(broadcast_dimensions, Sequence)

    # every dimension must be accounted for
    assert a.ndim == len(broadcast_dimensions)

    # broadcast shape must have weakly more dimensions
    assert len(shape) >= a.ndim

    # broadcast_dimensions must be an ascending sequence
    # (no relative reordering of dims) of integers and
    # each dimension must be within the new shape
    def _greater_than_reduce(acc, x):
        assert isinstance(x, int)
        assert x > acc
        assert x < len(shape)

        return x

    reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions,
           -1)

    # shape must be broadcastable to
    for idx, new_idx in enumerate(broadcast_dimensions):
        assert a.shape[idx] == 1 or a.shape[idx] == shape[new_idx]

    new_strides = []
    original_idx = 0
    for idx in range(len(shape)):
        if idx in broadcast_dimensions:
            new_strides.append(a.stride()[original_idx])
            original_idx = original_idx + 1
        else:
            new_strides.append(0)

    return TensorMeta(a, shape=shape, strides=new_strides)
示例#15
0
def _split_dim_meta(a: TensorLikeType, dim: int,
                    outer_length: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    utils.validate_idx(a.shape, dim)
    utils.validate_dim_length(outer_length)

    # Verifies the dim can be split with the specified lhs_length
    _inner_length = a.shape[dim] / outer_length
    inner_length: int = int(_inner_length)
    assert inner_length == _inner_length

    new_shape: List[int] = []
    new_strides: List[int] = []
    for idx in a.shape:
        if idx == dim:
            new_shape.extend((outer_length, inner_length))
            new_strides.extend(
                (a.stride()[idx] * inner_length, a.stride()[idx]))
        else:
            new_shape.append(a.shape[idx])
            new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#16
0
def _slice_meta(
    a: TensorLikeType,
    start_indices: DimsSequenceType,
    limit_indices: DimsSequenceType,
    strides: Optional[StrideType] = None,
) -> TensorLikeType:
    _strides = strides if strides is not None else [1] * len(start_indices)

    if a.ndim != len(start_indices):
        msg = "Attempting to slice tensor of rank {0} with start_indices of length {1}!".format(
            a.ndim, len(start_indices))
        raise ValueError(msg)

    if a.ndim != len(limit_indices):
        msg = "Attempting to slice tensor of rank {0} with limit_indices of length {1}!".format(
            a.ndim, len(limit_indices))
        raise ValueError(msg)

    if a.ndim != len(_strides):
        msg = (
            "Attempting to slice tensor of rank {0} with strides of length {1}!"
            .format(a.ndim, len(limit_indices)))
        raise ValueError(msg)

    for x, y in zip(start_indices, a.shape):
        if x < 0:
            msg = "Attempting to slice a tensor with a negative start index of {0}!".format(
                x)
            raise ValueError(msg)
        if x > y:
            msg = (
                "Attempting to slice a tensor but a start index in {0} is greater than"
                " the length of its corresponding dimension in shape {1}".
                format(start_indices, a.shape))
            raise ValueError(msg)

    for x, y, z in zip(limit_indices, a.shape, start_indices):
        if x < 0:
            msg = "Attempting to slice a tensor with a negative stop index of {0}!".format(
                x)
            raise ValueError(msg)
        if x > y:
            msg = (
                "Attempting to slice a tensor but a stop index in {0} is greater than the length of "
                " its corresponding dimension in shape {1}".format(
                    limit_indices, a.shape))
            raise ValueError(msg)
        if x < z:
            msg = (
                "Attempting to slice a tensor but a start index in {0} is greater than "
                " its corresponding stop index {1}".format(x, z))

    for x in _strides:
        if x <= 0:
            msg = (
                "Attempting to slice a tensor with a non-positive step of {0}!"
                .format(x))
            raise ValueError(msg)

    new_shape = []
    for x, y, z in zip(start_indices, limit_indices, _strides):
        new_shape.append(math.floor((y - x) / z))

    new_strides = []
    for x, y in zip(a.stride(), _strides):
        new_strides.append(x * y)

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#17
0
def _resize_meta(a: TensorLikeType, shape: Union[torch.Size, List[int],
                                                 Tuple[int, ...]]):
    return TensorMeta(a,
                      shape=shape,
                      strides=utils.make_contiguous_strides_for(shape))
示例#18
0
 def wrap(t):
     if isinstance(t, torch.Tensor):
         return TensorMeta(t)
     else:
         return t
示例#19
0
def _device_put_meta(a: TensorLikeType,
                     device: Union[str, torch.device]) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    assert isinstance(device, (str, torch.device))

    return TensorMeta(a, device=utils.wrap_device(device))
示例#20
0
def _clone_meta(a: TensorLikeType, *,
                memory_format: torch.memory_format) -> TensorLikeType:
    return TensorMeta(a)
示例#21
0
def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
    utils.validate_dimension_indices(a.ndim, dims)
    return TensorMeta(a)
示例#22
0
def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
    return TensorMeta(a)