def apply_activation(intermediate_output, intermediate_activation): """Applies selected activation function to intermediate output.""" if intermediate_activation is None: return intermediate_output if intermediate_activation == 'gelu': intermediate_output = nn.gelu(intermediate_output) elif intermediate_activation == 'relu': intermediate_output = nn.relu(intermediate_output) elif intermediate_activation == 'sigmoid': intermediate_output = nn.sigmoid(intermediate_output) elif intermediate_activation == 'softmax': intermediate_output = nn.softmax(intermediate_output) elif intermediate_activation == 'celu': intermediate_output = nn.celu(intermediate_output) elif intermediate_activation == 'elu': intermediate_output = nn.elu(intermediate_output) elif intermediate_activation == 'log_sigmoid': intermediate_output = nn.log_sigmoid(intermediate_output) elif intermediate_activation == 'log_softmax': intermediate_output = nn.log_softmax(intermediate_output) elif intermediate_activation == 'soft_sign': intermediate_output = nn.soft_sign(intermediate_output) elif intermediate_activation == 'softplus': intermediate_output = nn.softplus(intermediate_output) elif intermediate_activation == 'swish': intermediate_output = nn.swish(intermediate_output) elif intermediate_activation == 'tanh': intermediate_output = jnp.tanh(intermediate_output) else: raise NotImplementedError( '%s activation function is not yet supported.' % intermediate_activation) return intermediate_output
def binary_cross_entropy_with_logits(logits, labels): logits = nn.log_sigmoid(logits) return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))
def binary_cross_entropy_with_logits(logits, labels): return -jnp.sum(labels * nn.log_sigmoid(logits) + (1 - labels) * (nn.log_sigmoid(logits) - logits))