Пример #1
0
 def lmscore(self, bulist, K, LMModel, word_list=None, subvocab=None, clustermask=None, renorm=False, temperature=1):
     """
     Score function based on a pretrained RNN language model.
     """
     # note that LMModel should have the same vocab as that in Beam()
     
     ## when no candidate word list is provided, use the full vocabulary
     if word_list is None:
         word_list = self.vocab.itos
         subvocab=None
         clustermask=None
     
     if self.device is not None:
         LMModel = LMModel.cuda(device=self.device)
     LMModel.eval()
     with torch.no_grad():
         onbeam_ids = list(range(len(bulist)))
         batch_text = next(LMModel.parameters()).new_tensor([bulist[i].word_id for i in onbeam_ids],
                                                            dtype=torch.long).unsqueeze(0)
         if bulist[onbeam_ids[0]].lm_state is None:
             # 'lm_state' for the current beam is either all 'None' or all not 'None'.
             batch_hn = None
         else:
             batch_hn = (torch.cat([bulist[i].lm_state[0] for i in onbeam_ids], dim=1),
                        torch.cat([bulist[i].lm_state[1] for i in onbeam_ids], dim=1))  
         subprobs, probs, hn = prob_next_1step(LMModel, batch_text, hn=batch_hn,
                                               subvocab=subvocab, clustermask=clustermask, onscore=False, renorm=renorm,
                                               temperature=temperature)
     # convert the hidden state tuple into a list of tuples, corresponding to each beam sequence
     hn = list(zip(torch.chunk(hn[0], chunks=len(onbeam_ids), dim=1), torch.chunk(hn[1], chunks=len(onbeam_ids), dim=1)))
     lm_cum_logprob = subprobs.new_tensor([bulist[i].lm_score for i in onbeam_ids]).unsqueeze(1) + torch.log(subprobs)
     lm_cum_logprob = lm_cum_logprob.view(-1)      # this is the cumulative log probabilities
     
     ## rank and update
     if K > len(lm_cum_logprob):
         scores_sorted, ids_sorted = lm_cum_logprob.sort(descending=True)
         nexttopKK = [BeamUnit(self.vocab.stoi[word_list[i % len(word_list)]],
                              bulist[onbeam_ids[i // len(word_list)]].cur_loc,
                              m,
                              scores_sorted[m].item(),
                              bulist[onbeam_ids[i // len(word_list)]].seq_len + 1,
                              self.vocab,
                              lm_score=lm_cum_logprob[i].item(),
                              lm_state=hn[i // len(word_list)])
                      for (m, i) in enumerate(ids_sorted)]
     else:
         scores_topK, ids_topK = lm_cum_logprob.topk(K)
         nexttopKK = [BeamUnit(self.vocab.stoi[word_list[i % len(word_list)]],
                              bulist[onbeam_ids[i // len(word_list)]].cur_loc,
                              m,
                              scores_topK[m].item(),
                              bulist[onbeam_ids[i // len(word_list)]].seq_len + 1,
                              self.vocab,
                              lm_score=lm_cum_logprob[i].item(),
                              lm_state=hn[i // len(word_list)])
                     for (m, i) in enumerate(ids_topK)]
     
     endbus = []
     
     return nexttopKK, endbus
Пример #2
0
 def combscoreK_GPT2(self, bulist, K, template_vec, ge, LMModel,
               word_list=None, subvocab=None, clustermask=None,
               mono=True, normalized=True, renorm=False, temperature=1,
               bpe2word='last', alpha=0.01, stopbyLMeos=False, ifadditive=False):
     """
     Given a list of 'BeamUnit', score the next tokens from the candidate word list based on the combination of 
     sentence similarities and a pretrained language model. Output the top K scored new 'BeamUnit', in a list.
     
     Input:
         stopbyLMeos: whether to use the LM '<eos>' to solely decide end of sentence, i.e. when '<eos>' gets the highest probability from the LM, remove the generated sentence out of beam. Default: False.
     
     Note:
     'word_list', 'subvocab', and 'clustermask' should be coupled, sorted based on the full vocabulary.
     """
     
     ## when no candidate word list is provided, use the full vocabulary
     if word_list is None:
         word_list = self.vocab.itos
         subvocab=None
         clustermask=None
     
     ## calculate the similarity scores
     endbus = []                         # finished sequences 
     onbeam_ids = list(range(len(bulist)))        # keep track of sequences on beam that have not aligned to the end of the source sequence
     sim_cum_allbeam = None
     indices_allbeam = None
     states_allbeam = []
     for (i, bu) in enumerate(bulist):
         try:
             scores, indices, states = simScoreNext_GPT2(template_vec, word_list, ge,
                                                prevs_state=bu.gpt2_state,
                                                prevs_align=bu.align_loc if mono else None,
                                                normalized=normalized, bpe2word=bpe2word)
             scores_logprob = F.log_softmax(scores, dim=0)
             
             sim_cum_logprob = scores_logprob + torch.tensor(bu.sim_score, dtype=torch.float, device=self.device)
             
             sim_cum_allbeam = sim_cum_logprob if sim_cum_allbeam is None else torch.cat([sim_cum_allbeam, sim_cum_logprob])
             indices_allbeam = indices if indices_allbeam is None else torch.cat([indices_allbeam, indices])
             states_allbeam = states_allbeam + states
             
         # current sequence already aligned to the end: move out of beam
         except AssertionError as e:
             print('AssertionError:', e)
             endbus.append((i, bu.seq_len, bu.score, bu.sim_score, bu.lm_score))
             onbeam_ids.remove(i)
     
     ## calculate the RNN LM scores 
     ## note that LMModel should have the same vocab as that in Beam()
     if len(bulist) == 1 and bulist[0].word_id is None:
         # first beam step after initialization, only relying on similarity scores and no LM calculation is needed
         scores_comb = sim_cum_allbeam
         lm_cum_logprob = torch.zeros_like(sim_cum_allbeam)
         hn = [None] * len(onbeam_ids)        # at the initial step, 'onbeam_ids' wouldn't be empty anyway
     else:
         ## all sequences have aligned to the end of source sentence
         if onbeam_ids == []:
             return [], endbus
         ## do the RNN LM forward calculation
         if bulist[onbeam_ids[0]].lm_state is None:
             # 'lm_state' for the current beam is either all 'None' or all not 'None'.
             batch_hn = None
         else:
             batch_hn = (torch.cat([bulist[i].lm_state[0] for i in onbeam_ids], dim=1),
                        torch.cat([bulist[i].lm_state[1] for i in onbeam_ids], dim=1))    
         batch_text = next(LMModel.parameters()).new_tensor([bulist[i].word_id for i in onbeam_ids], dtype=torch.long).unsqueeze(0)
         subprobs, probs, hn = prob_next_1step(LMModel, batch_text, hn=batch_hn,
                                               subvocab=subvocab, clustermask=clustermask, onscore=False, renorm=renorm,
                                               temperature=temperature)
         
         ### LM predictes '<eos>' with the highest probability: move out of beam
         if stopbyLMeos:
             subprobs_max, subprobs_maxids = torch.max(subprobs, dim=1)
             eospos = (subprobs_maxids == word_list.index('<eos>')).nonzero()
             if eospos.size(0) > 0:        # number of ended sentences
                 # Note: have to delete backwards! Otherwise the indices will change.
                 oob_ids = [onbeam_ids.pop(ep.item()) for ep in eospos.squeeze(1).sort(descending=True)[0]]
                 oob_ids = sorted(oob_ids)
                 print('-' * 5 + ' <eos> predicted most likely by LM at location:', *oob_ids)
                 for i in oob_ids:
                     endbus.append((i, bulist[i].seq_len, bulist[i].score, bulist[i].sim_score, bulist[i].lm_score))
                 # all sequences have been predicted with '<eos>' having highest probabilities
                 if onbeam_ids == []:
                     return [], endbus
                 else:
                     remainpos = [i for i in range(len(subprobs)) if i not in eospos]
                     subprobs = subprobs[remainpos, :]
                     probs = probs[remainpos, :]
                     hn = (hn[0][:, remainpos, :], hn[1][:, remainpos, :])
                     remainpos_simallbeam = []
                     for rp in remainpos:
                         remainpos_simallbeam += list(range(len(word_list) * rp, len(word_list) * (rp + 1)))
                     sim_cum_allbeam = sim_cum_allbeam[remainpos_simallbeam]
                     indices_allbeam = indices_allbeam[remainpos_simallbeam]
                     states_allbeam = [s for (i, s) in enumerate(states_allbeam) if i in remainpos_simallbeam]
                     
         # convert the hidden state tuple into a list of tuples, corresponding to each beam sequence
         hn = list(zip(torch.chunk(hn[0], chunks=len(onbeam_ids), dim=1), torch.chunk(hn[1], chunks=len(onbeam_ids), dim=1)))
         lm_cum_logprob = subprobs.new_tensor([bulist[i].lm_score for i in onbeam_ids]).unsqueeze(1) + torch.log(subprobs)
         lm_cum_logprob = lm_cum_logprob.view(-1)      # this is the cumulative log probabilities
         
         if ifadditive:
             scores_comb = torch.log((1 - alpha) * torch.exp(sim_cum_allbeam) + alpha * torch.exp(lm_cum_logprob))
         else:
             scores_comb = (1 - alpha) * sim_cum_allbeam + alpha * lm_cum_logprob
     
     ## rank and update
     if K > len(scores_comb):
         scores_comb_sorted, ids_sorted = scores_comb.sort(descending=True)
         nexttopKK = [BeamUnit(self.vocab.stoi[word_list[i % len(word_list)]],
                              bulist[onbeam_ids[i // len(word_list)]].cur_loc,
                              m,
                              scores_comb_sorted[m].item(),
                              bulist[onbeam_ids[i // len(word_list)]].seq_len + 1,
                              self.vocab,
                              sim_score=sim_cum_allbeam[i].item(),
                              lm_score=lm_cum_logprob[i].item(),
                              lm_state=hn[i // len(word_list)],
                              gpt2_state=states_allbeam[i],
                              align_loc=indices_allbeam[i])
                      for (m, i) in enumerate(ids_sorted)]
     else:
         scores_comb_topK, ids_topK = scores_comb.topk(K)
         nexttopKK = [BeamUnit(self.vocab.stoi[word_list[i % len(word_list)]],
                              bulist[onbeam_ids[i // len(word_list)]].cur_loc,
                              m,
                              scores_comb_topK[m].item(),
                              bulist[onbeam_ids[i // len(word_list)]].seq_len + 1,
                              self.vocab,
                              sim_score=sim_cum_allbeam[i].item(),
                              lm_score=lm_cum_logprob[i].item(),
                              lm_state=hn[i // len(word_list)],
                              gpt2_state=states_allbeam[i],
                              align_loc=indices_allbeam[i])
                     for (m, i) in enumerate(ids_topK)]
     
     return nexttopKK, endbus