Ejemplo n.º 1
0
  def decode(self,
             programs,
             encoded,
             encoded_padding_mask):
    """Applies decoder on programs and encoded specification."""
    cfg = self.config

    # Allow for decoding without num_partial dimension for beam search.
    # programs shape == [batch_size, (num_partial), length]
    assert programs.ndim in [2, 3], ('Number of program dimensions should be'
                                     '2 or 3, but it is: %d' % programs.ndim)
    assert encoded.ndim == programs.ndim + 2

    # Collapse num_io dimension.
    num_io_axis = 1 if programs.ndim == 2 else 2
    flat_encoded = base_models.flatten_num_io_dim(encoded, axis=num_io_axis)
    flat_encoded_padding_mask = base_models.flatten_num_io_dim(
        encoded_padding_mask, axis=num_io_axis)

    # Make attention masks.
    if cfg.decode:
      # For fast decode with caching, programs shape == [batch_size, 1] and
      # cfg.shift = False, cfg.decode = True.
      decoder_mask = None
      encoder_decoder_mask = nn.make_attention_mask(
          jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype)
    else:
      decoder_mask = nn.combine_masks(
          nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype),
          nn.make_causal_mask(programs, dtype=cfg.dtype))
      encoder_decoder_mask = nn.make_attention_mask(
          programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

    return self.decoder(
        programs, flat_encoded, decoder_mask, encoder_decoder_mask)
Ejemplo n.º 2
0
    def __call__(self, inputs, outputs):
        """Applies Transformer model to encode the IO specification.

    Args:
      inputs: input data [batch_size, num_io, length]
      outputs: output data [batch_size, num_io, length2]

    Returns:
      Encoded IO data `[batch_size, num_io, length2, dim]`
    """
        cfg = self.config

        # Inputs and outputs shared embeddings.
        embed = nn.Embed(num_embeddings=cfg.vocab_size,
                         features=cfg.emb_dim,
                         embedding_init=nn.initializers.normal(stddev=1.0),
                         name='embed')
        pos_emb = AddPositionEmbs(config=cfg, cache=False, name='posembed_io')

        x = inputs.astype('int32')
        y = outputs.astype('int32')

        # Make attention masks.
        inputs_encoder_mask = nn.make_attention_mask(x > 0,
                                                     x > 0,
                                                     dtype=cfg.dtype)
        outputs_encoder_mask = nn.make_attention_mask(y > 0,
                                                      y > 0,
                                                      dtype=cfg.dtype)
        encoder_decoder_mask = nn.make_attention_mask(y > 0,
                                                      x > 0,
                                                      dtype=cfg.dtype)

        # Embed inputs.
        x = embed(x)
        x = pos_emb(x)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)

        x = x.astype(cfg.dtype)
        for lyr in range(cfg.num_layers):
            x = EncoderBlock(  # Attend to inputs.
                config=cfg, name=f'encoderblock_{lyr}')(x, inputs_encoder_mask)
        x = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)

        # Embed outputs.
        y = embed(y)
        y = pos_emb(y)
        y = nn.Dropout(rate=cfg.dropout_rate)(y,
                                              deterministic=cfg.deterministic)

        encode_decoder_cfg = cfg.replace(decode=False)
        for lyr in range(cfg.num_layers):
            y = EncoderDecoderBlock(  # Double attend to inputs and outputs.
                config=encode_decoder_cfg,
                name=f'encoderdecoderblock_{lyr}')(y, x, outputs_encoder_mask,
                                                   encoder_decoder_mask)
        y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

        return y
Ejemplo n.º 3
0
    def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
        """Applies Transformer encoder-branch on the inputs.

    Args:
      inputs: input data.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      encoded feature array from the transformer encoder.
    """
        cfg = self.config
        # Make padding attention mask.
        encoder_mask = nn.make_attention_mask(inputs > 0,
                                              inputs > 0,
                                              dtype=cfg.dtype)
        # Add segmentation block-diagonal attention mask if using segmented data.
        if inputs_segmentation is not None:
            encoder_mask = nn.combine_masks(
                encoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=cfg.dtype))
        return self.encoder(inputs,
                            inputs_positions=inputs_positions,
                            encoder_mask=encoder_mask)
