예제 #1
0
  def test_decoding(self, spatial_shape, attn_dims):
    bs = 2
    num_heads = 3
    num_features = 4
    rng = random.PRNGKey(0)
    key1, key2 = random.split(rng)
    inputs = random.normal(
        key1, (bs,) + spatial_shape + (num_heads * num_features,))
    module = nn.MultiHeadDotProductAttention(
        num_heads=num_heads,
        qkv_features=num_heads * num_features,
        attention_axis=attn_dims,
        causal_mask=True,
        precision=lax.Precision.HIGHEST,
        decode=False)
    decode_module = module.clone(decode=True)

    initial_vars = decode_module.init(key2, inputs, inputs)
    y_ref = jax.jit(lambda x: module.apply(initial_vars, x, x))(inputs)
    # feed the inputs sequentially to simulate decoding
    def body_fn(vars_in, x):
      y, vars_out = decode_module.apply(vars_in, x, x,
                                        decode=True, mutable=['cache'])
      return vars_out, y
    # scan_in_dim supports scanning multiple dims
    _, y = jax_utils.scan_in_dim(body_fn, initial_vars, inputs,
                                    axis=attn_dims, keepdims=True)

    np.testing.assert_allclose(y_ref, y, atol=1e-5)
예제 #2
0
    def __call__(self, inputs, *, deterministic):
        """Applies Encoder1DBlock module.

    Args:
      inputs: Inputs to the layer.
      deterministic: Dropout will not be applied when set to true.

    Returns:
      output after transformer encoder block.
    """

        # Attention block.
        assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}'
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = nn.MultiHeadDotProductAttention(
            dtype=self.dtype,
            kernel_init=nn.initializers.xavier_uniform(),
            broadcast_dropout=False,
            deterministic=deterministic,
            dropout_rate=self.attention_dropout_rate,
            num_heads=self.num_heads)(x, x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = MlpBlock(mlp_dim=self.mlp_dim,
                     dtype=self.dtype,
                     dropout_rate=self.dropout_rate)(
                         y, deterministic=deterministic)

        return x + y
예제 #3
0
  def __call__(self,
               targets,
               encoded,
               decoder_mask=None,
               encoder_decoder_mask=None):
    """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output after transformer encoder-decoder block.
    """
    cfg = self.config

    # Decoder block.
    assert targets.ndim == 3
    x = nn.LayerNorm(dtype=cfg.dtype)(targets)
    x = nn.SelfAttention(
        num_heads=cfg.num_heads,
        dtype=cfg.dtype,
        qkv_features=cfg.qkv_dim,
        kernel_init=cfg.kernel_init,
        bias_init=cfg.bias_init,
        use_bias=False,
        broadcast_dropout=False,
        dropout_rate=cfg.attention_dropout_rate,
        deterministic=cfg.deterministic,
        decode=cfg.decode)(x, decoder_mask)
    x = nn.Dropout(rate=cfg.dropout_rate)(
        x, deterministic=cfg.deterministic)
    x = x + targets

    # Encoder-Decoder block.
    y = nn.LayerNorm(dtype=cfg.dtype)(x)
    y = nn.MultiHeadDotProductAttention(
        num_heads=cfg.num_heads,
        dtype=cfg.dtype,
        qkv_features=cfg.qkv_dim,
        kernel_init=cfg.kernel_init,
        bias_init=cfg.bias_init,
        use_bias=False,
        broadcast_dropout=False,
        dropout_rate=cfg.attention_dropout_rate,
        deterministic=cfg.deterministic)(
            y, encoded, encoder_decoder_mask)

    y = nn.Dropout(rate=cfg.dropout_rate)(
        y, deterministic=cfg.deterministic)
    y = y + x

    # MLP block.
    z = nn.LayerNorm(dtype=cfg.dtype)(y)
    z = MlpBlock(config=cfg)(z)

    return y + z
예제 #4
0
    def __call__(self,
                 targets,
                 encoded,
                 decoder_mask=None,
                 encoder_decoder_mask=None):
        """Applies Transformer block.

    Args:
      targets: input data for decoder `[batch_size, ..., length, dim]`
      encoded: input data from encoder `[batch_size, ..., length2, dim2]`
      decoder_mask: decoder self-attention mask
      encoder_decoder_mask: encoder-decoder attention mask

    Returns:
      Decoded data `[batch_size, ..., length, mlp_dim]`
    """
        cfg = self.config

        # Decoder block.
        x = nn.LayerNorm(dtype=cfg.dtype)(targets)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=cfg.deterministic,
                             decode=cfg.decode)(x, decoder_mask)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + targets

        # Encoder-Decoder block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = nn.MultiHeadDotProductAttention(
            num_heads=cfg.num_heads,
            dtype=cfg.dtype,
            qkv_features=cfg.qkv_dim,
            kernel_init=cfg.kernel_init,
            bias_init=cfg.bias_init,
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=cfg.attention_dropout_rate,
            deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask)
        y = nn.Dropout(rate=cfg.dropout_rate)(y,
                                              deterministic=cfg.deterministic)
        y = y + x

        # MLP block.
        z = nn.LayerNorm(dtype=cfg.dtype)(y)
        z = MLPBlock(config=cfg)(z)

        return y + z
