def forward( self, input_ids, position_ids, attention_mask, beam_select_idx, input_log_probs, input_unfinished_sents, prev_step_results, prev_step_scores, *past, ): input_ids = input_ids.view(self.config.batch_size, -1, input_ids.size(-1)) past = [ past[i].index_select(1, beam_select_idx[0]) for i in range(len(past)) ] result = super().forward( input_ids.view(-1, input_ids.size(-1)), position_ids=position_ids, attention_mask=attention_mask, past_key_values=past, return_dict=False, ) logits_flat, present_flat = MyGPT2Model.post_process( result, self.config.n_layer) next_token_logits = logits_flat[:, -1].view(self.config.batch_size, -1, logits_flat.size(-1)) next_token_log_probs = torch.log_softmax(next_token_logits, dim=-1) next_token_log_probs, next_token_ids = torch.topk( next_token_log_probs, self.config.beam_size, dim=-1, largest=True, sorted=True) # finished sentences is always with EOS, and all but the first one has -inf, so that they will be automatically dropped in the round of beam search. finished_sents = ~input_unfinished_sents next_token_log_probs.masked_fill_(finished_sents.unsqueeze(-1), -numpy.inf) next_token_log_probs[..., 0].masked_fill_(finished_sents, 0) next_token_ids.masked_fill_(finished_sents.unsqueeze(-1), self.config.eos_token_id) output_log_probs = input_log_probs.unsqueeze(-1) + next_token_log_probs # select N sequences from beams of each input, sorted by sequence probability output_log_probs = output_log_probs.view( self.config.batch_size, -1) # shape=(batch, beam_size^2) output_log_probs, selected_index_flat = output_log_probs.topk( self.config.beam_size, dim=-1, largest=True, sorted=True) # output shape=(batch, beam_size) # select the correspondent sentences/next tokens selected_input_seq = selected_index_flat // self.config.beam_size next_token_ids = next_token_ids.view(self.config.batch_size, -1).gather( -1, selected_index_flat) prev_step_results = prev_step_results.view(self.config.batch_size, -1, prev_step_results.size(-1)) prev_step_results = prev_step_results.gather( 1, selected_input_seq.unsqueeze(-1).repeat( 1, 1, prev_step_results.size(-1))) output_unfinished_sents = input_unfinished_sents.gather( 1, selected_input_seq) output_unfinished_sents = (output_unfinished_sents & next_token_ids.ne( self.config.eos_token_id)) # get the next full input_ids current_step_results = torch.cat( [prev_step_results, next_token_ids.unsqueeze(-1)], dim=-1).contiguous() prev_step_scores = prev_step_scores.view(self.config.batch_size, -1, prev_step_scores.size(-1)) prev_step_scores = prev_step_scores.gather( 1, selected_input_seq.unsqueeze(-1).repeat(1, 1, prev_step_scores.size(-1))) current_step_scores = torch.cat( [prev_step_scores, output_log_probs.unsqueeze(-1)], dim=-1).contiguous() return ( next_token_ids, present_flat, selected_input_seq, output_log_probs, output_unfinished_sents, current_step_results.view( self.config.batch_size * self.config.beam_size, -1), current_step_scores.view( self.config.batch_size * self.config.beam_size, -1), )
def forward( self, input_ids, beam_select_idx, input_log_probs, input_unfinished_sents, prev_step_scores, *past, ): input_ids = input_ids.view(self.config.batch_size, -1, input_ids.size(-1)) input_num_seq_per_sample = input_ids.size(1) input_ids_unfinished_flat = self.collapse_first_two_dims( input_ids).index_select( 0, input_unfinished_sents.view(-1).nonzero( as_tuple=False).view(-1)) if self.config.ignore_eos: attention_mask = (input_ids_unfinished_flat != self.config.eos_token_id).float() else: attention_mask = torch.ones( input_ids_unfinished_flat.shape).float().to( input_ids_unfinished_flat.device) position_ids = (attention_mask.cumsum(-1) - 1).clamp(min=0).long() if past: last_seq_len = past[0].size(-2) input_ids_unfinished_flat = input_ids_unfinished_flat[:, last_seq_len:] position_ids = position_ids[:, last_seq_len:] unfinished_index_relative_to_last_unfinished = beam_select_idx.view( -1)[input_unfinished_sents.view(-1).nonzero( as_tuple=False).view(-1)] past = tuple([ p.index_select(1, unfinished_index_relative_to_last_unfinished) for p in past ]) result = super().forward( input_ids_unfinished_flat.view(-1, input_ids_unfinished_flat.size(-1)), position_ids=position_ids, attention_mask=attention_mask, past_key_values=past, return_dict=False, ) logits_flat, present_flat = MyGPT2Model.post_process( result, self.config.n_layer) # insert finished sequence back to form a square shape of (batch_size, beam_size) next_token_logits = logits_flat.new_zeros(input_ids.size()[:2] + (logits_flat.size(-1), )) next_token_logits.index_fill_( 2, torch.LongTensor([self.config.eos_token_id]).to(input_ids.device), -BIG_NEG) next_token_logits.masked_scatter_( input_unfinished_sents.unsqueeze(-1).expand_as(next_token_logits), logits_flat[:, -1]) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) if self.config.repetition_penalty != 1.0: _pen = next_token_logits.gather(2, input_ids) _pen = torch.where(_pen > 0, _pen / self.config.repetition_penalty, _pen * self.config.repetition_penalty) next_token_logits.scatter_(2, input_ids, _pen) # similar way to encourage short sentence if self.config.length_penalty != 1.0: _pen = next_token_logits[..., self.config.eos_token_id] # if eos > 0, increase it, else, decrease it. _pen = torch.where(_pen > 0, _pen * self.config.length_penalty, _pen / self.config.length_penalty) next_token_logits[..., self.config.eos_token_id] = _pen if self.config.temperature != 1.0: next_token_logits = next_token_logits / self.config.temperature # exclude excluded_token_ids if self.config.excluded_token_ids is not None: next_token_logits.index_fill_( 2, self.config.excluded_token_ids.to(next_token_logits.device), BIG_NEG) # batch x beams/sequences x vocab_size next_token_log_probs = torch.log_softmax(next_token_logits, dim=-1) if self.config.do_sample: vocab_size = next_token_log_probs.size(-1) _next_token_log_probs = self.top_k_top_p_filtering( next_token_log_probs.view(-1, vocab_size), top_k=self.config.do_sample_top_k, top_p=self.config.do_sample_top_p) next_token_ids = torch.multinomial( _next_token_log_probs.exp(), num_samples=self.config.beam_size, replacement=False) next_token_ids = next_token_ids.view(self.config.batch_size, input_num_seq_per_sample, -1) next_token_log_probs = next_token_log_probs.gather( -1, next_token_ids) else: next_token_log_probs, next_token_ids = torch.topk( next_token_log_probs, self.config.beam_size, dim=-1, largest=True, sorted=True) output_log_probs = input_log_probs.unsqueeze(-1) + next_token_log_probs # select N sequences from beams of each input, sorted by sequence probability output_log_probs = output_log_probs.view( self.config.batch_size, -1) # shape=(batch, beam_size^2) output_log_probs, selected_index_flat = output_log_probs.topk( self.config.beam_size, dim=-1, largest=True, sorted=True) # output shape=(batch, beam_size) # select the correspondent sentences/next tokens selected_input_seq = selected_index_flat // self.config.beam_size next_token_ids = next_token_ids.view(self.config.batch_size, -1).gather( -1, selected_index_flat) prev_step_results = input_ids.view(self.config.batch_size, -1, input_ids.size(-1)).contiguous() prev_step_results = prev_step_results.gather( 1, selected_input_seq.unsqueeze(-1).expand( selected_input_seq.shape + (prev_step_results.size(-1), ))) output_unfinished_sents = input_unfinished_sents.gather( 1, selected_input_seq) output_unfinished_sents = (output_unfinished_sents & next_token_ids.ne( self.config.eos_token_id)) current_step_results = torch.cat( [prev_step_results, next_token_ids.unsqueeze(-1)], dim=-1).contiguous() prev_step_scores = prev_step_scores.view(self.config.batch_size, -1, prev_step_scores.size(-1)) prev_step_scores = prev_step_scores.gather( 1, selected_input_seq.unsqueeze(-1).expand( selected_input_seq.shape + (prev_step_scores.size(-1), ))) current_step_scores = torch.cat( [prev_step_scores, output_log_probs.unsqueeze(-1)], dim=-1).contiguous() # For next past state index_relative_to_last_unfinished = ( input_unfinished_sents.view(-1).float().cumsum(-1) - 1).clamp(min=0).long().reshape_as(input_unfinished_sents).gather( 1, selected_input_seq) return ( current_step_results.view( self.config.batch_size * self.config.beam_size, -1), present_flat, index_relative_to_last_unfinished, output_log_probs, output_unfinished_sents, current_step_scores.view( self.config.batch_size * self.config.beam_size, -1), )