Exemplo n.º 1
0
    def decode(
            self,
            encoded,
            inputs,  # only needed for masks
            targets,
            targets_positions=None,
            inputs_segmentation=None,
            targets_segmentation=None):
        """Applies Transformer decoder-branch on encoded-input and target.

    Args:
      encoded: encoded input data from encoder.
      inputs: input data (only needed for masking).
      targets: target data.
      targets_positions: target subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.

    Returns:
      logits array from transformer decoder.
    """
        config = self.config

        # Make padding attention masks.
        if config.decode:
            # for fast autoregressive decoding only a special encoder-decoder mask is used
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(targets > 0,
                                       targets > 0,
                                       dtype=config.dtype),
                nn.make_causal_mask(targets, dtype=config.dtype))
            encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                          inputs > 0,
                                                          dtype=config.dtype)

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       targets_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))
            encoder_decoder_mask = nn.combine_masks(
                encoder_decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))
        logits = self.decoder(encoded,
                              targets,
                              targets_positions=targets_positions,
                              decoder_mask=decoder_mask,
                              encoder_decoder_mask=encoder_decoder_mask)
        return logits.astype(self.config.dtype)
Exemplo n.º 2
0
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        preshift_programs = programs  # Save pre-shifted programs for padding mask.
        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            # BOS tokens attend to all previous BOS tokens.
            decoder_bos_mask = nn.combine_masks(
                nn.make_attention_mask(programs == cfg.bos_token,
                                       programs == cfg.bos_token,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            # Program tokens attend to all previous tokens in partial program.
            decoder_partial_mask = nn.combine_masks(
                make_partial_program_mask(programs,
                                          bos_token=cfg.bos_token,
                                          dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(preshift_programs > 0,
                                       preshift_programs > 0,
                                       dtype=cfg.dtype),
                jnp.logical_or(decoder_bos_mask, decoder_partial_mask))
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
Exemplo n.º 3
0
  def decode(self,
             programs,
             encoded,
             encoded_padding_mask):
    """Applies decoder on programs and encoded specification."""
    cfg = self.config

    assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                ' but it is: %d' % programs.ndim)
    assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                               ' but it is: %d' % encoded.ndim)

    # Collapse num_io dimension
    flat_encoded = flatten_num_io_dim(encoded)
    flat_encoded_padding_mask = flatten_num_io_dim(encoded_padding_mask)

    # Make attention masks.
    if cfg.decode:
      # For fast decode with caching, programs shape == [batch_size, 1] and
      # cfg.shift = False, cfg.decode = True.
      decoder_mask = None
      encoder_decoder_mask = nn.make_attention_mask(
          jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype)
    else:
      decoder_mask = nn.combine_masks(
          nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype),
          nn.make_causal_mask(programs, dtype=cfg.dtype))
      encoder_decoder_mask = nn.make_attention_mask(
          programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

    return self.decoder(
        programs, flat_encoded, decoder_mask, encoder_decoder_mask)
Exemplo n.º 4
0
    def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
        """Applies Transformer encoder-branch on the inputs.

    Args:
      inputs: input data.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      encoded feature array from the transformer encoder.
    """
        cfg = self.config
        # Make padding attention mask.
        encoder_mask = nn.make_attention_mask(inputs > 0,
                                              inputs > 0,
                                              dtype=cfg.dtype)
        # Add segmentation block-diagonal attention mask if using segmented data.
        if inputs_segmentation is not None:
            encoder_mask = nn.combine_masks(
                encoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=cfg.dtype))
        return self.encoder(inputs,
                            inputs_positions=inputs_positions,
                            encoder_mask=encoder_mask)
