Example #1
0
def softshrink(a: TensorLikeType, lambd: float = 0.5):
    # Formula for reference,
    # softshrink(x) = x - lambd if x > lambd
    #               = x + lambd if x < -lambd
    #               = 0 otherwise
    ge_mask = a > lambd
    le_mask = a < -lambd
    zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask))
    result = refs.where(ge_mask, a - lambd, a)
    result = refs.where(le_mask, a + lambd, result)
    return refs.where(zero_mask, 0, result)
Example #2
0
def hinge_embedding_loss(
    input: TensorLikeType,
    target: TensorLikeType,
    margin: float = 1.0,
    reduction: str = "mean",
) -> TensorLikeType:
    # Formula of loss (implementation gets confusing with all the refs.foo)
    # loss_without_reduction = input if y == 1
    #                        = max(0, margin - input) if y == -1
    _check_reduction_value(reduction)
    margin_clamp = refs.maximum(refs.sub(margin, input), 0)
    output_margin = refs.where(refs.ne(target, 1), margin_clamp, 0)
    output_self = refs.where(refs.ne(target, -1), input, 0)
    loss = refs.add(output_margin, output_self)
    return _apply_loss_reduction(loss, reduction)
Example #3
0
def softshrink(a: TensorLikeType, lambd: float = 0.5):
    # Formula for reference,
    # softshrink(x) = x - lambd if x > lambd
    #               = x + lambd if x < -lambd
    #               = 0 otherwise
    check(
        lambd >= 0,
        lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
    )
    ge_mask = a > lambd
    le_mask = a < -lambd
    zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask))
    result = refs.where(ge_mask, a - lambd, a)
    result = refs.where(le_mask, a + lambd, result)
    return refs.where(zero_mask, 0, result)
Example #4
0
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.prelu
    """
    check(
        isinstance(a, TensorLike),
        lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
    )
    check(
        isinstance(weight, TensorLike),
        lambda:
        f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
    )

    if weight.numel() != 1:
        check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
        channel_size = a.shape[1] if a.ndim >= 2 else 1
        check(
            weight.numel() == channel_size,
            lambda:
            f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
            f" {weight.numel()} and channel size = {channel_size}.",
        )

    check(
        weight.ndim == 0 or weight.ndim == 1,
        lambda:
        f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
        f"ndim = {weight.ndim}",
    )
    weight = prims.broadcast_in_dim(weight, a.shape,
                                    tuple() if weight.ndim == 0 else (1, ))

    return refs.where(a > 0, a, a * weight)
Example #5
0
def softplus(
    a: TensorLikeType,
    beta: Optional[NumberType] = None,
    threshold: NumberType = 20,
    inplace: bool = False,
) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.softplus
    """

    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if beta is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(beta), python_type):
            msg = "beta argument of type {0} cannot be safely cast to type {1}!".format(
                type(beta), python_type)
            raise ValueError(msg)
        scaled_input = refs.mul(a, beta)
        rhs = refs.true_divide(refs.log1p(refs.exp(scaled_input)), beta)
    else:
        scaled_input = a
        rhs = refs.log1p(refs.exp(scaled_input))

    return refs.where(refs.gt(scaled_input, threshold), a, rhs)
Example #6
0
def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.selu
    """

    if inplace:
        raise NotImplementedError

    alpha = 1.6732632423543772848170429916717
    scale = 1.0507009873554804934193349852946

    rhs = refs.mul(alpha, refs.expm1(a))

    return refs.mul(scale, refs.where(refs.gt(a, 0), a, rhs))
Example #7
0
def elu(a: TensorLikeType,
        alpha: Optional[NumberType] = None,
        inplace: bool = False) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.elu
    """

    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if alpha is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(alpha), python_type):
            msg = (
                "alpha argument of type {0} cannot be safely cast to type {1}!"
                .format(type(alpha), python_type))
            raise ValueError(msg)
        rhs = refs.mul(alpha, refs.expm1(a))
    else:
        rhs = refs.expm1(a)

    return refs.where(refs.gt(a, 0), a, rhs)
Example #8
0
def hardshrink(a: TensorLikeType, lambd: float = 0.5):
    # Formula for reference,
    # hardshrink(x) = x if x > lambd
    #               = x if x < -lambd
    #               = 0 otherwise
    return refs.where(refs.logical_and(a >= -lambd, a <= lambd), 0, a)
Example #9
0
def hardshrink(a: TensorLikeType, lambd: float = 0.5):
    # Formula for reference,
    # hardshrink(x) = x if x > lambd
    #               = x if x < -lambd
    #               = 0 otherwise
    return refs.where(abs(a) > abs(lambd), a, 0)