Exemplo n.º 1
0
    def _one_step_forward(
        self,
        decoder_input_ids=None,
        encoder_hidden_states=None,
        encoder_input_mask=None,
        decoder_mems_list=None,
        lm_mems_list=None,
        pos=0,
    ):

        nmt_log_probs, decoder_mems_list = super()._one_step_forward(
            decoder_input_ids,
            encoder_hidden_states,
            encoder_input_mask,
            decoder_mems_list,
            pos,
        )
        input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float()
        lm_hidden_states = self.language_model.encoder.embedding.forward(
            decoder_input_ids, start_pos=pos)

        lm_mems_list = self.language_model.encoder.encoder.forward(
            lm_hidden_states,
            input_mask,
            lm_mems_list,
            return_mems=True,
        )
        lm_log_probs = self.language_model.log_softmax.forward(
            hidden_states=lm_mems_list[-1][:, -1:])

        log_probs = nmt_log_probs + self.fusion_coef * lm_log_probs

        return log_probs, decoder_mems_list, lm_mems_list
Exemplo n.º 2
0
 def _one_step_forward_lm(self, decoder_input_ids=None, lm_mems_list=None, pos=0):
     input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float()
     lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos)
     lm_mems_list = self.language_model.encoder.encoder.forward(
         lm_hidden_states, input_mask, lm_mems_list, return_mems=True,
     )
     lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:])
     return lm_log_probs, lm_mems_list
Exemplo n.º 3
0
    def _one_step_forward(
        self,
        ensemble_index,
        decoder_input_ids=None,
        encoder_hidden_states=None,
        encoder_input_mask=None,
        decoder_mems_list=None,
        pos=0,
    ):
        """
        One step of autoregressive output generation for one particular model.

        Args:
            decoder_input_ids: starting sequence of tokens to generate from;
                if None, generation will start from a batch of <bos> tokens
            encoder_hidden_states: output of the encoder for conditional
                sequence generation; if None, generator will use unconditional
                mode (e.g., language modeling)
            encoder_input_mask: input mask used in the encoder
            decoder_mems_list: list of size num_layers with cached activations
                of sequence (x[1], ..., x[k-1]) for fast generation of x[k]
            pos: starting position in positional encoding
        """

        decoder_hidden_states = self.embeddings[ensemble_index].forward(
            decoder_input_ids, start_pos=pos)
        decoder_input_mask = mask_padded_tokens(decoder_input_ids,
                                                self.pad).float()

        if encoder_hidden_states is not None:
            decoder_mems_list = self.decoders[ensemble_index].forward(
                decoder_hidden_states,
                decoder_input_mask,
                encoder_hidden_states,
                encoder_input_mask,
                decoder_mems_list,
                return_mems=True,
            )
        else:
            decoder_mems_list = self.decoders[ensemble_index].forward(
                decoder_hidden_states,
                decoder_input_mask,
                decoder_mems_list,
                return_mems=True)
        log_probs = self.log_softmaxes[ensemble_index].forward(
            hidden_states=decoder_mems_list[-1][:, -1:])
        return log_probs, decoder_mems_list