Exemplo n.º 5
0
  def decode(self,
             programs,
             encoded,
             encoded_padding_mask):
    """Applies decoder on programs and encoded specification."""
    cfg = self.config

    # Allow for decoding without num_partial dimension for beam search.
    # programs shape == [batch_size, (num_partial), length]
    assert programs.ndim in [2, 3], ('Number of program dimensions should be'
                                     '2 or 3, but it is: %d' % programs.ndim)
    assert encoded.ndim == programs.ndim + 2

    # Collapse num_io dimension.
    num_io_axis = 1 if programs.ndim == 2 else 2
    flat_encoded = base_models.flatten_num_io_dim(encoded, axis=num_io_axis)
    flat_encoded_padding_mask = base_models.flatten_num_io_dim(
        encoded_padding_mask, axis=num_io_axis)

    # Make attention masks.
    if cfg.decode:
      # For fast decode with caching, programs shape == [batch_size, 1] and
      # cfg.shift = False, cfg.decode = True.
      decoder_mask = None
      encoder_decoder_mask = nn.make_attention_mask(
          jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype)
    else:
      decoder_mask = nn.combine_masks(
          nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype),
          nn.make_causal_mask(programs, dtype=cfg.dtype))
      encoder_decoder_mask = nn.make_attention_mask(
          programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

    return self.decoder(
        programs, flat_encoded, decoder_mask, encoder_decoder_mask)
Exemplo n.º 6
0
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slighly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # detect if we're initializing by absence of existing cache data.
        is_initialized = self.has_variable("cache", "cached_key")
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape,
                                   key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros,
                                     value.shape, value.dtype)
        cache_index = self.variable("cache", "cache_index",
                                    lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # update key, value caches with our new 1d spatial slices
            cur_index = cache_index.value
            indices = (0, ) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value,
                                             indices)
            cached_key.value = key
            cached_value.value = value
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            attention_mask = combine_masks(pad_mask, attention_mask)
        return key, value, attention_mask
Exemplo n.º 7
0
    def decode(self,
               encoded,
               inputs,
               targets,
               targets_positions=None,
               inputs_segmentation=None,
               targets_segmentation=None,
               train=False):
        # Make padding attention masks.
        dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32
        if self.should_decode:
            # For fast autoregressive decoding, only a special encoder-decoder mask is
            # used.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype),
                nn.make_causal_mask(targets, dtype=dtype))
            encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                          inputs > 0,
                                                          dtype=dtype)

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       targets_segmentation,
                                       jnp.equal,
                                       dtype=dtype))
            encoder_decoder_mask = nn.combine_masks(
                encoder_decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        logits = self.decoder(encoded,
                              targets,
                              targets_positions=targets_positions,
                              decoder_mask=decoder_mask,
                              encoder_decoder_mask=encoder_decoder_mask,
                              train=train)
        return logits
Exemplo n.º 8
0
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
    ):
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        causal_attention_mask = None
        if self.causal:
            query_length, key_length = query.shape[1], key.shape[1]
            causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]

        if attention_mask is not None and causal_attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4")
        elif causal_attention_mask is not None:
            attention_mask = causal_attention_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs
