def apply_activation(intermediate_output, intermediate_activation):
    """Applies selected activation function to intermediate output."""
    if intermediate_activation is None:
        return intermediate_output

    if intermediate_activation == 'gelu':
        intermediate_output = nn.gelu(intermediate_output)
    elif intermediate_activation == 'relu':
        intermediate_output = nn.relu(intermediate_output)
    elif intermediate_activation == 'sigmoid':
        intermediate_output = nn.sigmoid(intermediate_output)
    elif intermediate_activation == 'softmax':
        intermediate_output = nn.softmax(intermediate_output)
    elif intermediate_activation == 'celu':
        intermediate_output = nn.celu(intermediate_output)
    elif intermediate_activation == 'elu':
        intermediate_output = nn.elu(intermediate_output)
    elif intermediate_activation == 'log_sigmoid':
        intermediate_output = nn.log_sigmoid(intermediate_output)
    elif intermediate_activation == 'log_softmax':
        intermediate_output = nn.log_softmax(intermediate_output)
    elif intermediate_activation == 'soft_sign':
        intermediate_output = nn.soft_sign(intermediate_output)
    elif intermediate_activation == 'softplus':
        intermediate_output = nn.softplus(intermediate_output)
    elif intermediate_activation == 'swish':
        intermediate_output = nn.swish(intermediate_output)
    elif intermediate_activation == 'tanh':
        intermediate_output = jnp.tanh(intermediate_output)
    else:
        raise NotImplementedError(
            '%s activation function is not yet supported.' %
            intermediate_activation)

    return intermediate_output
Exemple #2
0
def binary_cross_entropy_with_logits(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits +
                    (1. - labels) * jnp.log(-jnp.expm1(logits)))
Exemple #3
0
def binary_cross_entropy_with_logits(logits, labels):
    return -jnp.sum(labels * nn.log_sigmoid(logits) + (1 - labels) *
                    (nn.log_sigmoid(logits) - logits))