Beispiel #1
0
def explicit_mlp(scope, x, sizes=(3, 1)):
  for i, size in enumerate(sizes):
    dense = scope.param(f'dense_{i}', ExplicitDense.create, x.shape[-1], size)
    x = dense(x)
    if i + 1 < len(sizes):
      x = nn.relu(x)
  return x
Beispiel #2
0
def semi_explicit_mlp(scope, x, sizes=(3, 1)):
  for i, size in enumerate(sizes):
    dense = scope.child(ExplicitDense.create_in_scope, prefix='dense_')(x.shape[-1], size)
    x = dense(x)
    if i + 1 < len(sizes):
      x = nn.relu(x)
  return x
Beispiel #3
0
def mlp(scope: Scope, x: Array, hidden: int, out: int):
  x = scope.child(nn.dense, 'hidden')(x, hidden)
  x = nn.relu(x)
  return scope.child(nn.dense, 'out')(x, out)
Beispiel #4
0
    dropout_rate=config.attention_dropout_rate,
    deterministic=config.deterministic,
    cache=config.decode
  )


def mlp_block(scope, inputs, config: TransformerConfig)
  """Applies Transformer MlpBlock module."""
  dense = functools.partial(nn.dense,
      dtype=config.dtype,
      kernel_init=config.kernel_init, bias_init=config.bias_init)
  dropout = functools.partial(nn.dropout,
      rate=config.dropout_rate,
      deterministic=config.deterministic)
  x = scope.child(dense)(inputs, config.mlp_dim)
  x = nn.relu(x)
  x = scope.child(dropout)(x)
  output = scope.child(dense)(x, inputs.shape[-1])
  output = scope.child(dropout)(output)
  return output


def encoder_1d_block(scope, inputs, config: TransformerConfig):
  """Applies Encoder1DBlock module.
  Args:
    inputs: input data.
  Returns:
    output after transformer encoder block.
  """
  norm = functools.partial(nn.layer_norm, dtype=dtype)