Beispiel #1
0
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.prelu
    """
    check(
        isinstance(a, TensorLike),
        lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
    )
    check(
        isinstance(weight, TensorLike),
        lambda:
        f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
    )

    if weight.numel() != 1:
        check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
        channel_size = a.shape[1] if a.ndim >= 2 else 1
        check(
            weight.numel() == channel_size,
            lambda:
            f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
            f" {weight.numel()} and channel size = {channel_size}.",
        )

    check(
        weight.ndim == 0 or weight.ndim == 1,
        lambda:
        f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
        f"ndim = {weight.ndim}",
    )
    weight = prims.broadcast_in_dim(weight, a.shape,
                                    tuple() if weight.ndim == 0 else (1, ))

    return refs.where(a > 0, a, a * weight)
Beispiel #2
0
def _safe_copy_out(
    *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
):
    # Checks same device
    if copy_from.device != copy_to.device:
        msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
            copy_from.device, copy_to.device
        )
        raise RuntimeError(msg)

    # Checks safe cast
    if exact_dtype:
        utils.check(
            copy_from.dtype == copy_to.dtype,
            lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
            "but got {copy_to.dtype} instead",
        )
    else:
        utils.check(
            utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
            lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
            "but this can't be cast because it is not safe!",
        )

    return copy_to.copy_(copy_from)
Beispiel #3
0
def _maybe_resize_out(out: TensorLikeType, shape):
    if out.numel() == 0:
        return out.resize_(shape)

    if out.numel() != reduce(operator.mul, shape, 1):
        msg = (
            "An output with one or more elements was resized since it had shape {0} "
            "which does not match the required output shape {1}. "
            "This behavior is deprecated, and in a future PyTorch release outputs will not "
            "be resized unless they have zero elements. "
            "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
            .format(str(out.shape), str(shape)))
        warnings.warn(msg)
        return out.resize_(shape)

    return out
Beispiel #4
0
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
    check(a.ndim == 2,
          lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
    check(p >= 0, lambda: "pdist only supports non-negative p values")
    # For p == 2 we can use an efficient implementation, but other values of p
    # require creating a much bigger tensor for an intermediate step
    if p == 2:
        aTa = torch.mm(a, a.T)
        aTa_diag = torch.diag(aTa)
        t = torch.sqrt(
            torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0))
    else:
        t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2)
    i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device)
    return t.flatten().index_select(0, i[0] * t.shape[0] + i[1])
Beispiel #5
0
def _resize_fft_input(x: TensorLikeType, dims: Tuple[int, ...],
                      sizes: Tuple[int, ...]) -> TensorLikeType:
    """
    Fixes the shape of x such that x.size(dims[i]) == sizes[i],
    either by zero-padding, or by slicing x starting from 0.
    """
    assert len(dims) == len(sizes)
    must_copy = False
    x_sizes = x.shape
    pad_amount = [0] * len(x_sizes) * 2
    for i in range(len(dims)):
        if sizes[i] == -1:
            continue

        if x_sizes[dims[i]] < sizes[i]:
            must_copy = True
            pad_idx = len(pad_amount) - 2 * dims[i] - 1
            pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]

        if x_sizes[dims[i]] > sizes[i]:
            x = x.narrow(dims[i], 0, sizes[i])

    return torch.constant_pad_nd(x, pad_amount) if must_copy else x
Beispiel #6
0
def vector_norm(
    x: TensorLikeType,
    ord: float = 2.0,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    # Checks
    check_fp_or_complex(x.dtype, "linalg.vector_norm")

    if isinstance(dim, int):
        dim = [dim]  # type: ignore[assignment]
    elif not isinstance(dim, List) and dim is not None:
        # refs.amin just accepts List rather than DimType (Tuple)
        dim = list(dim)  # type: ignore[assignment]

    if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
        check(
            dim is not None and len(dim) != 0,
            lambda:
            f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
            "because the operation does not have an identity",
        )
        shape = x.shape
        assert dim is not None  # mypy does not seem to be able to see through check?
        for d in dim:
            check(
                shape[d] != 0,
                lambda:
                f"linalg.vector_norm cannot compute the {ord} norm on the "
                f"dimension {d} because this dimension is empty and the "
                "operation does not have an identity",
            )
    check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")

    computation_dtype, result_dtype = utils.reduction_dtypes(
        x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype)

    to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype)

    # Implementation
    if ord == 0.0:
        return refs.sum(refs.ne(x, 0.0),
                        dim=dim,
                        keepdim=keepdim,
                        dtype=result_dtype)
    elif ord == float("inf"):
        return to_result_dtype(
            refs.amax(torch.abs(x), dim=dim, keepdim=keepdim))
    elif ord == float("-inf"):
        return to_result_dtype(
            refs.amin(torch.abs(x), dim=dim, keepdim=keepdim))
    else:
        # From here on the computation dtype is important as the reduction is non-trivial
        x = prims.convert_element_type(x, computation_dtype)
        reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim)

        if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)):
            x = torch.abs(x)
        return to_result_dtype(
            torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))