def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) preshift_programs = programs # Save pre-shifted programs for padding mask. if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(preshift_programs > 0, preshift_programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config.base_config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. decoder_mask = None decoder_relative_position = None # Relative positions. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: attention_mask_type = self.config.attention_mask_type if attention_mask_type == 'baseline': decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: if attention_mask_type == 'bos_to_bos': # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) elif attention_mask_type == 'bos_to_last': # BOS tokens attend to all last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = shift_left(bos_mask) elif attention_mask_type == 'bos_to_bos_and_last': # BOS tokens attend to all previous BOS + last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = jnp.logical_or(bos_mask, shift_left(bos_mask)) elif attention_mask_type == 'bos_full_attention': # BOS tokens attend to all previous tokens, including program tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: raise ValueError( 'Unhandled attention_mask_type: {}'.format( attention_mask_type)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) if self.config.bos_special_attention: # Make custom relative positions where BOS are separately indexed. decoder_relative_position = make_relative_position( programs) decoder_partial_relative_position = ( make_partial_program_relative_position( programs, bos_token=cfg.bos_token)) decoder_relative_position = jnp.where( (programs == cfg.bos_token)[Ellipsis, None], decoder_partial_relative_position, decoder_relative_position) else: decoder_relative_position = None encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask, decoder_relative_position)