Ejemplo n.º 4
0
  def decode(self,
             programs,
             encoded,
             encoded_padding_mask):
    """Applies decoder on programs and encoded specification."""
    cfg = self.config

    assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                ' but it is: %d' % programs.ndim)
    assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                               ' but it is: %d' % encoded.ndim)

    # Collapse num_io dimension
    flat_encoded = flatten_num_io_dim(encoded)
    flat_encoded_padding_mask = flatten_num_io_dim(encoded_padding_mask)

    # Make attention masks.
    if cfg.decode:
      # For fast decode with caching, programs shape == [batch_size, 1] and
      # cfg.shift = False, cfg.decode = True.
      decoder_mask = None
      encoder_decoder_mask = nn.make_attention_mask(
          jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype)
    else:
      decoder_mask = nn.combine_masks(
          nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype),
          nn.make_causal_mask(programs, dtype=cfg.dtype))
      encoder_decoder_mask = nn.make_attention_mask(
          programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

    return self.decoder(
        programs, flat_encoded, decoder_mask, encoder_decoder_mask)
Ejemplo n.º 5
0
    def decode(
            self,
            encoded,
            inputs,  # only needed for masks
            targets,
            targets_positions=None,
            inputs_segmentation=None,
            targets_segmentation=None):
        """Applies Transformer decoder-branch on encoded-input and target.

    Args:
      encoded: encoded input data from encoder.
      inputs: input data (only needed for masking).
      targets: target data.
      targets_positions: target subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.

    Returns:
      logits array from transformer decoder.
    """
        config = self.config

        # Make padding attention masks.
        if config.decode:
            # for fast autoregressive decoding only a special encoder-decoder mask is used
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(targets > 0,
                                       targets > 0,
                                       dtype=config.dtype),
                nn.make_causal_mask(targets, dtype=config.dtype))
            encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                          inputs > 0,
                                                          dtype=config.dtype)

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       targets_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))
            encoder_decoder_mask = nn.combine_masks(
                encoder_decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))
        logits = self.decoder(encoded,
                              targets,
                              targets_positions=targets_positions,
                              decoder_mask=decoder_mask,
                              encoder_decoder_mask=encoder_decoder_mask)
        return logits.astype(self.config.dtype)
Ejemplo n.º 6
0
  def __call__(self, input_qkv):
    cfg = self.config
    cfg.max_len % cfg.max_seg_len == 0
    bsize = input_qkv.shape[0]
    features = self.out_features or input_qkv.shape[-1]
    num_seg = cfg.max_len // cfg.max_seg_len
    x_sqr = input_qkv.reshape([bsize, num_seg, cfg.max_seg_len, input_qkv.shape[-1]])
    q_row_local, key_row_local, value_row_local, head_dim = get_qkv(cfg, x_sqr)
    local_logits = jnp.einsum('...qhd,...khd->...qhk', q_row_local, key_row_local)
    row_probs = jax.nn.softmax(local_logits)
    if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
      dropout_rng = self.make_rng('dropout')
      row_probs = dropatt(row_probs, dropout_rng, 1 - cfg.attention_dropout_rate)
    row_attn_out = jnp.einsum('...qhk,...khd->...qhd', row_probs, value_row_local)

    key_row = DenseGeneral(features=input_qkv.shape[-1],
                           axis=(-2, -1),
                           kernel_init=cfg.kernel_init,
                           bias_init=cfg.bias_init,
                           use_bias=False,
                           dtype=cfg.dtype)(row_attn_out)
    key_row = nn.Dropout(rate=cfg.dropout_rate)(key_row, deterministic=cfg.deterministic)
    key_row = key_row + x_sqr
    key_row = nn.LayerNorm(dtype=cfg.dtype)(key_row)
    key_row = DenseGeneral(axis=-1,
                           features=(cfg.num_heads, head_dim),
                           kernel_init=cfg.kernel_init,
                           bias_init=cfg.bias_init,
                           use_bias=False,
                           dtype=cfg.dtype)(key_row)
    idx_cols = jnp.arange(cfg.max_seg_len)
    local_mask = nn.make_attention_mask(idx_cols, idx_cols, jnp.less, extra_batch_dims=1)
    local_mask = jnp.expand_dims(local_mask, axis=-2) * -1e10
    local_logits = local_logits + local_mask

    global_logits = jnp.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row)
    idx_rows = jnp.arange(num_seg)
    global_mask = nn.make_attention_mask(idx_rows, idx_rows, jnp.less_equal)
    global_mask = global_mask[:, :, jnp.newaxis, jnp.newaxis, :] * -1e10
    global_logits = global_logits + global_mask

    joint_logits = jnp.concatenate((local_logits, global_logits), axis=-1)
    attn_probs = jax.nn.softmax(joint_logits, axis=-1)
    local_att, global_att = jnp.split(attn_probs, [cfg.max_seg_len], axis=-1)
    if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
      dropout_rng = self.make_rng('dropout')
      local_att = dropatt(local_att, dropout_rng, 1 - cfg.attention_dropout_rate)
    local_merged = jnp.einsum('bsqhk,bskhd->bsqhd', local_att, value_row_local)
    global_merged = jnp.einsum('bqlhv,bvlhd->bqlhd', global_att, row_attn_out)
    joint_merged = jnp.reshape(local_merged + global_merged, [bsize, cfg.max_len, cfg.num_heads, head_dim])
    x = DenseGeneral(features=features,
                  axis=(-2, -1),
                  kernel_init=cfg.kernel_init,
                  bias_init=cfg.bias_init,
                  use_bias=False,
                  dtype=cfg.dtype)(joint_merged)
    return x
