def _prepare_source(): """ Pre-processes inputs to the encoder and generates the corresponding attention masks.""" # Embed source_embeddings = self._embed(source_ids) # Obtain length and depth of the input tensors _, time_steps, depth = get_shape_list(source_embeddings) # Transform input mask into attention mask inverse_mask = tf.cast(tf.equal(source_mask, 0.0), dtype=self.float_dtype) attn_mask = inverse_mask * -1e9 # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1) # Differentiate between self-attention and cross-attention masks for further, optional modifications self_attn_mask = attn_mask cross_attn_mask = attn_mask # Add positional encodings positional_signal = get_positional_signal(time_steps, depth, self.float_dtype) source_embeddings += positional_signal # Apply dropout if self.config.dropout_embeddings > 0: source_embeddings = tf.layers.dropout( source_embeddings, rate=self.config.dropout_embeddings, training=self.training) return source_embeddings, self_attn_mask, cross_attn_mask
def decode_at_train(self, target_ids, enc_cache, cross_attn_mask): """ Returns the probability distribution over target-side tokens conditioned on the output of the encoder; performs decoding in parallel at training time. """ def _decode_all(target_embeddings): """ Decodes the encoder-generated representations into target-side logits in parallel. """ # Apply input dropout dec_input = \ tf.layers.dropout(target_embeddings, rate=self.config.dropout_embeddings, training=self.training) # Propagate inputs through the encoder stack dec_output = dec_input for layer_id in range(1, self.config.num_decoder_layers + 1): dec_output, _ = self.decoder_stack[layer_id][ 'self_attn'].forward(dec_output, None, self_attn_mask) dec_output, _ = \ self.decoder_stack[layer_id]['cross_attn'].forward(dec_output, enc_cache, cross_attn_mask) dec_output = self.decoder_stack[layer_id]['ffn'].forward( dec_output) # Update gate-tracker if len(self.gate_tracker.keys()) > 0: self.gate_tracker['decoder_layer_{:d}'.format(layer_id)]['lexical_gate_keys'] = \ self.decoder_stack[layer_id]['cross_attn'].key_gate self.gate_tracker['decoder_layer_{:d}'.format(layer_id)]['lexical_gate_values'] = \ self.decoder_stack[layer_id]['cross_attn'].value_gate return dec_output def _prepare_targets(): """ Pre-processes target token ids before they're passed on as input to the decoder for parallel decoding. """ # Embed target_ids target_embeddings = self._embed(target_ids) target_embeddings += positional_signal if self.config.dropout_embeddings > 0: target_embeddings = tf.layers.dropout( target_embeddings, rate=self.config.dropout_embeddings, training=self.training) return target_embeddings def _decoding_function(): """ Generates logits for target-side tokens. """ # Embed the model's predictions up to the current time-step; add positional information, mask target_embeddings = _prepare_targets() # Pass encoder context and decoder embeddings through the decoder dec_output = _decode_all(target_embeddings) # Project decoder stack outputs and apply the soft-max non-linearity full_logits = self.softmax_projection_layer.project(dec_output) return full_logits with tf.variable_scope(self.name): # Create nodes self._build_graph() self_attn_mask = get_right_context_mask(tf.shape(target_ids)[-1]) positional_signal = get_positional_signal( tf.shape(target_ids)[-1], self.config.embedding_size, self.float_dtype) logits = _decoding_function() return logits
def decode_at_test(self, enc_output, cross_attn_mask, batch_size, beam_size, do_sample): """ Returns the probability distribution over target-side tokens conditioned on the output of the encoder; performs decoding via auto-regression at test time. """ def _decode_step(target_embeddings, memories): """ Decode the encoder-generated representations into target-side logits with auto-regression. """ # Propagate inputs through the encoder stack dec_output = target_embeddings # NOTE: No self-attention mask is applied at decoding, as future information is unavailable for layer_id in range(1, self.config.num_decoder_layers + 1): dec_output, memories['layer_{:d}'.format(layer_id)] = \ self.decoder_stack[layer_id]['self_attn'].forward( dec_output, None, None, memories['layer_{:d}'.format(layer_id)]) dec_output, _ = \ self.decoder_stack[layer_id]['cross_attn'].forward(dec_output, enc_output, cross_attn_mask) dec_output = self.decoder_stack[layer_id]['ffn'].forward( dec_output) # Return prediction at the final time-step to be consistent with the inference pipeline dec_output = dec_output[:, -1, :] return dec_output, memories def _pre_process_targets(step_target_ids, current_time_step): """ Pre-processes target token ids before they're passed on as input to the decoder for auto-regressive decoding. """ # Embed target_ids target_embeddings = self._embed(step_target_ids) signal_slice = positional_signal[:, current_time_step - 1:current_time_step, :] target_embeddings += signal_slice if self.config.dropout_embeddings > 0: target_embeddings = tf.layers.dropout( target_embeddings, rate=self.config.dropout_embeddings, training=self.training) return target_embeddings def _decoding_function(step_target_ids, current_time_step, memories): """ Generates logits for the target-side token predicted for the next-time step with auto-regression. """ # Embed the model's predictions up to the current time-step; add positional information, mask target_embeddings = _pre_process_targets(step_target_ids, current_time_step) # Pass encoder context and decoder embeddings through the decoder dec_output, memories = _decode_step(target_embeddings, memories) # Project decoder stack outputs and apply the soft-max non-linearity step_logits = self.softmax_projection_layer.project(dec_output) return step_logits, memories with tf.variable_scope(self.name): # Create nodes self._build_graph() positional_signal = get_positional_signal( self.config.translation_max_len, self.config.embedding_size, self.float_dtype) if beam_size > 0: # Initialize target IDs with <GO> initial_ids = tf.cast(tf.fill([batch_size], 1), dtype=self.int_dtype) initial_memories = self._get_initial_memories( batch_size, beam_size=beam_size) output_sequences, scores = beam_search( _decoding_function, initial_ids, initial_memories, self.int_dtype, self.float_dtype, self.config.translation_max_len, batch_size, beam_size, self.embedding_layer.get_vocab_size(), 0, self.config.length_normalization_alpha) else: # Initialize target IDs with <GO> initial_ids = tf.cast(tf.fill([batch_size, 1], 1), dtype=self.int_dtype) initial_memories = self._get_initial_memories(batch_size, beam_size=1) output_sequences, scores = greedy_search( _decoding_function, initial_ids, initial_memories, self.int_dtype, self.float_dtype, self.config.translation_max_len, batch_size, 0, do_sample, time_major=False) return output_sequences, scores