Exemplo n.º 9
0
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies TransformerLM on the inputs.

    Args:
      inputs: target data.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      logits array from transformer decoder.
    """
        config = self.config

        # Make padding attention masks.
        if config.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0,
                                       inputs > 0,
                                       dtype=config.dtype),
                nn.make_causal_mask(inputs, dtype=config.dtype))

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))

        logits = Decoder(config=config, shared_embedding=None, name='decoder')(
            inputs,
            inputs_positions=inputs_positions,
            inputs_segmentation=inputs_segmentation,
            decoder_mask=decoder_mask,
            encoder_decoder_mask=None)
        return logits.astype(self.config.dtype)
Exemplo n.º 10
0
    def decode(self, programs, latents, encoded, latents_padding_mask,
               encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert latents.ndim == 3, ('Number of latents dimensions should be 3,'
                                   ' but it is: %d' % latents.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = models.flatten_num_io_dim(
            encoded_padding_mask)

        latents = self.latent_pos_emb(latents)
        # Concatenate the i/o encoding and latents together.
        flat_encoded = jnp.concatenate([flat_encoded, latents], axis=1)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            decoder_mask = None
            latent_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs), latents_padding_mask, dtype=cfg.dtype)
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
            encoder_decoder_mask = jnp.concatenate(
                [encoder_decoder_mask, latent_decoder_mask], axis=-1)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(programs > 0,
                                       programs > 0,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            latent_decoder_mask = nn.make_attention_mask(programs > 0,
                                                         latents_padding_mask,
                                                         dtype=cfg.dtype)
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)
            encoder_decoder_mask = jnp.concatenate(
                [encoder_decoder_mask, latent_decoder_mask], axis=-1)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
Exemplo n.º 11
0
 def get_attention_masks(self, inputs, targets):
     cfg = self.config
     if cfg.decode:
         decoder_mask = None
         encoder_decoder_mask = nn.make_attention_mask(
             jnp.ones_like(targets) > 0, inputs > 0)
     else:
         decoder_mask = nn.combine_masks(
             nn.make_attention_mask(targets > 0,
                                    targets > 0,
                                    dtype=cfg.dtype),
             nn.make_causal_mask(targets, dtype=cfg.dtype))
         encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                       inputs > 0,
                                                       dtype=cfg.dtype)
     return decoder_mask, encoder_decoder_mask
Exemplo n.º 12
0
    def encode(self,
               inputs,
               inputs_positions=None,
               inputs_segmentation=None,
               train=False):
        # Make padding attention mask.
        encoder_mask = nn.make_attention_mask(inputs > 0,
                                              inputs > 0,
                                              dtype=inputs.dtype)
        # Add segmentation block-diagonal attention mask if using segmented data.
        if inputs_segmentation is not None:
            encoder_mask = nn.combine_masks(
                encoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=inputs_segmentation.dtype))
        encoded = self.encoder(inputs,
                               inputs_positions=inputs_positions,
                               encoder_mask=encoder_mask,
                               train=train)

        return encoded
Exemplo n.º 13
0
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        qkv_out = self.c_attn(hidden_states)
        query, key, value = jnp.split(qkv_out, 3, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
        )

        # usual dot product attention
        attn_output = dot_product_attention(
            query,
            key,
            value,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        # TODO: at the moment it's not possible to retrieve attn_weights from
        # dot_product_attention, but should be in the future -> add functionality then

        return (attn_output, )
Exemplo n.º 14
0
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        # get query proj
        query_states = self.q_proj(hidden_states)
        # get key, value proj
        if is_cross_attention:
            # cross_attentions
            key_states = self.k_proj(key_value_states)
            value_states = self.v_proj(key_value_states)
        else:
            # self_attention
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = self._split_heads(query_states)
        key_states = self._split_heads(key_states)
        value_states = self._split_heads(value_states)

        # handle cache prepare causal attention mask
        if self.causal:
            query_length, key_length = query_states.shape[1], key_states.shape[
                1]
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states, value_states, query_states, attention_mask)

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape,
                         jnp.finfo(self.dtype).min).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights,
                                 value_states)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights
Exemplo n.º 15
0
    def __call__(
        self,
        hidden_states,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        if not is_cross_attention:
            qkv_out = self.c_attn(hidden_states)
            query, key, value = jnp.split(qkv_out, 3, axis=2)
        else:
            q_out = self.q_attn(hidden_states)
            (query, ) = jnp.split(q_out, 1, axis=2)
            kv_out = self.c_attn(key_value_states)
            key, value = jnp.split(kv_out, 2, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.causal:
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Exemplo n.º 16
0
    def __call__(self,
                 inputs,
                 train,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      train: bool: if model is training.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer decoder.
    """
        assert inputs.ndim == 2  # (batch, len)
        dtype = utils.dtype_from_str(self.model_dtype)

        if self.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype),
                nn.make_causal_mask(inputs, dtype=dtype))

        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        y = inputs.astype('int32')
        if not self.decode:
            y = shift_inputs(y, segment_ids=inputs_segmentation)

        # TODO(gdahl,znado): this code appears to be accessing out-of-bounds
        # indices for dataset_lib:proteins_test. This will break when jnp.take() is
        # updated to return NaNs for out-of-bounds indices.
        # Debug why this is the case.
        y = jnp.clip(y, 0, self.vocab_size - 1)

        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            output_embed = self.shared_embedding

        y = output_embed(y)

        y = AddPositionEmbs(max_len=self.max_len,
                            posemb_init=sinusoidal_init(max_len=self.max_len),
                            decode=self.decode,
                            name='posembed_output')(
                                y, inputs_positions=inputs_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train)

        y = y.astype(dtype)

        for _ in range(self.num_layers):
            y = Transformer1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                attention_fn=self.attention_fn,
                normalizer=self.normalizer,
                dtype=dtype)(
                    inputs=y,
                    train=train,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=None,
                    inputs_positions=None,
                    inputs_segmentation=None,
                )
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(self.normalizer,
                                                         train,
                                                         dtype=dtype)
            y = maybe_normalize()(y)

        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])
        else:
            logits = nn.Dense(self.vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              dtype=dtype,
                              name='logits_dense')(y)

        return logits.astype(dtype)