예제 #5
0
 def test_multihead_encoder_decoder_attention(self):
   rng = random.PRNGKey(0)
   q = jnp.ones((4, 2, 3, 5))
   kv = jnp.ones((4, 2, 3, 5))
   sa_module = nn.MultiHeadDotProductAttention(
       num_heads=8,
       qkv_features=16,
       kernel_init=initializers.ones,
       bias_init=initializers.zeros,
   )
   y, _ = sa_module.init_with_output(rng, q, kv)
   self.assertEqual(y.shape, q.shape)
예제 #6
0
 def test_multihead_self_attention_w_dropout(self):
   rng = random.PRNGKey(0)
   x = jnp.ones((4, 2, 3, 5))
   sa_module = nn.MultiHeadDotProductAttention(
       num_heads=8,
       qkv_features=16,
       kernel_init=initializers.ones,
       bias_init=initializers.zeros,
       dropout_rate=0.1,
   )
   rng1, rng2 = random.split(rng)
   rngs = {'params': rng1, 'dropout': rng2}
   y, _ = sa_module.init_with_output(rngs, x, x)
   self.assertEqual(y.shape, x.shape)
예제 #7
0
    def __call__(self, x):
        # TODO(lbeyer): condition on GAP(x)
        n, _, d = x.shape
        probe = self.param('probe', nn.initializers.xavier_uniform(),
                           (1, 1, d), x.dtype)
        probe = jnp.tile(probe, [n, 1, 1])

        x = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            kernel_init=nn.initializers.xavier_uniform())(probe, x)

        # TODO(lbeyer): dropout on head?
        y = nn.LayerNorm()(x)
        x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
        return x[:, 0]
  def __call__(self, images: jnp.ndarray, train: Optional[bool] = None):
    train = nn.module.merge_param("train", self.train, train)
    transformer = self.transformer or {}
    # Convert images to patches.
    x = self.patches(images, self.hidden_size, self.patch_size, self.patch_grid)
    # Add "class" token if necessary.
    n, _, c = x.shape
    if self.classifier == "token":
      cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size))
      cls = jnp.tile(cls, [n, 1, 1])
      x = jnp.concatenate([cls, x], axis=1)
    # Encode tokens.
    x, extra_info = BatchEnsembleEncoder(
        train=train, name="BatchEnsembleTransformer", **transformer)(
            x)
    # Reduce tokens to a single vector representation.
    if self.classifier == "token":
      # Take the first token's output as representation as in BERT.
      x = x[:, 0]
    elif self.classifier == "gap":
      # Average all tokens.
      x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1)))  # (1,) or (1, 2)
    elif self.classifier == "map":
      probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c))
      probe = jnp.tile(probe, [n, 1, 1])
      attention = nn.MultiHeadDotProductAttention(
          deterministic=not train,
          num_heads=transformer.get("attention", {}).get("num_heads", 1),
          kernel_init=nn.initializers.xavier_uniform())
      x = attention(inputs_q=probe, inputs_kv=x)
      y = nn.LayerNorm()(x)
      y = patch_transformer_lib.MlpBlock(
          mlp_dim=transformer["mlp_dim"],
          dropout_rate=0,
          deterministic=not train)(y)
      x = (x + y)[:, 0]
    else:
      raise ValueError(f"Unknown classifier: {self.classifier}")

    if self.representation_size is None:
      x = identity.IdentityLayer(name="pre_logits")(x)
    else:
      x = nn.Dense(self.representation_size, name="pre_logits")(x)
      x = nn.tanh(x)

    x = nn.Dense(self.num_classes, kernel_init=self.head_kernel_init,
                 name="head")(x)
    return x, extra_info
