def _one_step_forward( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, decoder_mems_list=None, lm_mems_list=None, pos=0, ): nmt_log_probs, decoder_mems_list = super()._one_step_forward( decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos, ) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward( decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( lm_hidden_states, input_mask, lm_mems_list, return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward( hidden_states=lm_mems_list[-1][:, -1:]) log_probs = nmt_log_probs + self.fusion_coef * lm_log_probs return log_probs, decoder_mems_list, lm_mems_list
def _one_step_forward_lm(self, decoder_input_ids=None, lm_mems_list=None, pos=0): input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( lm_hidden_states, input_mask, lm_mems_list, return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) return lm_log_probs, lm_mems_list
def _one_step_forward( self, ensemble_index, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, decoder_mems_list=None, pos=0, ): """ One step of autoregressive output generation for one particular model. Args: decoder_input_ids: starting sequence of tokens to generate from; if None, generation will start from a batch of <bos> tokens encoder_hidden_states: output of the encoder for conditional sequence generation; if None, generator will use unconditional mode (e.g., language modeling) encoder_input_mask: input mask used in the encoder decoder_mems_list: list of size num_layers with cached activations of sequence (x[1], ..., x[k-1]) for fast generation of x[k] pos: starting position in positional encoding """ decoder_hidden_states = self.embeddings[ensemble_index].forward( decoder_input_ids, start_pos=pos) decoder_input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() if encoder_hidden_states is not None: decoder_mems_list = self.decoders[ensemble_index].forward( decoder_hidden_states, decoder_input_mask, encoder_hidden_states, encoder_input_mask, decoder_mems_list, return_mems=True, ) else: decoder_mems_list = self.decoders[ensemble_index].forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True) log_probs = self.log_softmaxes[ensemble_index].forward( hidden_states=decoder_mems_list[-1][:, -1:]) return log_probs, decoder_mems_list