Exemplo n.º 17
0
    def __call__(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        sincos = jnp.take(self.embed_positions, position_ids, axis=0)
        sincos = jnp.split(sincos, 2, axis=-1)
        if self.rotary_dim is not None:
            k_rot = key[:, :, :, :self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim:]

            q_rot = query[:, :, :, :self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim:]

            k_rot = apply_rotary_pos_emb(k_rot, sincos)
            q_rot = apply_rotary_pos_emb(q_rot, sincos)

            key = jnp.concatenate([k_rot, k_pass], axis=-1)
            query = jnp.concatenate([q_rot, q_pass], axis=-1)
        else:
            key = apply_rotary_pos_emb(key, sincos)
            query = apply_rotary_pos_emb(query, sincos)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
        )

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config.base_config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        decoder_mask = None
        decoder_relative_position = None  # Relative positions.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            attention_mask_type = self.config.attention_mask_type
            if attention_mask_type == 'baseline':
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
            else:
                if attention_mask_type == 'bos_to_bos':
                    # BOS tokens attend to all previous BOS tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                elif attention_mask_type == 'bos_to_last':
                    # BOS tokens attend to all last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = shift_left(bos_mask)
                elif attention_mask_type == 'bos_to_bos_and_last':
                    # BOS tokens attend to all previous BOS + last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = jnp.logical_or(bos_mask,
                                                      shift_left(bos_mask))
                elif attention_mask_type == 'bos_full_attention':
                    # BOS tokens attend to all previous tokens, including program tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs > 0,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                else:
                    raise ValueError(
                        'Unhandled attention_mask_type: {}'.format(
                            attention_mask_type))
                # Program tokens attend to all previous tokens in partial program.
                decoder_partial_mask = nn.combine_masks(
                    make_partial_program_mask(programs,
                                              bos_token=cfg.bos_token,
                                              dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    jnp.logical_or(decoder_bos_mask, decoder_partial_mask))

                if self.config.bos_special_attention:
                    # Make custom relative positions where BOS are separately indexed.
                    decoder_relative_position = make_relative_position(
                        programs)
                    decoder_partial_relative_position = (
                        make_partial_program_relative_position(
                            programs, bos_token=cfg.bos_token))
                    decoder_relative_position = jnp.where(
                        (programs == cfg.bos_token)[Ellipsis, None],
                        decoder_partial_relative_position,
                        decoder_relative_position)
                else:
                    decoder_relative_position = None

            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask, decoder_relative_position)