예제 #9
0
    def test_autoregresive_receptive_field_1d(self):
        """Tests the autoregresive self-attention receptive field."""
        rng = random.PRNGKey(0)
        rng1, rng2 = random.split(rng, num=2)

        length = 10
        dim = 1
        num_heads = 1
        input_shape = (1, length, dim)
        inputs = random.normal(rng2, input_shape)

        module = nn.MultiHeadDotProductAttention(
            num_heads=num_heads,
            kernel_init=jax.nn.initializers.ones,
            deterministic=False)

        initial_vars = module.init(rng1, inputs, inputs)
        causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1]))

        def model_loss(inputs, pos):
            out = module.apply(initial_vars, inputs, inputs, causal_mask)
            assert out.shape == input_shape
            assert len(out.shape) == 3
            return out[0, pos, :].sum()

        grad_fn = jax.jit(jax.grad(model_loss))

        def get_receptive_field_1d(pos):
            g = grad_fn(inputs, pos)[0, :, :]
            return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1)

        for i in range(length):
            deps = get_receptive_field_1d(i)
            assert (deps[:i] == 1).all(), (
                'Receptive Field Error: Some of the '
                'previous postions are not reachable '
                'in autoregressive self-attention.')
            if i != length - 1:
                k = i + 1
                assert (deps[k:] == 0).all(), (
                    'Receptive Field Error: Some of the '
                    'future postions are reachable in '
                    'autoregressive self-attention.')
