示例#1
0
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.prelu
    """
    check(
        isinstance(a, TensorLike),
        lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
    )
    check(
        isinstance(weight, TensorLike),
        lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
    )

    if weight.numel() != 1:
        check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
        channel_size = a.shape[1] if a.ndim >= 2 else 1
        check(
            weight.numel() == channel_size,
            lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
            f" {weight.numel()} and channel size = {channel_size}.",
        )

    check(
        weight.ndim == 0 or weight.ndim == 1,
        lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
        f"ndim = {weight.ndim}",
    )
    weight = prims.broadcast_in_dim(
        weight, a.shape, tuple() if weight.ndim == 0 else (1,)
    )

    return refs.where(a > 0, a, a * weight)
示例#2
0
def _split_dim_meta(a: TensorLikeType, dim: int,
                    outer_length: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    utils.validate_idx(a.ndim, dim)
    utils.validate_dim_length(outer_length)

    # Verifies the dim can be split with the specified lhs_length
    _inner_length = a.shape[dim] / outer_length
    inner_length: int = int(_inner_length)

    if inner_length != _inner_length:
        msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format(
            a.shape[dim], outer_length)
        raise ValueError(msg)

    new_shape: List[int] = []
    new_strides: List[int] = []
    for idx in range(a.ndim):
        if idx == dim:
            new_shape.extend((outer_length, inner_length))
            new_strides.extend(
                (a.stride()[idx] * inner_length, a.stride()[idx]))
        else:
            new_shape.append(a.shape[idx])
            new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#3
0
def _reshape_meta(a: TensorLikeType, shape: ShapeType):
    assert isinstance(a, TensorLike)
    utils.validate_shape(shape)

    # Validates the tensor and the requested shape have the
    # same number of elements
    numel = reduce(operator.mul, shape)
    if numel != a.numel():
        msg = "Attempting to reshape a tensor with {0} elements to a shape with {1} elements!".format(
            a.numel(), numel)
        raise ValueError(msg)

    return TensorMeta(a,
                      shape=shape,
                      strides=utils.make_contiguous_strides_for(shape))
示例#4
0
def _maybe_resize_out(out: TensorLikeType, shape):
    if out.numel() == 0:
        return prims.resize(out, shape)

    if out.numel() != reduce(operator.mul, shape, 1):
        msg = (
            "An output with one or more elements was resized since it had shape {0} "
            "which does not match the required output shape {1}. "
            "This behavior is deprecated, and in a future PyTorch release outputs will not "
            "be resized unless they have zero elements. "
            "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
            .format(str(out.shape), str(shape)))
        warnings.warn(msg)
        return prims.resize(out, shape)

    return out
示例#5
0
def vector_norm(
    x: TensorLikeType,
    ord: float = 2.0,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    # Checks
    check_fp_or_complex(x.dtype, "linalg.vector_norm")

    if isinstance(dim, int):
        dim = [dim]  # type: ignore[assignment]
    elif not isinstance(dim, List) and dim is not None:
        # refs.amin just accepts List rather than DimType (Tuple)
        dim = list(dim)  # type: ignore[assignment]

    if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
        check(
            dim is not None and len(dim) != 0,
            lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
            "because the operation does not have an identity",
        )
        shape = x.shape
        assert dim is not None  # mypy does not seem to be able to see through check?
        for d in dim:
            check(
                shape[d] != 0,
                lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
                f"dimension {d} because this dimension is empty and the "
                "operation does not have an identity",
            )
    check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")

    computation_dtype, result_dtype = reduction_dtypes(
        x, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
    )

    to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype)

    # Implementation
    if ord == 0.0:
        return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
    elif ord == float("inf"):
        return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim))
    elif ord == float("-inf"):
        return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim))
    else:
        # From here on the computation dtype is important as the reduction is non-trivial
        x = prims.convert_element_type(x, computation_dtype)
        reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim)

        # Avoid computing a sqrt in abs and then squaring (more stable)
        # This could potentially be done for complex dtypes as
        # x = torch.real(torch.conj(x) * x))
        # and it should be more stable, but it's not clear whether it'll be faster on, say
        # CPU (abs is 1 vectorised operation), so leaving it just for real dtypes for now
        if not (ord % 2.0 == 0.0 and is_float_dtype(x.dtype)):
            x = torch.abs(x)
        return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
示例#6
0
def _collapse_view_meta(a: TensorLikeType, start: int,
                        end: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)

    shape = a.shape
    strides = a.stride()

    utils.validate_idx(shape, start)
    utils.validate_exclusive_idx(shape, end)

    # Verifies end is strictly greater than start
    # (Collapse requires a non-empty interval)
    assert end > start

    length = 1
    stride = 1
    for idx in range(start, end):
        if idx != (end - 1):
            assert strides[idx] == strides[idx + 1] * shape[idx + 1]
        length = length * shape[idx]
        stride = stride * strides[idx]

    new_shape = shape[:start] + (length, ) + shape[end:]
    new_strides = strides[:start] + (stride, ) + shape[end:]

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#7
0
def _reshape_meta(a: TensorLikeType, shape: Sequence):
    assert isinstance(a, TensorLike)
    utils.validate_shape(shape)

    # Validates the tensor and the requested shape have the
    # same number of elements
    numel = reduce(lambda acc, x: acc * x, shape)
    assert a.numel() == numel
示例#8
0
def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
    assert isinstance(a, TensorLike)
    assert isinstance(b, TensorLike)

    # Validates the cast is safe
    # TODO: move this as an option on the reference
    # a_typ = utils.dtype_to_type(a.dtype)
    # b_typ = utils.dtype_to_type(b.dtype)
    # if a_typ is not utils.get_higher_type(a_typ, b_typ):
    #     raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!")

    # Validates the tensors have the same number of elements
    if a.numel() != b.numel():
        msg = "Attempting to copy {0} elements to a tensor with {1} elements!".format(
            b.numel(), a.numel())
        raise RuntimeError(msg)

    return a
示例#9
0
def _split_dim_meta(a: TensorLikeType, dim: int,
                    outer_length: int) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    utils.validate_idx(a.shape, dim)
    utils.validate_dim_length(outer_length)

    # Verifies the dim can be split with the specified lhs_length
    _inner_length = a.shape[dim] / outer_length
    inner_length: int = int(_inner_length)
    assert inner_length == _inner_length

    new_shape: List[int] = []
    new_strides: List[int] = []
    for idx in a.shape:
        if idx == dim:
            new_shape.extend((outer_length, inner_length))
            new_strides.extend(
                (a.stride()[idx] * inner_length, a.stride()[idx]))
        else:
            new_shape.append(a.shape[idx])
            new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#10
0
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
    assert isinstance(a, TensorLike)

    for idx in dimensions:
        utils.validate_idx(a.ndim, idx)
        assert a.shape[idx] == 1

    new_shape = []
    new_strides = []
    for idx in range(len(a.shape)):
        if idx in dimensions:
            continue

        new_shape.append(a.shape[idx])
        new_strides.append(a.stride()[idx])

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#11
0
def _transpose_meta(a: TensorLikeType,
                    permutation: DimsSequenceType) -> TensorLikeType:
    if a.ndim != len(permutation):
        msg = "Attempting to permute a tensor of rank {0}, but received a permutation of length {1}!".format(
            a.ndim, len(permutation))
        raise ValueError(msg)

    if not utils.is_valid_permutation(a.ndim, permutation):
        msg = "Received an invalid permutation, {0}!".format(permutation)
        raise ValueError(msg)

    new_shape = [0] * a.ndim
    new_strides = [0] * a.ndim
    for idx, dim in enumerate(permutation):
        new_shape[idx] = a.shape[dim]
        new_strides[idx] = a.stride()[dim]

    return TensorMeta(a, shape=tuple(new_shape), strides=tuple(new_strides))
示例#12
0
def _broadcast_in_dim_meta(a: TensorLikeType, shape: ShapeType,
                           broadcast_dimensions: Sequence[int]):
    # Type checks
    assert isinstance(a, TensorLike)
    assert isinstance(shape, Sequence)
    assert isinstance(broadcast_dimensions, Sequence)

    # every dimension must be accounted for
    assert a.ndim == len(broadcast_dimensions)

    # broadcast shape must have weakly more dimensions
    assert len(shape) >= a.ndim

    # broadcast_dimensions must be an ascending sequence
    # (no relative reordering of dims) of integers and
    # each dimension must be within the new shape
    def _greater_than_reduce(acc, x):
        assert isinstance(x, int)
        assert x > acc
        assert x < len(shape)

        return x

    reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions,
           -1)

    # shape must be broadcastable to
    for idx, new_idx in enumerate(broadcast_dimensions):
        assert a.shape[idx] == 1 or a.shape[idx] == shape[new_idx]

    new_strides = []
    original_idx = 0
    for idx in range(len(shape)):
        if idx in broadcast_dimensions:
            new_strides.append(a.stride()[original_idx])
            original_idx = original_idx + 1
        else:
            new_strides.append(0)

    return TensorMeta(a, shape=shape, strides=new_strides)
示例#13
0
def _collapse_view_helper(
        a: TensorLikeType, start: int,
        end: int) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
    assert isinstance(a, TensorLike)

    # Special-case for zero dimensional tensors
    if a.ndim == 0:
        shape = (1, )
        strides = (1, )
    else:
        shape = a.shape  # type: ignore[assignment]
        strides = a.stride()

    utils.validate_idx(len(shape), start)
    utils.validate_exclusive_idx(len(shape), end)

    # Verifies end is strictly greater than start
    # (Collapse requires a non-empty interval)
    if end <= start:
        msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format(
            end, start)
        raise ValueError(msg)

    length = 1
    stride = 1
    for idx in range(start, end):
        if idx != (end - 1):
            if not (strides[idx] == strides[idx + 1] * shape[idx + 1]):
                return None, None
        length = length * shape[idx]
        stride = stride * strides[idx]

    new_shape = shape[:start] + (length, ) + shape[end:]
    new_strides = strides[:start] + (stride, ) + shape[end:]

    return new_shape, new_strides
示例#14
0
def _slice_meta(
    a: TensorLikeType,
    start_indices: DimsSequenceType,
    limit_indices: DimsSequenceType,
    strides: Optional[StrideType] = None,
) -> TensorLikeType:
    _strides = strides if strides is not None else [1] * len(start_indices)

    if a.ndim != len(start_indices):
        msg = "Attempting to slice tensor of rank {0} with start_indices of length {1}!".format(
            a.ndim, len(start_indices))
        raise ValueError(msg)

    if a.ndim != len(limit_indices):
        msg = "Attempting to slice tensor of rank {0} with limit_indices of length {1}!".format(
            a.ndim, len(limit_indices))
        raise ValueError(msg)

    if a.ndim != len(_strides):
        msg = (
            "Attempting to slice tensor of rank {0} with strides of length {1}!"
            .format(a.ndim, len(limit_indices)))
        raise ValueError(msg)

    for x, y in zip(start_indices, a.shape):
        if x < 0:
            msg = "Attempting to slice a tensor with a negative start index of {0}!".format(
                x)
            raise ValueError(msg)
        if x > y:
            msg = (
                "Attempting to slice a tensor but a start index in {0} is greater than"
                " the length of its corresponding dimension in shape {1}".
                format(start_indices, a.shape))
            raise ValueError(msg)

    for x, y, z in zip(limit_indices, a.shape, start_indices):
        if x < 0:
            msg = "Attempting to slice a tensor with a negative stop index of {0}!".format(
                x)
            raise ValueError(msg)
        if x > y:
            msg = (
                "Attempting to slice a tensor but a stop index in {0} is greater than the length of "
                " its corresponding dimension in shape {1}".format(
                    limit_indices, a.shape))
            raise ValueError(msg)
        if x < z:
            msg = (
                "Attempting to slice a tensor but a start index in {0} is greater than "
                " its corresponding stop index {1}".format(x, z))

    for x in _strides:
        if x <= 0:
            msg = (
                "Attempting to slice a tensor with a non-positive step of {0}!"
                .format(x))
            raise ValueError(msg)

    new_shape = []
    for x, y, z in zip(start_indices, limit_indices, _strides):
        new_shape.append(math.floor((y - x) / z))

    new_strides = []
    for x, y in zip(a.stride(), _strides):
        new_strides.append(x * y)

    return TensorMeta(a, shape=new_shape, strides=new_strides)
示例#15
0
def _reshape_view_helper(a: TensorLikeType, shape: ShapeType, *,
                         allow_copy: bool) -> TensorLikeType:
    # NOTE: Reshape may be given a shape with a -1 length
    # This indicates that the dimension's length should be inferred
    # Creates a valid shape

    for idx in range(len(shape)):
        if shape[idx] == -1:
            # Verifies there's only one dimension of length -1 in the shape
            if shape.count(-1) > 1:
                msg = "Can only infer the length of one dimension, but got shape {0}!".format(
                    str(shape))
                raise ValueError(msg)

            # TODO: improve error message
            if a.numel() > 0:
                length = reduce(operator.floordiv,
                                (x for x in shape if x != -1), a.numel())
            else:
                msg = "Cannot reshape a tensor of zero elements into shape {0} because the unspecified length is ambiguous!".format(
                    str(shape))
                raise ValueError(msg)

            shape = list(shape)
            shape[idx] = length
            break

    # Short-circuits if shape is the same
    utils.validate_shape(shape)
    if tuple(a.shape) == tuple(shape):
        return prims.view_of(a)

    numel = reduce(operator.mul, shape) if len(shape) > 0 else 1
    if a.numel() != numel:
        msg = "Attempting to reshape a tensor with shape {0} and {1} elements to a shape {2} with {3} elements!".format(
            str(a.shape), a.numel(), str(shape), numel)
        raise ValueError(msg)

    # Special-cases tensors with no elements
    if a.numel() == 0:
        return as_strided(a, shape, utils.make_contiguous_strides_for(shape))

    # Special-cases reshaping zero dim tensors
    if a.ndim == 0:
        _a = a
        for length in shape:
            assert length == 1
            _a = unsqueeze(_a, -1)
        return _a

    # Special-cases reshaping to zero dim tensors
    if len(shape) == 0:
        _a = a
        for length in a.shape:
            assert length == 1
            _a = squeeze(_a, -1)
        return _a

    # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape

    # NOTE [Reshape Algorithm]
    # This algorithm works by attempting to greedily construct the desired dimensions in
    # the output shape, left to right. It does this by, conceptually, accumulating
    # dimensions of the original tensor, also left to right, until the dimension
    # can be constructed using prims.split_dim.
    # The algorithm also has special handling for tail squeezes/unsqueezes, like
    # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
    #
    # This algorithm does not flatten the original tensor and then split dims as appropriate
    # because that would create copies more often than this algorithm. flatten is the only
    # operation below which can create a view or a copy, and while it prefers creating
    # views it may sometimes create a copy if the tensor's strides do not permit a view.
    # As a result, this algorithm tries to minimize flattening.
    #
    # Note that a better version of this algorithm may exist. Regions which could be
    # flattened without creating a copy can be identified in advance, and that might
    # allow fewer flatten calls or faster short-circuiting to make a copy.
    idx = 0
    a_ = a
    for length in shape:
        # Handles tail unsqueezes
        if idx >= a_.ndim:
            assert length == 1
            a_ = unsqueeze(a_, -1)
            idx = idx + 1
            continue

        # Skips dimensions that are already the correct length
        if length == a_.shape[idx]:
            idx = idx + 1
            continue

        # Gathers enough original dimensions such that this new dimension can be created
        # Note that this accumulation will terminate because we've verified a and the shape
        # specify the same number of elements above
        accum = a_.shape[idx]
        end = idx
        while accum % length != 0:
            end = end + 1
            accum = accum * a_.shape[end]
        if end != idx:
            # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
            # This flattening is why reshape sometimes creates a copy -- because flattening
            # may return a view of a copy

            # Checks if collapse can be a view and short-circuits to copying reshape if it can't
            new_shape, new_strides = prims._collapse_view_helper(
                a_, idx, end + 1)
            if new_shape is None:
                if allow_copy:
                    return prims.reshape(a, shape)

                msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor with shape {2}!".format(
                    a.shape, a.stride(), shape)
                raise ValueError(msg)

            a_ = flatten(a_, idx, end)

        # Splits the (possibly flattened) dimension to create the desired dim length
        if accum != length:
            a_ = prims.split_dim(a_, idx, length)

        idx = idx + 1

    # Squeezes tail
    while idx < a_.ndim:
        assert a_.shape[idx] == 1
        a_ = squeeze(a_, idx)

    return a_
示例#16
0
def _resize_meta(a: TensorLikeType, shape: Union[torch.Size, List[int],
                                                 Tuple[int, ...]]):
    assert a.numel() == 0
    return TensorMeta(a,
                      shape=shape,
                      strides=utils.make_contiguous_strides_for(shape))