Esempio n. 1
0
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        preshift_programs = programs  # Save pre-shifted programs for padding mask.
        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            # BOS tokens attend to all previous BOS tokens.
            decoder_bos_mask = nn.combine_masks(
                nn.make_attention_mask(programs == cfg.bos_token,
                                       programs == cfg.bos_token,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            # Program tokens attend to all previous tokens in partial program.
            decoder_partial_mask = nn.combine_masks(
                make_partial_program_mask(programs,
                                          bos_token=cfg.bos_token,
                                          dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(preshift_programs > 0,
                                       preshift_programs > 0,
                                       dtype=cfg.dtype),
                jnp.logical_or(decoder_bos_mask, decoder_partial_mask))
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
Esempio n. 2
0
    def setup(self):
        self.embed_dim = self.config.hidden_size
        self.num_heads = self.config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
        self.scale = self.head_dim**-0.5
        self.dropout = self.config.attention_dropout

        self.k_proj = nn.Dense(self.embed_dim,
                               dtype=self.dtype,
                               kernel_init=jax.nn.initializers.normal(
                                   0.01, dtype=self.dtype))
        self.v_proj = nn.Dense(self.embed_dim,
                               dtype=self.dtype,
                               kernel_init=jax.nn.initializers.normal(
                                   0.01, dtype=self.dtype))
        self.q_proj = nn.Dense(self.embed_dim,
                               dtype=self.dtype,
                               kernel_init=jax.nn.initializers.normal(
                                   0.01, dtype=self.dtype))
        self.out_proj = nn.Dense(self.embed_dim,
                                 dtype=self.dtype,
                                 kernel_init=jax.nn.initializers.normal(
                                     0.01, dtype=self.dtype))

        self.causal = isinstance(self.config, CLIPTextConfig)
        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="i4"))
Esempio n. 3
0
    def setup(self):
        config = self.config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.rotary_dim = config.rotary_dim

        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
        )

        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense()

        self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)

        self.causal_mask = make_causal_mask(jnp.ones(
            (1, config.max_position_embeddings), dtype="bool"),
                                            dtype="bool")

        pos_embd_dim = self.rotary_dim or self.embed_dim
        self.embed_positions = create_sinusoidal_positions(
            config.max_position_embeddings, pos_embd_dim)
Esempio n. 4
0
  def decode(self,
             programs,
             encoded,
             encoded_padding_mask):
    """Applies decoder on programs and encoded specification."""
    cfg = self.config

    assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                ' but it is: %d' % programs.ndim)
    assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                               ' but it is: %d' % encoded.ndim)

    # Collapse num_io dimension
    flat_encoded = flatten_num_io_dim(encoded)
    flat_encoded_padding_mask = flatten_num_io_dim(encoded_padding_mask)

    # Make attention masks.
    if cfg.decode:
      # For fast decode with caching, programs shape == [batch_size, 1] and
      # cfg.shift = False, cfg.decode = True.
      decoder_mask = None
      encoder_decoder_mask = nn.make_attention_mask(
          jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype)
    else:
      decoder_mask = nn.combine_masks(
          nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype),
          nn.make_causal_mask(programs, dtype=cfg.dtype))
      encoder_decoder_mask = nn.make_attention_mask(
          programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

    return self.decoder(
        programs, flat_encoded, decoder_mask, encoder_decoder_mask)
Esempio n. 5
0
  def decode(self,
             programs,
             encoded,
             encoded_padding_mask):
    """Applies decoder on programs and encoded specification."""
    cfg = self.config

    # Allow for decoding without num_partial dimension for beam search.
    # programs shape == [batch_size, (num_partial), length]
    assert programs.ndim in [2, 3], ('Number of program dimensions should be'
                                     '2 or 3, but it is: %d' % programs.ndim)
    assert encoded.ndim == programs.ndim + 2

    # Collapse num_io dimension.
    num_io_axis = 1 if programs.ndim == 2 else 2
    flat_encoded = base_models.flatten_num_io_dim(encoded, axis=num_io_axis)
    flat_encoded_padding_mask = base_models.flatten_num_io_dim(
        encoded_padding_mask, axis=num_io_axis)

    # Make attention masks.
    if cfg.decode:
      # For fast decode with caching, programs shape == [batch_size, 1] and
      # cfg.shift = False, cfg.decode = True.
      decoder_mask = None
      encoder_decoder_mask = nn.make_attention_mask(
          jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype)
    else:
      decoder_mask = nn.combine_masks(
          nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype),
          nn.make_causal_mask(programs, dtype=cfg.dtype))
      encoder_decoder_mask = nn.make_attention_mask(
          programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

    return self.decoder(
        programs, flat_encoded, decoder_mask, encoder_decoder_mask)
    def setup(self):
        config = self.config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads}).")

        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.resid_dropout)

        dense = partial(
            nn.Dense,
            self.embed_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
        )

        self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(
            use_bias=False), dense(use_bias=False)
        self.out_proj = dense()

        self.causal_mask = make_causal_mask(jnp.ones(
            (1, config.max_position_embeddings), dtype="bool"),
                                            dtype="bool")
        if self.attention_type == "local":
            self.causal_mask = self.causal_mask ^ jnp.tril(
                self.causal_mask, -config.window_size)
    def setup(self) -> None:
        self.head_dim = self.embed_dim // self.num_heads

        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} "
                "and `num_heads`: {self.num_heads}).")

        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense()

        self.dropout_layer = nn.Dropout(rate=self.dropout)

        if self.causal:
            self.causal_mask = make_causal_mask(jnp.ones(
                (1, self.config.max_position_embeddings), dtype="bool"),
                                                dtype="bool")