예제 #10
0
  def __call__(self,
               inputs: jnp.ndarray,
               *,
               deterministic: Optional[bool] = None):
    """Applies Encoder1Dlock module."""
    assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"

    x = nn.LayerNorm(dtype=self.dtype, name="LayerNorm_0")(inputs)
    x = nn.MultiHeadDotProductAttention(
        dtype=self.dtype,
        kernel_init=nn.initializers.xavier_uniform(),
        broadcast_dropout=False,
        deterministic=deterministic,
        name="MultiHeadDotProductAttention_1",
        num_heads=self.num_heads,
        dropout_rate=self.attention_dropout_rate)(x, x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(dtype=self.dtype, name="LayerNorm_2")(x)
    y = self.mlp_class(name="MlpBlock_3")(y, deterministic=deterministic)

    return x + y
예제 #11
0
  def __call__(self,
               targets,
               encoded,
               decoder_mask = None,
               encoder_decoder_mask = None,
               decoder_relative_position = None,
               encoder_decoder_relative_position = None):
    """Applies Transformer block.

    Args:
      targets: input data for decoder `[batch_size, ..., length, dim]`
      encoded: input data from encoder `[batch_size, ..., length2, dim2]`
      decoder_mask: decoder self-attention mask
      encoder_decoder_mask: encoder-decoder attention mask
      decoder_relative_position: decoder relative positions tensor
          `[batch_sizes..., length2, length2]'
      encoder_decoder_relative_position: encoder-decoder relative tensor
          `[batch_sizes..., length2, length]'

    Returns:
      Decoded data `[batch_size, ..., length2, mlp_dim]`
    """
    cfg = self.config

    # Decoder block.
    x = nn.LayerNorm(dtype=cfg.dtype)(targets)
    if cfg.use_relative_attention:
      x = relative_attention.RelativeSelfAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic,
          bidirectional=self.bidirectional_attention,
          num_relative_position_buckets=self.num_relative_position_buckets,
          max_distance=self.max_distance)(
              x, decoder_mask, decoder_relative_position)
    else:
      x = nn.SelfAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic)(x, decoder_mask)

    x = nn.Dropout(rate=cfg.dropout_rate)(
        x, deterministic=cfg.deterministic)
    x = x + targets

    # Encoder-Decoder block.
    y = nn.LayerNorm(dtype=cfg.dtype)(x)
    if self.relative_cross_attention:
      y = relative_attention.RelativeMultiHeadDotProductAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic,
          bidirectional=self.bidirectional_cross_attention,
          num_relative_position_buckets=(
              self.num_relative_position_buckets_cross_attention),
          max_distance=self.max_distance_cross_attention)(
              y, encoded, encoder_decoder_mask,
              encoder_decoder_relative_position)
    else:
      y = nn.MultiHeadDotProductAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask)

    y = nn.Dropout(rate=cfg.dropout_rate)(
        y, deterministic=cfg.deterministic)
    y = y + x

    # MLP block.
    z = nn.LayerNorm(dtype=cfg.dtype)(y)
    z = MLPBlock(config=cfg)(z)

    return y + z
    def __call__(self,
                 images: jnp.ndarray,
                 train: Optional[bool] = None,
                 mean_field_factor: float = -1.,
                 **gp_kwargs):
        train = nn.module.merge_param("train", self.train, train)
        transformer = self.transformer or {}
        # Convert images to patches.
        x = self.patches(images, self.hidden_size, self.patch_size,
                         self.patch_grid)
        # Add "class" token if necessary.
        n, _, c = x.shape
        if self.classifier == "token":
            cls = self.param("cls", nn.initializers.zeros,
                             (1, 1, self.hidden_size))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)
        # Encode tokens.
        x, extra_info = vit_batchensemble.BatchEnsembleEncoder(
            train=train, name="Transformer", **transformer)(x)
        # Reduce tokens to a single vector representation.
        if self.classifier == "token":
            # Take the first token's output as representation as in BERT.
            x = x[:, 0]
        elif self.classifier == "gap":
            # Average all tokens.
            x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1)))  # (1,) or (1, 2)
        elif self.classifier == "map":
            probe = self.param("probe", nn.initializers.xavier_uniform(),
                               (1, 1, c))
            # x may have been subject to tiling, n can be different from x.shape[0].
            probe = jnp.tile(probe, [x.shape[0], 1, 1])
            attention = nn.MultiHeadDotProductAttention(
                deterministic=not train,
                num_heads=transformer.get("attention", {}).get("num_heads", 1),
                kernel_init=nn.initializers.xavier_uniform())
            x = attention(inputs_q=probe, inputs_kv=x)
            y = nn.LayerNorm()(x)
            y = vit.MlpBlock(mlp_dim=transformer["mlp_dim"],
                             dropout_rate=0)(y, deterministic=not train)
            x = (x + y)[:, 0]
        else:
            raise ValueError(f"Unknown classifier: {self.classifier}")

        if self.representation_size is None:
            x = vit.IdentityLayer(name="pre_logits")(x)
            extra_info["pre_logits"] = x
        else:
            x = nn.Dense(self.representation_size, name="pre_logits")(x)
            extra_info["pre_logits"] = x
            x = nn.tanh(x)

        if self.use_gp_layer:
            x_gp = self.gp_layer(x, **gp_kwargs)
            # Gaussian process layer output: a tuple of logits, covmat, and optionally
            # random features.
            extra_info["covmat"] = x_gp[1]
            if len(x_gp) > 2:
                extra_info["random_features"] = x_gp[2]
            if train:
                x = x_gp[0]
            else:
                # During inference, compute posterior mean by adjusting the original
                # logits with predictive uncertainty.
                x = ed.nn.utils.mean_field_logits(
                    logits=x_gp[0],
                    covmat=x_gp[1],
                    mean_field_factor=mean_field_factor)
        else:
            x = nn.Dense(self.num_classes,
                         kernel_init=self.head_kernel_init,
                         name="batchensemble_head")(x)
        return x, extra_info
