def _prepare_source():
     """ Pre-processes inputs to the encoder and generates the corresponding attention masks."""
     # Embed
     source_embeddings = self._embed(source_ids)
     # Obtain length and depth of the input tensors
     _, time_steps, depth = get_shape_list(source_embeddings)
     # Transform input mask into attention mask
     inverse_mask = tf.cast(tf.equal(source_mask, 0.0),
                            dtype=self.float_dtype)
     attn_mask = inverse_mask * -1e9
     # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits
     attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1)
     # Differentiate between self-attention and cross-attention masks for further, optional modifications
     self_attn_mask = attn_mask
     cross_attn_mask = attn_mask
     # Add positional encodings
     positional_signal = get_positional_signal(time_steps, depth,
                                               self.float_dtype)
     source_embeddings += positional_signal
     # Apply dropout
     if self.config.dropout_embeddings > 0:
         source_embeddings = tf.layers.dropout(
             source_embeddings,
             rate=self.config.dropout_embeddings,
             training=self.training)
     return source_embeddings, self_attn_mask, cross_attn_mask
    def decode_at_train(self, target_ids, enc_cache, cross_attn_mask):
        """ Returns the probability distribution over target-side tokens conditioned on the output of the encoder;
         performs decoding in parallel at training time. """
        def _decode_all(target_embeddings):
            """ Decodes the encoder-generated representations into target-side logits in parallel. """
            # Apply input dropout
            dec_input = \
                tf.layers.dropout(target_embeddings, rate=self.config.dropout_embeddings, training=self.training)
            # Propagate inputs through the encoder stack
            dec_output = dec_input
            for layer_id in range(1, self.config.num_decoder_layers + 1):
                dec_output, _ = self.decoder_stack[layer_id][
                    'self_attn'].forward(dec_output, None, self_attn_mask)
                dec_output, _ = \
                    self.decoder_stack[layer_id]['cross_attn'].forward(dec_output, enc_cache, cross_attn_mask)
                dec_output = self.decoder_stack[layer_id]['ffn'].forward(
                    dec_output)

                # Update gate-tracker
                if len(self.gate_tracker.keys()) > 0:
                    self.gate_tracker['decoder_layer_{:d}'.format(layer_id)]['lexical_gate_keys'] = \
                        self.decoder_stack[layer_id]['cross_attn'].key_gate
                    self.gate_tracker['decoder_layer_{:d}'.format(layer_id)]['lexical_gate_values'] = \
                        self.decoder_stack[layer_id]['cross_attn'].value_gate

            return dec_output

        def _prepare_targets():
            """ Pre-processes target token ids before they're passed on as input to the decoder
            for parallel decoding. """
            # Embed target_ids
            target_embeddings = self._embed(target_ids)
            target_embeddings += positional_signal
            if self.config.dropout_embeddings > 0:
                target_embeddings = tf.layers.dropout(
                    target_embeddings,
                    rate=self.config.dropout_embeddings,
                    training=self.training)
            return target_embeddings

        def _decoding_function():
            """ Generates logits for target-side tokens. """
            # Embed the model's predictions up to the current time-step; add positional information, mask
            target_embeddings = _prepare_targets()
            # Pass encoder context and decoder embeddings through the decoder
            dec_output = _decode_all(target_embeddings)
            # Project decoder stack outputs and apply the soft-max non-linearity
            full_logits = self.softmax_projection_layer.project(dec_output)
            return full_logits

        with tf.variable_scope(self.name):
            # Create nodes
            self._build_graph()

            self_attn_mask = get_right_context_mask(tf.shape(target_ids)[-1])
            positional_signal = get_positional_signal(
                tf.shape(target_ids)[-1], self.config.embedding_size,
                self.float_dtype)
            logits = _decoding_function()
        return logits
    def decode_at_test(self, enc_output, cross_attn_mask, batch_size,
                       beam_size, do_sample):
        """ Returns the probability distribution over target-side tokens conditioned on the output of the encoder;
         performs decoding via auto-regression at test time. """
        def _decode_step(target_embeddings, memories):
            """ Decode the encoder-generated representations into target-side logits with auto-regression. """
            # Propagate inputs through the encoder stack
            dec_output = target_embeddings
            # NOTE: No self-attention mask is applied at decoding, as future information is unavailable
            for layer_id in range(1, self.config.num_decoder_layers + 1):
                dec_output, memories['layer_{:d}'.format(layer_id)] = \
                    self.decoder_stack[layer_id]['self_attn'].forward(
                        dec_output, None, None, memories['layer_{:d}'.format(layer_id)])
                dec_output, _ = \
                    self.decoder_stack[layer_id]['cross_attn'].forward(dec_output, enc_output, cross_attn_mask)
                dec_output = self.decoder_stack[layer_id]['ffn'].forward(
                    dec_output)
            # Return prediction at the final time-step to be consistent with the inference pipeline
            dec_output = dec_output[:, -1, :]
            return dec_output, memories

        def _pre_process_targets(step_target_ids, current_time_step):
            """ Pre-processes target token ids before they're passed on as input to the decoder
            for auto-regressive decoding. """
            # Embed target_ids
            target_embeddings = self._embed(step_target_ids)
            signal_slice = positional_signal[:, current_time_step -
                                             1:current_time_step, :]
            target_embeddings += signal_slice
            if self.config.dropout_embeddings > 0:
                target_embeddings = tf.layers.dropout(
                    target_embeddings,
                    rate=self.config.dropout_embeddings,
                    training=self.training)
            return target_embeddings

        def _decoding_function(step_target_ids, current_time_step, memories):
            """ Generates logits for the target-side token predicted for the next-time step with auto-regression. """
            # Embed the model's predictions up to the current time-step; add positional information, mask
            target_embeddings = _pre_process_targets(step_target_ids,
                                                     current_time_step)
            # Pass encoder context and decoder embeddings through the decoder
            dec_output, memories = _decode_step(target_embeddings, memories)
            # Project decoder stack outputs and apply the soft-max non-linearity
            step_logits = self.softmax_projection_layer.project(dec_output)
            return step_logits, memories

        with tf.variable_scope(self.name):
            # Create nodes
            self._build_graph()

            positional_signal = get_positional_signal(
                self.config.translation_max_len, self.config.embedding_size,
                self.float_dtype)
            if beam_size > 0:
                # Initialize target IDs with <GO>
                initial_ids = tf.cast(tf.fill([batch_size], 1),
                                      dtype=self.int_dtype)
                initial_memories = self._get_initial_memories(
                    batch_size, beam_size=beam_size)
                output_sequences, scores = beam_search(
                    _decoding_function, initial_ids, initial_memories,
                    self.int_dtype, self.float_dtype,
                    self.config.translation_max_len, batch_size, beam_size,
                    self.embedding_layer.get_vocab_size(), 0,
                    self.config.length_normalization_alpha)

            else:
                # Initialize target IDs with <GO>
                initial_ids = tf.cast(tf.fill([batch_size, 1], 1),
                                      dtype=self.int_dtype)
                initial_memories = self._get_initial_memories(batch_size,
                                                              beam_size=1)
                output_sequences, scores = greedy_search(
                    _decoding_function,
                    initial_ids,
                    initial_memories,
                    self.int_dtype,
                    self.float_dtype,
                    self.config.translation_max_len,
                    batch_size,
                    0,
                    do_sample,
                    time_major=False)
        return output_sequences, scores