Esempio n. 8
0
    def decode(
            self,
            encoded,
            inputs,  # only needed for masks
            targets,
            targets_positions=None,
            inputs_segmentation=None,
            targets_segmentation=None):
        """Applies Transformer decoder-branch on encoded-input and target.

    Args:
      encoded: encoded input data from encoder.
      inputs: input data (only needed for masking).
      targets: target data.
      targets_positions: target subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.

    Returns:
      logits array from transformer decoder.
    """
        config = self.config

        # Make padding attention masks.
        if config.decode:
            # for fast autoregressive decoding only a special encoder-decoder mask is used
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(targets > 0,
                                       targets > 0,
                                       dtype=config.dtype),
                nn.make_causal_mask(targets, dtype=config.dtype))
            encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                          inputs > 0,
                                                          dtype=config.dtype)

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       targets_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))
            encoder_decoder_mask = nn.combine_masks(
                encoder_decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))
        logits = self.decoder(encoded,
                              targets,
                              targets_positions=targets_positions,
                              decoder_mask=decoder_mask,
                              encoder_decoder_mask=encoder_decoder_mask)
        return logits.astype(self.config.dtype)
Esempio n. 9
0
    def setup(self):
        config = self.config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype)
        self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
        self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
        self.causal_mask = make_causal_mask(jnp.ones(
            (1, config.max_position_embeddings), dtype="bool"),
                                            dtype="bool")
Esempio n. 10
0
    def decode(self, programs, latents, encoded, latents_padding_mask,
               encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert latents.ndim == 3, ('Number of latents dimensions should be 3,'
                                   ' but it is: %d' % latents.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = models.flatten_num_io_dim(
            encoded_padding_mask)

        latents = self.latent_pos_emb(latents)
        # Concatenate the i/o encoding and latents together.
        flat_encoded = jnp.concatenate([flat_encoded, latents], axis=1)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            decoder_mask = None
            latent_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs), latents_padding_mask, dtype=cfg.dtype)
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
            encoder_decoder_mask = jnp.concatenate(
                [encoder_decoder_mask, latent_decoder_mask], axis=-1)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(programs > 0,
                                       programs > 0,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            latent_decoder_mask = nn.make_attention_mask(programs > 0,
                                                         latents_padding_mask,
                                                         dtype=cfg.dtype)
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)
            encoder_decoder_mask = jnp.concatenate(
                [encoder_decoder_mask, latent_decoder_mask], axis=-1)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
