Ejemplo n.º 1
0
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        preshift_programs = programs  # Save pre-shifted programs for padding mask.
        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            # BOS tokens attend to all previous BOS tokens.
            decoder_bos_mask = nn.combine_masks(
                nn.make_attention_mask(programs == cfg.bos_token,
                                       programs == cfg.bos_token,
                                       dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            # Program tokens attend to all previous tokens in partial program.
            decoder_partial_mask = nn.combine_masks(
                make_partial_program_mask(programs,
                                          bos_token=cfg.bos_token,
                                          dtype=cfg.dtype),
                nn.make_causal_mask(programs, dtype=cfg.dtype))
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(preshift_programs > 0,
                                       preshift_programs > 0,
                                       dtype=cfg.dtype),
                jnp.logical_or(decoder_bos_mask, decoder_partial_mask))
            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask)
    def decode(self, programs, encoded, encoded_padding_mask):
        """Applies decoder on programs and encoded specification."""
        cfg = self.config.base_config

        assert programs.ndim == 2, ('Number of program dimensions should be 2,'
                                    ' but it is: %d' % programs.ndim)
        assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,'
                                   ' but it is: %d' % encoded.ndim)

        # Collapse num_io dimension
        flat_encoded = base_models.flatten_num_io_dim(encoded)
        flat_encoded_padding_mask = base_models.flatten_num_io_dim(
            encoded_padding_mask)

        if cfg.shift:
            programs = base_models.shift_right(programs, cfg.bos_token)

        # Make attention masks.
        decoder_mask = None
        decoder_relative_position = None  # Relative positions.
        if cfg.decode:
            # For fast decode with caching, programs shape == [batch_size, 1] and
            # cfg.shift = False, cfg.decode = True.
            # TODO(jxihong): Fast decoding currently does not work with new attention.
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(programs),
                flat_encoded_padding_mask,
                dtype=cfg.dtype)
        else:
            attention_mask_type = self.config.attention_mask_type
            if attention_mask_type == 'baseline':
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
            else:
                if attention_mask_type == 'bos_to_bos':
                    # BOS tokens attend to all previous BOS tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                elif attention_mask_type == 'bos_to_last':
                    # BOS tokens attend to all last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = shift_left(bos_mask)
                elif attention_mask_type == 'bos_to_bos_and_last':
                    # BOS tokens attend to all previous BOS + last partial program tokens.
                    bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs == cfg.bos_token,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                    # Shift bos mask to left to get all previous last partial program
                    # tokens.
                    decoder_bos_mask = jnp.logical_or(bos_mask,
                                                      shift_left(bos_mask))
                elif attention_mask_type == 'bos_full_attention':
                    # BOS tokens attend to all previous tokens, including program tokens.
                    decoder_bos_mask = nn.combine_masks(
                        nn.make_attention_mask(programs == cfg.bos_token,
                                               programs > 0,
                                               dtype=cfg.dtype),
                        nn.make_causal_mask(programs, dtype=cfg.dtype))
                else:
                    raise ValueError(
                        'Unhandled attention_mask_type: {}'.format(
                            attention_mask_type))
                # Program tokens attend to all previous tokens in partial program.
                decoder_partial_mask = nn.combine_masks(
                    make_partial_program_mask(programs,
                                              bos_token=cfg.bos_token,
                                              dtype=cfg.dtype),
                    nn.make_causal_mask(programs, dtype=cfg.dtype))
                decoder_mask = nn.combine_masks(
                    nn.make_attention_mask(programs > 0,
                                           programs > 0,
                                           dtype=cfg.dtype),
                    jnp.logical_or(decoder_bos_mask, decoder_partial_mask))

                if self.config.bos_special_attention:
                    # Make custom relative positions where BOS are separately indexed.
                    decoder_relative_position = make_relative_position(
                        programs)
                    decoder_partial_relative_position = (
                        make_partial_program_relative_position(
                            programs, bos_token=cfg.bos_token))
                    decoder_relative_position = jnp.where(
                        (programs == cfg.bos_token)[Ellipsis, None],
                        decoder_partial_relative_position,
                        decoder_relative_position)
                else:
                    decoder_relative_position = None

            encoder_decoder_mask = nn.make_attention_mask(
                programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype)

        return self.decoder(programs, flat_encoded, decoder_mask,
                            encoder_decoder_mask, decoder_relative_position)