示例#1
0
  def decode(self,
             programs,
             encoded,
             encoded_padding_mask):
    """Applies decoder on programs and encoded specification."""
    cfg = self.config

    # Allow for decoding without num_partial dimension for beam search.
    # programs shape == [batch_size, (num_partial), length]
    assert programs.ndim in [2, 3], ('Number of program dimensions should be'
                                     '2 or 3, but it is: %d' % programs.ndim)
    assert encoded.ndim == programs.ndim + 2

    # Collapse num_io dimension.
    num_io_axis = 1 if programs.ndim == 2 else 2
    flat_encoded = base_models.flatten_num_io_dim(encoded, axis=num_io_axis)
    flat_encoded_padding_mask = base_models.flatten_num_io_dim(
        encoded_padding_mask, axis=num_io_axis)

    # Make attention masks.
    if cfg.decode:
      # For fast decode with caching, programs shape == [batch_size, 1] and
      # cfg.shift = False, cfg.decode = True.
      decoder_mask = None
      encoder_decoder_mask = nn.make_attention_mask(
          jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype)
    else:
      decoder_mask = nn.combine_masks(
          nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype),
          nn.make_causal_mask(programs, dtype=cfg.dtype))
      encoder_decoder_mask = nn.make_attention_mask(
          programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

    return self.decoder(
        programs, flat_encoded, decoder_mask, encoder_decoder_mask)
示例#2
0
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        preshift_programs = programs  # Save pre-shifted programs for padding mask.
        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            # BOS tokens attend to all previous BOS tokens.
            decoder_bos_mask = nn.combine_masks(
                nn.make_attention_mask(programs == cfg.bos_token,
                                       programs == cfg.bos_token,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            # Program tokens attend to all previous tokens in partial program.
            decoder_partial_mask = nn.combine_masks(
                make_partial_program_mask(programs,
                                          bos_token=cfg.bos_token,
                                          dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(preshift_programs > 0,
                                       preshift_programs > 0,
                                       dtype=cfg.dtype),
                jnp.logical_or(decoder_bos_mask, decoder_partial_mask))
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
    def decode(self, programs, latents, encoded, latents_padding_mask,
               encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert latents.ndim == 3, ('Number of latents dimensions should be 3,'
                                   ' but it is: %d' % latents.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = models.flatten_num_io_dim(
            encoded_padding_mask)

        latents = self.latent_pos_emb(latents)
        # Concatenate the i/o encoding and latents together.
        flat_encoded = jnp.concatenate([flat_encoded, latents], axis=1)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            decoder_mask = None
            latent_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs), latents_padding_mask, dtype=cfg.dtype)
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
            encoder_decoder_mask = jnp.concatenate(
                [encoder_decoder_mask, latent_decoder_mask], axis=-1)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(programs > 0,
                                       programs > 0,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            latent_decoder_mask = nn.make_attention_mask(programs > 0,
                                                         latents_padding_mask,
                                                         dtype=cfg.dtype)
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)
            encoder_decoder_mask = jnp.concatenate(
                [encoder_decoder_mask, latent_decoder_mask], axis=-1)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config.base_config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        decoder_mask = None
        decoder_relative_position = None  # Relative positions.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            attention_mask_type = self.config.attention_mask_type
            if attention_mask_type == 'baseline':
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
            else:
                if attention_mask_type == 'bos_to_bos':
                    # BOS tokens attend to all previous BOS tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                elif attention_mask_type == 'bos_to_last':
                    # BOS tokens attend to all last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = shift_left(bos_mask)
                elif attention_mask_type == 'bos_to_bos_and_last':
                    # BOS tokens attend to all previous BOS + last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = jnp.logical_or(bos_mask,
                                                      shift_left(bos_mask))
                elif attention_mask_type == 'bos_full_attention':
                    # BOS tokens attend to all previous tokens, including program tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs > 0,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                else:
                    raise ValueError(
                        'Unhandled attention_mask_type: {}'.format(
                            attention_mask_type))
                # Program tokens attend to all previous tokens in partial program.
                decoder_partial_mask = nn.combine_masks(
                    make_partial_program_mask(programs,
                                              bos_token=cfg.bos_token,
                                              dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    jnp.logical_or(decoder_bos_mask, decoder_partial_mask))

                if self.config.bos_special_attention:
                    # Make custom relative positions where BOS are separately indexed.
                    decoder_relative_position = make_relative_position(
                        programs)
                    decoder_partial_relative_position = (
                        make_partial_program_relative_position(
                            programs, bos_token=cfg.bos_token))
                    decoder_relative_position = jnp.where(
                        (programs == cfg.bos_token)[Ellipsis, None],
                        decoder_partial_relative_position,
                        decoder_relative_position)
                else:
                    decoder_relative_position = None

            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask, decoder_relative_position)