Esempio n. 11
0
 def get_attention_masks(self, inputs, targets):
     cfg = self.config
     if cfg.decode:
         decoder_mask = None
         encoder_decoder_mask = nn.make_attention_mask(
             jnp.ones_like(targets) > 0, inputs > 0)
     else:
         decoder_mask = nn.combine_masks(
             nn.make_attention_mask(targets > 0,
                                    targets > 0,
                                    dtype=cfg.dtype),
             nn.make_causal_mask(targets, dtype=cfg.dtype))
         encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                       inputs > 0,
                                                       dtype=cfg.dtype)
     return decoder_mask, encoder_decoder_mask
Esempio n. 12
0
    def decode(self,
               encoded,
               inputs,
               targets,
               targets_positions=None,
               inputs_segmentation=None,
               targets_segmentation=None,
               train=False):
        # Make padding attention masks.
        dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32
        if self.should_decode:
            # For fast autoregressive decoding, only a special encoder-decoder mask is
            # used.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype),
                nn.make_causal_mask(targets, dtype=dtype))
            encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                          inputs > 0,
                                                          dtype=dtype)

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       targets_segmentation,
                                       jnp.equal,
                                       dtype=dtype))
            encoder_decoder_mask = nn.combine_masks(
                encoder_decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        logits = self.decoder(encoded,
                              targets,
                              targets_positions=targets_positions,
                              decoder_mask=decoder_mask,
                              encoder_decoder_mask=encoder_decoder_mask,
                              train=train)
        return logits
Esempio n. 13
0
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies TransformerLM on the inputs.

    Args:
      inputs: target data.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      logits array from transformer decoder.
    """
        config = self.config

        # Make padding attention masks.
        if config.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0,
                                       inputs > 0,
                                       dtype=config.dtype),
                nn.make_causal_mask(inputs, dtype=config.dtype))

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=config.dtype))

        logits = Decoder(config=config, shared_embedding=None, name='decoder')(
            inputs,
            inputs_positions=inputs_positions,
            inputs_segmentation=inputs_segmentation,
            decoder_mask=decoder_mask,
            encoder_decoder_mask=None)
        return logits.astype(self.config.dtype)
Esempio n. 14
0
    def __call__(self, inputs, temb, train, context, permutations):
        """Applies Transformer model on the inputs.

    Args:
      inputs: Input data.
      temb: The time embedding.
      train: Is the model training?
      context: A context to condition on.
      permutations: A batch of permutations that specifies generation order.

    Returns:
      Output of a transformer decoder.
    """
        cfg = self.config
        assert inputs.ndim == 2  # (batch, len)
        deterministic = not train

        # Permutations give the permutation order, for XLNet style training only. It
        # is important that permutations are applied _before shifting_. For this
        # reason, we also have to deal with the positional embeddings seperately
        # at a later point.
        if permutations is not None:
            assert cfg.is_causal
            assert permutations.shape == inputs.shape

            # Use the permutations to act on the inputs.
            inputs = util_fns.batch_permute(inputs, permutations)

        # Target Embedding
        embedding_layer = nn.Embed(
            num_embeddings=cfg.output_vocab_size,
            features=cfg.emb_dim,
            embedding_init=nn.initializers.normal(stddev=1.0))

        # Concatenate context if available.
        if context is not None:
            assert cfg.context_length == context.shape[
                1], f'{cfg.context_length} != {context.shape[1]} for {context.shape}'
            inputs = jnp.concatenate([context, inputs], axis=1)

        y = inputs.astype('int32')

        if cfg.is_causal:
            logging.info('Using causal Transformer')
            decoder_mask = nn.make_causal_mask(inputs, dtype=cfg.dtype)
        else:
            logging.info('Using fully connected (non-causal) Transformer')
            decoder_mask = None

        if cfg.is_causal:
            y = shift_inputs(y)
        y = embedding_layer(y)

        y = AddPositionEmbs(config=cfg, name='add_posemb')(y, permutations)

        y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=deterministic)

        y = y.astype(cfg.dtype)

        # Target-Input Decoder
        for lyr in range(cfg.num_layers):
            y = EncoderDecoder1DBlock(config=cfg,
                                      name=f'encoderdecoderblock_{lyr}')(
                                          y,
                                          temb,
                                          deterministic,
                                          decoder_mask=decoder_mask)
        y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

        logits = nn.Dense(cfg.output_vocab_size,
                          dtype=cfg.dtype,
                          kernel_init=cfg.kernel_init,
                          bias_init=cfg.bias_init,
                          name='logitdense')(y)

        if context is not None:
            # Take only predictions for inputs, not context.
            logits = logits[:, cfg.context_length:]

        if permutations is not None:
            assert cfg.is_causal
            # Apply the inverse permutation to the logits.
            inv_permutations = util_fns.compute_batch_inverse_permute(
                permutations)
            logits = util_fns.batch_permute(logits, inv_permutations)

        return logits
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config.base_config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        decoder_mask = None
        decoder_relative_position = None  # Relative positions.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            attention_mask_type = self.config.attention_mask_type
            if attention_mask_type == 'baseline':
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
            else:
                if attention_mask_type == 'bos_to_bos':
                    # BOS tokens attend to all previous BOS tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                elif attention_mask_type == 'bos_to_last':
                    # BOS tokens attend to all last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = shift_left(bos_mask)
                elif attention_mask_type == 'bos_to_bos_and_last':
                    # BOS tokens attend to all previous BOS + last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = jnp.logical_or(bos_mask,
                                                      shift_left(bos_mask))
                elif attention_mask_type == 'bos_full_attention':
                    # BOS tokens attend to all previous tokens, including program tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs > 0,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                else:
                    raise ValueError(
                        'Unhandled attention_mask_type: {}'.format(
                            attention_mask_type))
                # Program tokens attend to all previous tokens in partial program.
                decoder_partial_mask = nn.combine_masks(
                    make_partial_program_mask(programs,
                                              bos_token=cfg.bos_token,
                                              dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    jnp.logical_or(decoder_bos_mask, decoder_partial_mask))

                if self.config.bos_special_attention:
                    # Make custom relative positions where BOS are separately indexed.
                    decoder_relative_position = make_relative_position(
                        programs)
                    decoder_partial_relative_position = (
                        make_partial_program_relative_position(
                            programs, bos_token=cfg.bos_token))
                    decoder_relative_position = jnp.where(
                        (programs == cfg.bos_token)[Ellipsis, None],
                        decoder_partial_relative_position,
                        decoder_relative_position)
                else:
                    decoder_relative_position = None

            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask, decoder_relative_position)
Esempio n. 16
0
    def __call__(self,
                 inputs,
                 train,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      train: bool: if model is training.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer decoder.
    """
        assert inputs.ndim == 2  # (batch, len)
        dtype = utils.dtype_from_str(self.model_dtype)

        if self.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype),
                nn.make_causal_mask(inputs, dtype=dtype))

        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        y = inputs.astype('int32')
        if not self.decode:
            y = shift_inputs(y, segment_ids=inputs_segmentation)

        # TODO(gdahl,znado): this code appears to be accessing out-of-bounds
        # indices for dataset_lib:proteins_test. This will break when jnp.take() is
        # updated to return NaNs for out-of-bounds indices.
        # Debug why this is the case.
        y = jnp.clip(y, 0, self.vocab_size - 1)

        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            output_embed = self.shared_embedding

        y = output_embed(y)

        y = AddPositionEmbs(max_len=self.max_len,
                            posemb_init=sinusoidal_init(max_len=self.max_len),
                            decode=self.decode,
                            name='posembed_output')(
                                y, inputs_positions=inputs_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train)

        y = y.astype(dtype)

        for _ in range(self.num_layers):
            y = Transformer1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                attention_fn=self.attention_fn,
                normalizer=self.normalizer,
                dtype=dtype)(
                    inputs=y,
                    train=train,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=None,
                    inputs_positions=None,
                    inputs_segmentation=None,
                )
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(self.normalizer,
                                                         train,
                                                         dtype=dtype)
            y = maybe_normalize()(y)

        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])
        else:
            logits = nn.Dense(self.vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              dtype=dtype,
                              name='logits_dense')(y)

        return logits.astype(dtype)