Пример #1
0
def softplus(
    a: TensorLikeType,
    beta: Optional[NumberType] = None,
    threshold: NumberType = 20,
    inplace: bool = False,
) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.softplus
    """

    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if beta is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(beta), python_type):
            msg = "beta argument of type {0} cannot be safely cast to type {1}!".format(
                type(beta), python_type)
            raise ValueError(msg)
        scaled_input = a * beta
        rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)),
                                beta)  # type: ignore[arg-type]

    else:
        scaled_input = a
        rhs = torch.log1p(torch.exp(scaled_input))

    return torch.where(scaled_input > threshold, a, rhs)
Пример #2
0
def leaky_relu(a: TensorLikeType,
               negative_slope: float = 0.01,
               inplace: bool = False) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.leaky_relu
    """

    if inplace:
        raise NotImplementedError

    python_type = utils.dtype_to_type(a.dtype)
    if not utils.is_weakly_lesser_type(type(negative_slope), python_type):
        msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!"
        raise ValueError(msg)
    return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope))
Пример #3
0
def _maybe_convert_to_dtype(
        a: Union[TensorLikeType, NumberType, Sequence],
        dtype: torch.dtype) -> Union[TensorLikeType, NumberType, Sequence]:
    import torch._prims as prims
    if isinstance(a, TensorLike):
        if a.dtype != dtype:
            # NOTE: this is incorrect on the CPU
            # See https://github.com/pytorch/pytorch/issues/77553
            return prims.convert_element_type(a, dtype)
        return a
    if isinstance(a, Number):
        return utils.dtype_to_type(dtype)(a)
    if isinstance(a, Sequence):
        return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)

    raise ValueError(
        "Received type {0} that is neither a tensor or a number!".format(
            type(a)))
Пример #4
0
def _maybe_convert_to_dtype(
    a: Union[TensorLikeType, NumberType, Sequence, None], dtype: torch.dtype
) -> Union[TensorLikeType, NumberType, Sequence, None]:
    import torch._prims as prims
    if isinstance(a, TensorLike):
        if a.dtype != dtype:
            # NOTE: this is incorrect on the CPU
            # See https://github.com/pytorch/pytorch/issues/77553
            return prims.convert_element_type(a, dtype)
        return a
    if isinstance(a, Number):
        return utils.dtype_to_type(dtype)(a)
    if isinstance(a, Sequence):
        return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
    # Passthrough None because some functions wrapped with type promotion
    # wrapper might have optional args
    if a is None:
        return None

    raise ValueError(
        "Received type {0} that is neither a tensor or a number!".format(type(a))
    )
Пример #5
0
def elu(a: TensorLikeType,
        alpha: Optional[NumberType] = None,
        inplace: bool = False) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.elu
    """
    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if alpha is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(alpha), python_type):
            msg = (
                "alpha argument of type {0} cannot be safely cast to type {1}!"
                .format(type(alpha), python_type))
            raise ValueError(msg)
        rhs = alpha * torch.expm1(a)
    else:
        rhs = torch.expm1(a)

    return torch.where(a > 0, a, rhs)
Пример #6
0
def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor:
    return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self)