コード例 #1
0
def softplus(
    a: TensorLikeType,
    beta: Optional[NumberType] = None,
    threshold: NumberType = 20,
    inplace: bool = False,
) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.softplus
    """

    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if beta is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(beta), python_type):
            msg = "beta argument of type {0} cannot be safely cast to type {1}!".format(
                type(beta), python_type)
            raise ValueError(msg)
        scaled_input = refs.mul(a, beta)
        rhs = refs.true_divide(refs.log1p(refs.exp(scaled_input)), beta)
    else:
        scaled_input = a
        rhs = refs.log1p(refs.exp(scaled_input))

    return refs.where(refs.gt(scaled_input, threshold), a, rhs)
コード例 #2
0
def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.selu
    """

    if inplace:
        raise NotImplementedError

    alpha = 1.6732632423543772848170429916717
    scale = 1.0507009873554804934193349852946

    rhs = refs.mul(alpha, refs.expm1(a))

    return refs.mul(scale, refs.where(refs.gt(a, 0), a, rhs))
コード例 #3
0
def elu(a: TensorLikeType,
        alpha: Optional[NumberType] = None,
        inplace: bool = False) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.elu
    """

    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if alpha is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(alpha), python_type):
            msg = (
                "alpha argument of type {0} cannot be safely cast to type {1}!"
                .format(type(alpha), python_type))
            raise ValueError(msg)
        rhs = refs.mul(alpha, refs.expm1(a))
    else:
        rhs = refs.expm1(a)

    return refs.where(refs.gt(a, 0), a, rhs)