コード例 #1
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=True,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6)):
   """Applies Transformer MlpBlock module."""
   actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
   x = nn.Dense(
       inputs,
       mlp_dim,
       dtype=dtype,
       kernel_init=kernel_init,
       bias_init=bias_init)
   x = nn.gelu(x)
   x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
   output = nn.Dense(
       x,
       actual_out_dim,
       dtype=dtype,
       kernel_init=kernel_init,
       bias_init=bias_init)
   output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic)
   return output
コード例 #2
0
def apply_activation(intermediate_output, intermediate_activation):
    """Applies selected activation function to intermediate output."""
    if intermediate_activation is None:
        return intermediate_output

    if intermediate_activation == 'gelu':
        intermediate_output = nn.gelu(intermediate_output)
    elif intermediate_activation == 'relu':
        intermediate_output = nn.relu(intermediate_output)
    elif intermediate_activation == 'sigmoid':
        intermediate_output = nn.sigmoid(intermediate_output)
    elif intermediate_activation == 'softmax':
        intermediate_output = nn.softmax(intermediate_output)
    elif intermediate_activation == 'celu':
        intermediate_output = nn.celu(intermediate_output)
    elif intermediate_activation == 'elu':
        intermediate_output = nn.elu(intermediate_output)
    elif intermediate_activation == 'log_sigmoid':
        intermediate_output = nn.log_sigmoid(intermediate_output)
    elif intermediate_activation == 'log_softmax':
        intermediate_output = nn.log_softmax(intermediate_output)
    elif intermediate_activation == 'soft_sign':
        intermediate_output = nn.soft_sign(intermediate_output)
    elif intermediate_activation == 'softplus':
        intermediate_output = nn.softplus(intermediate_output)
    elif intermediate_activation == 'swish':
        intermediate_output = nn.swish(intermediate_output)
    elif intermediate_activation == 'tanh':
        intermediate_output = jnp.tanh(intermediate_output)
    else:
        raise NotImplementedError(
            '%s activation function is not yet supported.' %
            intermediate_activation)

    return intermediate_output