def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=True, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim x = nn.Dense( inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.gelu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense( x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) return output
def apply_activation(intermediate_output, intermediate_activation): """Applies selected activation function to intermediate output.""" if intermediate_activation is None: return intermediate_output if intermediate_activation == 'gelu': intermediate_output = nn.gelu(intermediate_output) elif intermediate_activation == 'relu': intermediate_output = nn.relu(intermediate_output) elif intermediate_activation == 'sigmoid': intermediate_output = nn.sigmoid(intermediate_output) elif intermediate_activation == 'softmax': intermediate_output = nn.softmax(intermediate_output) elif intermediate_activation == 'celu': intermediate_output = nn.celu(intermediate_output) elif intermediate_activation == 'elu': intermediate_output = nn.elu(intermediate_output) elif intermediate_activation == 'log_sigmoid': intermediate_output = nn.log_sigmoid(intermediate_output) elif intermediate_activation == 'log_softmax': intermediate_output = nn.log_softmax(intermediate_output) elif intermediate_activation == 'soft_sign': intermediate_output = nn.soft_sign(intermediate_output) elif intermediate_activation == 'softplus': intermediate_output = nn.softplus(intermediate_output) elif intermediate_activation == 'swish': intermediate_output = nn.swish(intermediate_output) elif intermediate_activation == 'tanh': intermediate_output = jnp.tanh(intermediate_output) else: raise NotImplementedError( '%s activation function is not yet supported.' % intermediate_activation) return intermediate_output