示例#1
0
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.prelu
    """
    check(
        isinstance(a, TensorLike),
        lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
    )
    check(
        isinstance(weight, TensorLike),
        lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
    )

    if weight.numel() != 1:
        check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
        channel_size = a.shape[1] if a.ndim >= 2 else 1
        check(
            weight.numel() == channel_size,
            lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
            f" {weight.numel()} and channel size = {channel_size}.",
        )

    check(
        weight.ndim == 0 or weight.ndim == 1,
        lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
        f"ndim = {weight.ndim}",
    )
    weight = prims.broadcast_in_dim(
        weight, a.shape, tuple() if weight.ndim == 0 else (1,)
    )

    return refs.where(a > 0, a, a * weight)
示例#2
0
def _reshape_meta(a: TensorLikeType, shape: ShapeType):
    assert isinstance(a, TensorLike)
    utils.validate_shape(shape)

    # Validates the tensor and the requested shape have the
    # same number of elements
    numel = reduce(operator.mul, shape)
    if numel != a.numel():
        msg = "Attempting to reshape a tensor with {0} elements to a shape with {1} elements!".format(
            a.numel(), numel)
        raise ValueError(msg)

    return TensorMeta(a,
                      shape=shape,
                      strides=utils.make_contiguous_strides_for(shape))
示例#3
0
def _maybe_resize_out(out: TensorLikeType, shape):
    if out.numel() == 0:
        return prims.resize(out, shape)

    if out.numel() != reduce(operator.mul, shape, 1):
        msg = (
            "An output with one or more elements was resized since it had shape {0} "
            "which does not match the required output shape {1}. "
            "This behavior is deprecated, and in a future PyTorch release outputs will not "
            "be resized unless they have zero elements. "
            "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
            .format(str(out.shape), str(shape)))
        warnings.warn(msg)
        return prims.resize(out, shape)

    return out
示例#4
0
def vector_norm(
    x: TensorLikeType,
    ord: float = 2.0,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    # Checks
    check_fp_or_complex(x.dtype, "linalg.vector_norm")

    if isinstance(dim, int):
        dim = [dim]  # type: ignore[assignment]
    elif not isinstance(dim, List) and dim is not None:
        # refs.amin just accepts List rather than DimType (Tuple)
        dim = list(dim)  # type: ignore[assignment]

    if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
        check(
            dim is not None and len(dim) != 0,
            lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
            "because the operation does not have an identity",
        )
        shape = x.shape
        assert dim is not None  # mypy does not seem to be able to see through check?
        for d in dim:
            check(
                shape[d] != 0,
                lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
                f"dimension {d} because this dimension is empty and the "
                "operation does not have an identity",
            )
    check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")

    computation_dtype, result_dtype = reduction_dtypes(
        x, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
    )

    to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype)

    # Implementation
    if ord == 0.0:
        return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
    elif ord == float("inf"):
        return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim))
    elif ord == float("-inf"):
        return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim))
    else:
        # From here on the computation dtype is important as the reduction is non-trivial
        x = prims.convert_element_type(x, computation_dtype)
        reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim)

        # Avoid computing a sqrt in abs and then squaring (more stable)
        # This could potentially be done for complex dtypes as
        # x = torch.real(torch.conj(x) * x))
        # and it should be more stable, but it's not clear whether it'll be faster on, say
        # CPU (abs is 1 vectorised operation), so leaving it just for real dtypes for now
        if not (ord % 2.0 == 0.0 and is_float_dtype(x.dtype)):
            x = torch.abs(x)
        return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
示例#5
0
def _reshape_meta(a: TensorLikeType, shape: Sequence):
    assert isinstance(a, TensorLike)
    utils.validate_shape(shape)

    # Validates the tensor and the requested shape have the
    # same number of elements
    numel = reduce(lambda acc, x: acc * x, shape)
    assert a.numel() == numel
示例#6
0
def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
    assert isinstance(a, TensorLike)
    assert isinstance(b, TensorLike)

    # Validates the cast is safe
    # TODO: move this as an option on the reference
    # a_typ = utils.dtype_to_type(a.dtype)
    # b_typ = utils.dtype_to_type(b.dtype)
    # if a_typ is not utils.get_higher_type(a_typ, b_typ):
    #     raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!")

    # Validates the tensors have the same number of elements
    if a.numel() != b.numel():
        msg = "Attempting to copy {0} elements to a tensor with {1} elements!".format(
            b.numel(), a.numel())
        raise RuntimeError(msg)

    return a
示例#7
0
def _reshape_view_helper(a: TensorLikeType, shape: ShapeType, *,
                         allow_copy: bool) -> TensorLikeType:
    # NOTE: Reshape may be given a shape with a -1 length
    # This indicates that the dimension's length should be inferred
    # Creates a valid shape

    for idx in range(len(shape)):
        if shape[idx] == -1:
            # Verifies there's only one dimension of length -1 in the shape
            if shape.count(-1) > 1:
                msg = "Can only infer the length of one dimension, but got shape {0}!".format(
                    str(shape))
                raise ValueError(msg)

            # TODO: improve error message
            if a.numel() > 0:
                length = reduce(operator.floordiv,
                                (x for x in shape if x != -1), a.numel())
            else:
                msg = "Cannot reshape a tensor of zero elements into shape {0} because the unspecified length is ambiguous!".format(
                    str(shape))
                raise ValueError(msg)

            shape = list(shape)
            shape[idx] = length
            break

    # Short-circuits if shape is the same
    utils.validate_shape(shape)
    if tuple(a.shape) == tuple(shape):
        return prims.view_of(a)

    numel = reduce(operator.mul, shape) if len(shape) > 0 else 1
    if a.numel() != numel:
        msg = "Attempting to reshape a tensor with shape {0} and {1} elements to a shape {2} with {3} elements!".format(
            str(a.shape), a.numel(), str(shape), numel)
        raise ValueError(msg)

    # Special-cases tensors with no elements
    if a.numel() == 0:
        return as_strided(a, shape, utils.make_contiguous_strides_for(shape))

    # Special-cases reshaping zero dim tensors
    if a.ndim == 0:
        _a = a
        for length in shape:
            assert length == 1
            _a = unsqueeze(_a, -1)
        return _a

    # Special-cases reshaping to zero dim tensors
    if len(shape) == 0:
        _a = a
        for length in a.shape:
            assert length == 1
            _a = squeeze(_a, -1)
        return _a

    # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape

    # NOTE [Reshape Algorithm]
    # This algorithm works by attempting to greedily construct the desired dimensions in
    # the output shape, left to right. It does this by, conceptually, accumulating
    # dimensions of the original tensor, also left to right, until the dimension
    # can be constructed using prims.split_dim.
    # The algorithm also has special handling for tail squeezes/unsqueezes, like
    # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
    #
    # This algorithm does not flatten the original tensor and then split dims as appropriate
    # because that would create copies more often than this algorithm. flatten is the only
    # operation below which can create a view or a copy, and while it prefers creating
    # views it may sometimes create a copy if the tensor's strides do not permit a view.
    # As a result, this algorithm tries to minimize flattening.
    #
    # Note that a better version of this algorithm may exist. Regions which could be
    # flattened without creating a copy can be identified in advance, and that might
    # allow fewer flatten calls or faster short-circuiting to make a copy.
    idx = 0
    a_ = a
    for length in shape:
        # Handles tail unsqueezes
        if idx >= a_.ndim:
            assert length == 1
            a_ = unsqueeze(a_, -1)
            idx = idx + 1
            continue

        # Skips dimensions that are already the correct length
        if length == a_.shape[idx]:
            idx = idx + 1
            continue

        # Gathers enough original dimensions such that this new dimension can be created
        # Note that this accumulation will terminate because we've verified a and the shape
        # specify the same number of elements above
        accum = a_.shape[idx]
        end = idx
        while accum % length != 0:
            end = end + 1
            accum = accum * a_.shape[end]
        if end != idx:
            # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
            # This flattening is why reshape sometimes creates a copy -- because flattening
            # may return a view of a copy

            # Checks if collapse can be a view and short-circuits to copying reshape if it can't
            new_shape, new_strides = prims._collapse_view_helper(
                a_, idx, end + 1)
            if new_shape is None:
                if allow_copy:
                    return prims.reshape(a, shape)

                msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor with shape {2}!".format(
                    a.shape, a.stride(), shape)
                raise ValueError(msg)

            a_ = flatten(a_, idx, end)

        # Splits the (possibly flattened) dimension to create the desired dim length
        if accum != length:
            a_ = prims.split_dim(a_, idx, length)

        idx = idx + 1

    # Squeezes tail
    while idx < a_.ndim:
        assert a_.shape[idx] == 1
        a_ = squeeze(a_, idx)

    return a_
示例#8
0
def _resize_meta(a: TensorLikeType, shape: Union[torch.Size, List[int],
                                                 Tuple[int, ...]]):
    assert a.numel() == 0
    return TensorMeta(a,
                      shape=shape,
                      strides=utils.make_contiguous_strides_for(shape))