Ejemplo n.º 7
0
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        preshift_programs = programs  # Save pre-shifted programs for padding mask.
        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            # BOS tokens attend to all previous BOS tokens.
            decoder_bos_mask = nn.combine_masks(
                nn.make_attention_mask(programs == cfg.bos_token,
                                       programs == cfg.bos_token,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            # Program tokens attend to all previous tokens in partial program.
            decoder_partial_mask = nn.combine_masks(
                make_partial_program_mask(programs,
                                          bos_token=cfg.bos_token,
                                          dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(preshift_programs > 0,
                                       preshift_programs > 0,
                                       dtype=cfg.dtype),
                jnp.logical_or(decoder_bos_mask, decoder_partial_mask))
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
Ejemplo n.º 8
0
    def decode(self, programs, latents, encoded, latents_padding_mask,
               encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert latents.ndim == 3, ('Number of latents dimensions should be 3,'
                                   ' but it is: %d' % latents.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = models.flatten_num_io_dim(
            encoded_padding_mask)

        latents = self.latent_pos_emb(latents)
        # Concatenate the i/o encoding and latents together.
        flat_encoded = jnp.concatenate([flat_encoded, latents], axis=1)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            decoder_mask = None
            latent_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs), latents_padding_mask, dtype=cfg.dtype)
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
            encoder_decoder_mask = jnp.concatenate(
                [encoder_decoder_mask, latent_decoder_mask], axis=-1)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(programs > 0,
                                       programs > 0,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            latent_decoder_mask = nn.make_attention_mask(programs > 0,
                                                         latents_padding_mask,
                                                         dtype=cfg.dtype)
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)
            encoder_decoder_mask = jnp.concatenate(
                [encoder_decoder_mask, latent_decoder_mask], axis=-1)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
Ejemplo n.º 9
0
  def __call__(self,
               input_ids,
               input_mask,
               type_ids,
               deterministic = False):
    """Applies model on the inputs.

    Args:
      input_ids: Tokenized inputs of shape <int>[BATCH_SIZE, MAX_SEQ_LENGTH].
      input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual
        inputs from padding. Only used by BERT.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] ids partitioning input into
        different types.
      deterministic: Whether or not to apply dropout in each layer.

    Returns:
      Hidden states of shape <float>[BATCH_SIZE, MAX_SEQ_LENGTH, HIDDEN_DIM],
        and pooled output <float>[BATCH_SIZE, HIDDEN_DIM] scaled to (-1, 1).
    """
    hidden_states = self.embedder(
        input_ids, type_ids, deterministic=deterministic)

    # Only used by (BERT) self-attention sublayer.
    padding_mask = input_mask.astype(jnp.int32)
    padding_mask = nn.make_attention_mask(
        query_input=padding_mask, key_input=padding_mask)

    for encoder_block in self.encoder_blocks:
      hidden_states = encoder_block(
          hidden_states, padding_mask, deterministic=deterministic)

    pooled_output = self.pooler(hidden_states[:, 0])
    pooled_output = jnp.tanh(pooled_output)

    return hidden_states, pooled_output
Ejemplo n.º 10
0
 def get_attention_masks(self, inputs, targets):
     cfg = self.config
     if cfg.decode:
         decoder_mask = None
         encoder_decoder_mask = nn.make_attention_mask(
             jnp.ones_like(targets) > 0, inputs > 0)
     else:
         decoder_mask = nn.combine_masks(
             nn.make_attention_mask(targets > 0,
                                    targets > 0,
                                    dtype=cfg.dtype),
             nn.make_causal_mask(targets, dtype=cfg.dtype))
         encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                       inputs > 0,
                                                       dtype=cfg.dtype)
     return decoder_mask, encoder_decoder_mask
Ejemplo n.º 11
0
    def decode(self,
               encoded,
               inputs,
               targets,
               targets_positions=None,
               inputs_segmentation=None,
               targets_segmentation=None,
               train=False):
        # Make padding attention masks.
        dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32
        if self.should_decode:
            # For fast autoregressive decoding, only a special encoder-decoder mask is
            # used.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype),
                nn.make_causal_mask(targets, dtype=dtype))
            encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                          inputs > 0,
                                                          dtype=dtype)

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       targets_segmentation,
                                       jnp.equal,
                                       dtype=dtype))
            encoder_decoder_mask = nn.combine_masks(
                encoder_decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        logits = self.decoder(encoded,
                              targets,
                              targets_positions=targets_positions,
                              decoder_mask=decoder_mask,
                              encoder_decoder_mask=encoder_decoder_mask,
                              train=train)
        return logits
Ejemplo n.º 12
0
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies TransformerLM on the inputs.

    Args:
      inputs: target data.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      logits array from transformer decoder.
    """
        config = self.config

        # Make padding attention masks.
        if config.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0,
                                       inputs > 0,
                                       dtype=config.dtype),
                nn.make_causal_mask(inputs, dtype=config.dtype))

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))

        logits = Decoder(config=config, shared_embedding=None, name='decoder')(
            inputs,
            inputs_positions=inputs_positions,
            inputs_segmentation=inputs_segmentation,
            decoder_mask=decoder_mask,
            encoder_decoder_mask=None)
        return logits.astype(self.config.dtype)
Ejemplo n.º 13
0
def make_causal_mask(x, length_axis, extra_batch_dims=0, strict=False):
    idxs = jnp.broadcast_to(jnp.arange(x.shape[length_axis], dtype=jnp.int32),
                            x.shape[:length_axis + 1])
    mask = nn.make_attention_mask(
        idxs,
        idxs,
        jnp.greater_equal if not strict else jnp.greater,
        extra_batch_dims=extra_batch_dims,
        dtype=jnp.float32)
    return mask
Ejemplo n.º 14
0
    def encode(self,
               inputs,
               inputs_positions=None,
               inputs_segmentation=None,
               train=False):
        # Make padding attention mask.
        encoder_mask = nn.make_attention_mask(inputs > 0,
                                              inputs > 0,
                                              dtype=inputs.dtype)
        # Add segmentation block-diagonal attention mask if using segmented data.
        if inputs_segmentation is not None:
            encoder_mask = nn.combine_masks(
                encoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=inputs_segmentation.dtype))
        encoded = self.encoder(inputs,
                               inputs_positions=inputs_positions,
                               encoder_mask=encoder_mask,
                               train=train)

        return encoded
Ejemplo n.º 15
0
    def __call__(self, targets, targets_mask=None):
        """Autoencodes program task.

    Args:
      targets: target data `[batch_size, length]`
      targets_mask: padding mask for targets.

    Returns:
      embedding sequence.
    """
        cfg = self.config
        assert targets.ndim == 2  # (batch, len)

        if targets_mask is None:
            targets_mask = jnp.where(targets > 0, 1, 0).astype(jnp.float32)
        encoder_mask = nn.make_attention_mask(targets_mask,
                                              targets_mask,
                                              dtype=cfg.dtype)

        output_embed = nn.Embed(
            num_embeddings=cfg.output_vocab_size,
            features=cfg.emb_dim,
            embedding_init=nn.initializers.normal(stddev=1.0),
            name='embed_output')

        # Add num_io dimension to latents and latents_mask.
        x = targets.astype('int32')
        x = output_embed(x)
        x = models.AddPositionEmbs(config=cfg,
                                   cache=cfg.decode,
                                   name='posembed')(x)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)

        for lyr in range(cfg.num_layers):
            x = models.EncoderBlock(  # Attend to inputs.
                config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask)

        y = x * targets_mask[Ellipsis, None]
        for i in range(self.c):  # Strided convolutions to decrease length.
            y = nn.Conv(features=cfg.emb_dim,
                        kernel_size=(2, ),
                        strides=(2, ),
                        name=f'conv_{i}')(y)

        return y
Ejemplo n.º 16
0
    def __call__(self, encoding: Array, attention_mask: Array,
                 deterministic: bool) -> Array:
        """Self attention layer forward.

    Args:
      encoding: [bsz, seq_len, model_dim] model state.
      attention_mask: [bsz, seq_len].
      deterministic: if true, do not apply dropout.

    Returns:
      Updated encoding.
    """

        attention_mask = nn.make_attention_mask(attention_mask, attention_mask)
        update = self.attention_layer(inputs_q=encoding,
                                      mask=attention_mask,
                                      deterministic=deterministic)
        update = self.dropout(update, deterministic=deterministic)
        encoding = self.layer_norm(encoding + update)

        return encoding
Ejemplo n.º 17
0
    def __call__(self, inputs, dummy):
        """Vanilla Transformer encoder.

    Args:
      inputs: input data [batch_size, num_io, length]
      dummy: unused for SCAN dataset.
    Returns:
      Encoded inputs `[batch_size, num_io, length, dim]`
    """
        del dummy
        # TODO(kshi): possibly use dummy for RobustFill.

        cfg = self.config

        # Inputs and outputs shared embeddings.
        embed = nn.Embed(num_embeddings=cfg.vocab_size,
                         features=cfg.emb_dim,
                         embedding_init=nn.initializers.normal(stddev=1.0),
                         name='embed')

        x = inputs.astype('int32')
        encoder_mask = nn.make_attention_mask(x > 0, x > 0, dtype=cfg.dtype)

        # Embed outputs.
        x = embed(x)
        if not cfg.use_relative_attention:
            pos_emb = AddPositionEmbs(config=cfg,
                                      cache=False,
                                      name='posembed_io')
            x = pos_emb(x)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)

        for lyr in range(cfg.num_layers):
            x = EncoderBlock(  # Attend to inputs.
                config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask)
        y = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)

        return y
Ejemplo n.º 18
0
    def __call__(self,
                 inputs,
                 train,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      train: bool: if model is training.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer decoder.
    """
        assert inputs.ndim == 2  # (batch, len)
        dtype = utils.dtype_from_str(self.model_dtype)

        if self.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype),
                nn.make_causal_mask(inputs, dtype=dtype))

        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        y = inputs.astype('int32')
        if not self.decode:
            y = shift_inputs(y, segment_ids=inputs_segmentation)

        # TODO(gdahl,znado): this code appears to be accessing out-of-bounds
        # indices for dataset_lib:proteins_test. This will break when jnp.take() is
        # updated to return NaNs for out-of-bounds indices.
        # Debug why this is the case.
        y = jnp.clip(y, 0, self.vocab_size - 1)

        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            output_embed = self.shared_embedding

        y = output_embed(y)

        y = AddPositionEmbs(max_len=self.max_len,
                            posemb_init=sinusoidal_init(max_len=self.max_len),
                            decode=self.decode,
                            name='posembed_output')(
                                y, inputs_positions=inputs_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train)

        y = y.astype(dtype)

        for _ in range(self.num_layers):
            y = Transformer1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                attention_fn=self.attention_fn,
                normalizer=self.normalizer,
                dtype=dtype)(
                    inputs=y,
                    train=train,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=None,
                    inputs_positions=None,
                    inputs_segmentation=None,
                )
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(self.normalizer,
                                                         train,
                                                         dtype=dtype)
            y = maybe_normalize()(y)

        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])
        else:
            logits = nn.Dense(self.vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              dtype=dtype,
                              name='logits_dense')(y)

        return logits.astype(dtype)