예제 #13
0
  def __call__(self, images: jnp.ndarray, train: Optional[bool] = None):
    train = nn.module.merge_param("train", self.train, train)
    transformer = self.transformer or {}
    # Convert images to patches.
    x = self.embed(images, self.hidden_size, self.patches.size)
    # Add "class" token if necessary.
    n, _, c = x.shape
    if self.classifier == "token":
      cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size))
      cls = jnp.tile(cls, [n, 1, 1])
      x = jnp.concatenate([cls, x], axis=1)
    # Encode tokens.
    x, extra_info = BatchEnsembleEncoder(
        train=train, name="Transformer", **transformer)(
            x)
    # Reduce tokens to a single vector representation.
    if self.classifier == "token":
      # Take the first token's output as representation as in BERT.
      x = x[:, 0]
    elif self.classifier == "gap":
      # Average all tokens.
      x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1)))  # (1,) or (1, 2)
    elif self.classifier == "map":
      probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c))
      # x may have been subject to tiling, n can be different from x.shape[0].
      probe = jnp.tile(probe, [x.shape[0], 1, 1])
      attention = nn.MultiHeadDotProductAttention(
          deterministic=not train,
          num_heads=transformer.get("attention", {}).get("num_heads", 1),
          kernel_init=nn.initializers.xavier_uniform())
      x = attention(inputs_q=probe, inputs_kv=x)
      y = nn.LayerNorm()(x)
      y = vit.MlpBlock(
          mlp_dim=transformer["mlp_dim"], dropout_rate=0)(
              y, deterministic=not train)
      x = (x + y)[:, 0]
    else:
      raise ValueError(f"Unknown classifier: {self.classifier}")

    if self.representation_size is None:
      x = IdentityLayer(name="pre_logits")(x)
      extra_info["pre_logits"] = x
    else:
      x = ed.nn.DenseBatchEnsemble(
          self.representation_size,
          self.transformer.get("ens_size"),
          activation=None,
          alpha_init=ed.nn.utils.make_sign_initializer(
              self.transformer.get("random_sign_init")),
          gamma_init=ed.nn.utils.make_sign_initializer(
              self.transformer.get("random_sign_init")),
          name="pre_logits")(x)
      extra_info["pre_logits"] = x
      x = nn.tanh(x)

    x = ed.nn.DenseBatchEnsemble(
        self.num_classes,
        self.transformer.get("ens_size"),
        activation=None,
        alpha_init=ed.nn.utils.make_sign_initializer(
            self.transformer.get("random_sign_init")),
        gamma_init=ed.nn.utils.make_sign_initializer(
            self.transformer.get("random_sign_init")),
        kernel_init=self.head_kernel_init,
        name="batchensemble_head")(x)
    return x, extra_info
예제 #14
0
    def __call__(self,
                 inputs,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 inputs_kv=None):
        """Applies EncoderDecoder1DBlock module.

    Args:
      inputs: input data for decoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output after transformer encoder-decoder block.
    """
        cfg = self.config

        # Decoder block.
        assert inputs.ndim == 3
        # assert decoder_mask.ndim == 4
        if cfg.use_layernorm:
            x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        else:
            x = inputs
        if self.is_self_att:
            x = nn.SelfAttention(num_heads=cfg.num_heads,
                                 out_features=self.out_features,
                                 dtype=cfg.dtype,
                                 qkv_features=cfg.qkv_dim,
                                 kernel_init=cfg.kernel_init,
                                 bias_init=cfg.bias_init,
                                 use_bias=False,
                                 broadcast_dropout=False,
                                 dropout_rate=cfg.attention_dropout_rate,
                                 deterministic=cfg.deterministic,
                                 attention_fn=self.attention_fn,
                                 decode=cfg.decode)(x, decoder_mask)
        else:
            if cfg.use_layernorm:
                x_kv = nn.LayerNorm(dtype=cfg.dtype)(inputs_kv)
            else:
                x_kv = inputs_kv
            x = nn.MultiHeadDotProductAttention(
                num_heads=cfg.num_heads,
                out_features=self.out_features,
                dtype=cfg.dtype,
                qkv_features=cfg.qkv_dim,
                kernel_init=cfg.kernel_init,
                bias_init=cfg.bias_init,
                use_bias=False,
                broadcast_dropout=False,
                dropout_rate=cfg.attention_dropout_rate,
                deterministic=cfg.deterministic,
                attention_fn=self.attention_fn,
                decode=cfg.decode)(x, x_kv, mask=decoder_mask)

        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + inputs

        # MLP block.
        if cfg.use_layernorm:
            z = nn.LayerNorm(dtype=cfg.dtype)(x)
        else:
            z = x
        z = MlpBlock(config=cfg)(z)

        return x + z
