def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config # Allow for decoding without num_partial dimension for beam search. # programs shape == [batch_size, (num_partial), length] assert programs.ndim in [2, 3], ('Number of program dimensions should be' '2 or 3, but it is: %d' % programs.ndim) assert encoded.ndim == programs.ndim + 2 # Collapse num_io dimension. num_io_axis = 1 if programs.ndim == 2 else 2 flat_encoded = base_models.flatten_num_io_dim(encoded, axis=num_io_axis) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask, axis=num_io_axis) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder( programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) preshift_programs = programs # Save pre-shifted programs for padding mask. if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(preshift_programs > 0, preshift_programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def decode(self, programs, latents, encoded, latents_padding_mask, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert latents.ndim == 3, ('Number of latents dimensions should be 3,' ' but it is: %d' % latents.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = models.flatten_num_io_dim( encoded_padding_mask) latents = self.latent_pos_emb(latents) # Concatenate the i/o encoding and latents together. flat_encoded = jnp.concatenate([flat_encoded, latents], axis=1) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None latent_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), latents_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = jnp.concatenate( [encoder_decoder_mask, latent_decoder_mask], axis=-1) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) latent_decoder_mask = nn.make_attention_mask(programs > 0, latents_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = jnp.concatenate( [encoder_decoder_mask, latent_decoder_mask], axis=-1) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config.base_config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. decoder_mask = None decoder_relative_position = None # Relative positions. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: attention_mask_type = self.config.attention_mask_type if attention_mask_type == 'baseline': decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: if attention_mask_type == 'bos_to_bos': # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) elif attention_mask_type == 'bos_to_last': # BOS tokens attend to all last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = shift_left(bos_mask) elif attention_mask_type == 'bos_to_bos_and_last': # BOS tokens attend to all previous BOS + last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = jnp.logical_or(bos_mask, shift_left(bos_mask)) elif attention_mask_type == 'bos_full_attention': # BOS tokens attend to all previous tokens, including program tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: raise ValueError( 'Unhandled attention_mask_type: {}'.format( attention_mask_type)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) if self.config.bos_special_attention: # Make custom relative positions where BOS are separately indexed. decoder_relative_position = make_relative_position( programs) decoder_partial_relative_position = ( make_partial_program_relative_position( programs, bos_token=cfg.bos_token)) decoder_relative_position = jnp.where( (programs == cfg.bos_token)[Ellipsis, None], decoder_partial_relative_position, decoder_relative_position) else: decoder_relative_position = None encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask, decoder_relative_position)