Ejemplo n.º 19
0
    def __call__(self, input_qkv):
        cfg = self.config
        cfg.max_len % cfg.max_seg_len == 0
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        query, key, value, head_dim = get_qkv(cfg, input_qkv)

        num_seg = cfg.max_len // cfg.max_seg_len
        cur_query = query.reshape(
            [-1, cfg.max_seg_len, query.shape[-2], query.shape[-1]])
        merged_query = jnp.max(cur_query, axis=1,
                               keepdims=True) * jnp.sqrt(head_dim)
        cur_key = key.reshape(
            [-1, cfg.max_seg_len, key.shape[-2], key.shape[-1]])
        cur_value = value.reshape(
            [-1, cfg.max_seg_len, value.shape[-2], value.shape[-1]])
        dropout_rng = None
        if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')
        s = dot_product_attention(merged_query,
                                  cur_key,
                                  cur_value,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=cfg.attention_dropout_rate,
                                  broadcast_dropout=False,
                                  deterministic=cfg.deterministic,
                                  dtype=cfg.dtype)
        span_val = jnp.reshape(s, [bsize, -1, s.shape[-2], s.shape[-1]])
        span_key = jnp.max(cur_key, axis=1, keepdims=True)
        # (bsize, n_seg, n_head, dim_per_head)
        span_key = jnp.reshape(
            span_key, [bsize, -1, span_key.shape[-2], span_key.shape[-1]])

        local_mask = make_causal_mask(cur_query,
                                      length_axis=1).transpose([0, 2, 1, 3])
        local_bias = lax.select(
            local_mask > 0,
            jnp.full(local_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(local_mask.shape, -1e10).astype(cfg.dtype))
        # (bsize * n_seg, seg_len, n_head, seg_len)
        local_logits = jnp.einsum('...qhd,...khd->...qhk', cur_query,
                                  cur_key) + local_bias
        local_logits = jnp.reshape(local_logits,
                                   [bsize, -1, cfg.num_heads, cfg.max_seg_len])
        idx = jnp.broadcast_to(jnp.arange(span_key.shape[1], dtype=jnp.int32),
                               span_key.shape[:2])
        prev_mask = nn.make_attention_mask(idx,
                                           idx,
                                           jnp.greater,
                                           extra_batch_dims=0,
                                           dtype=jnp.float32).transpose(
                                               [0, 2, 1, 3])
        prev_mask = jnp.repeat(prev_mask, cfg.max_seg_len, axis=-3)
        prev_bias = lax.select(
            prev_mask > 0,
            jnp.full(prev_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(prev_mask.shape, -1e10).astype(cfg.dtype))
        # (bsize, max_len, n_head, num_segs)
        prev_logits = jnp.einsum('...qhd,...khd->...qhk', query,
                                 span_key) + prev_bias
        joint_logits = jnp.concatenate((local_logits, prev_logits), axis=-1)
        # (bsize x max_len,  n_head, seg_len + num_segs)
        attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype)
        local_att, prev_att = jnp.split(attn_weights, [cfg.max_seg_len],
                                        axis=-1)
        local_att = local_att.reshape(
            [bsize * num_seg, cfg.max_seg_len, cfg.num_heads, cfg.max_seg_len])
        local_merged = jnp.einsum('...qhk,...khd->...qhd', local_att,
                                  cur_value)
        prev_merged = jnp.einsum('...qhk,...khd->...qhd', prev_att, span_val)
        joint_merged = jnp.reshape(local_merged,
                                   prev_merged.shape) + prev_merged
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
                         kernel_init=cfg.kernel_init,
                         bias_init=cfg.bias_init,
                         use_bias=False,
                         dtype=cfg.dtype)(joint_merged)
        return x
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config.base_config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        decoder_mask = None
        decoder_relative_position = None  # Relative positions.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            attention_mask_type = self.config.attention_mask_type
            if attention_mask_type == 'baseline':
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
            else:
                if attention_mask_type == 'bos_to_bos':
                    # BOS tokens attend to all previous BOS tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                elif attention_mask_type == 'bos_to_last':
                    # BOS tokens attend to all last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = shift_left(bos_mask)
                elif attention_mask_type == 'bos_to_bos_and_last':
                    # BOS tokens attend to all previous BOS + last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = jnp.logical_or(bos_mask,
                                                      shift_left(bos_mask))
                elif attention_mask_type == 'bos_full_attention':
                    # BOS tokens attend to all previous tokens, including program tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs > 0,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                else:
                    raise ValueError(
                        'Unhandled attention_mask_type: {}'.format(
                            attention_mask_type))
                # Program tokens attend to all previous tokens in partial program.
                decoder_partial_mask = nn.combine_masks(
                    make_partial_program_mask(programs,
                                              bos_token=cfg.bos_token,
                                              dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    jnp.logical_or(decoder_bos_mask, decoder_partial_mask))

                if self.config.bos_special_attention:
                    # Make custom relative positions where BOS are separately indexed.
                    decoder_relative_position = make_relative_position(
                        programs)
                    decoder_partial_relative_position = (
                        make_partial_program_relative_position(
                            programs, bos_token=cfg.bos_token))
                    decoder_relative_position = jnp.where(
                        (programs == cfg.bos_token)[Ellipsis, None],
                        decoder_partial_relative_position,
                        decoder_relative_position)
                else:
                    decoder_relative_position = None

            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask, decoder_relative_position)