예제 #15
0
    def __call__(self,
                 targets,
                 encoded,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 train=True):
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.
      train: if it is training.

    Returns:
      output after transformer encoder-decoder block.
    """
        # Decoder block.
        assert targets.ndim == 3
        if self.normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            maybe_post_normalize = model_utils.get_normalizer('none', train)
        elif self.normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer('none', train)
            maybe_post_normalize = model_utils.get_normalizer(
                self.normalizer, train)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(
                self.normalizer))

        x = maybe_pre_normalize()(targets)
        x = nn.SelfAttention(num_heads=self.num_heads,
                             dtype=self.dtype,
                             qkv_features=self.qkv_dim,
                             kernel_init=self.dec_self_attn_kernel_init_fn,
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=self.attention_dropout_rate,
                             decode=self.decode,
                             name='DecoderSelfAttention')(
                                 x, decoder_mask, deterministic=not train)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = x + targets

        x = maybe_post_normalize()(x)
        # Encoder-Decoder block.
        y = maybe_pre_normalize()(x)
        y = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            dtype=self.dtype,
            qkv_features=self.qkv_dim,
            kernel_init=self.dec_cross_attn_kernel_init_fn,
            bias_init=nn.initializers.normal(stddev=1e-6),
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=self.attention_dropout_rate)(y,
                                                      encoded,
                                                      encoder_decoder_mask,
                                                      deterministic=not train)

        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train)
        y = y + x

        y = maybe_post_normalize()(y)
        # MLP block.
        z = maybe_pre_normalize()(y)
        z = MlpBlock(mlp_dim=self.mlp_dim,
                     dtype=self.dtype,
                     dropout_rate=self.dropout_rate,
                     name='MLPBlock')(z, train=train)

        res = y + z
        return maybe_post_normalize()(res)
예제 #16
0
    def __call__(
            self,
            targets,
            encoded,
            inputs_segmentation=None,  # REFACTOR
            targets_segmentation=None,  # REFACTOR
            padding_mask=None,  # REFACTOR
            key_padding_mask=None):  # REFACTOR
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      padding_mask: bool, mask padding tokens
      key_padding_mask: bool, mask padding tokens

    Returns:
      output after transformer encoder-decoder block.
    """
        cfg = self.config

        # Decoder block.
        assert targets.ndim == 3
        x = nn.LayerNorm(dtype=cfg.dtype)(targets)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=True,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=cfg.deterministic,
                             decode=cfg.decode)(
                                 x,
                                 padding_mask=padding_mask,
                                 segmentation=targets_segmentation)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + targets

        # Encoder-Decoder block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = nn.MultiHeadDotProductAttention(
            num_heads=cfg.num_heads,
            dtype=cfg.dtype,
            qkv_features=cfg.qkv_dim,
            attention_axis=(1, ),
            causal_mask=False,
            kernel_init=cfg.kernel_init,
            bias_init=cfg.bias_init,
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=cfg.attention_dropout_rate,
            deterministic=cfg.deterministic)(
                y,
                encoded,
                padding_mask=padding_mask,
                key_padding_mask=key_padding_mask,
                segmentation=targets_segmentation,
                key_segmentation=inputs_segmentation)

        y = nn.Dropout(rate=cfg.dropout_rate)(y,
                                              deterministic=cfg.deterministic)
        y = y + x

        # MLP block.
        z = nn.LayerNorm(dtype=cfg.dtype)(y)
        z = MlpBlock(config=cfg)(z)

        return y + z