def lmscore(self, bulist, K, LMModel, word_list=None, subvocab=None, clustermask=None, renorm=False, temperature=1): """ Score function based on a pretrained RNN language model. """ # note that LMModel should have the same vocab as that in Beam() ## when no candidate word list is provided, use the full vocabulary if word_list is None: word_list = self.vocab.itos subvocab=None clustermask=None if self.device is not None: LMModel = LMModel.cuda(device=self.device) LMModel.eval() with torch.no_grad(): onbeam_ids = list(range(len(bulist))) batch_text = next(LMModel.parameters()).new_tensor([bulist[i].word_id for i in onbeam_ids], dtype=torch.long).unsqueeze(0) if bulist[onbeam_ids[0]].lm_state is None: # 'lm_state' for the current beam is either all 'None' or all not 'None'. batch_hn = None else: batch_hn = (torch.cat([bulist[i].lm_state[0] for i in onbeam_ids], dim=1), torch.cat([bulist[i].lm_state[1] for i in onbeam_ids], dim=1)) subprobs, probs, hn = prob_next_1step(LMModel, batch_text, hn=batch_hn, subvocab=subvocab, clustermask=clustermask, onscore=False, renorm=renorm, temperature=temperature) # convert the hidden state tuple into a list of tuples, corresponding to each beam sequence hn = list(zip(torch.chunk(hn[0], chunks=len(onbeam_ids), dim=1), torch.chunk(hn[1], chunks=len(onbeam_ids), dim=1))) lm_cum_logprob = subprobs.new_tensor([bulist[i].lm_score for i in onbeam_ids]).unsqueeze(1) + torch.log(subprobs) lm_cum_logprob = lm_cum_logprob.view(-1) # this is the cumulative log probabilities ## rank and update if K > len(lm_cum_logprob): scores_sorted, ids_sorted = lm_cum_logprob.sort(descending=True) nexttopKK = [BeamUnit(self.vocab.stoi[word_list[i % len(word_list)]], bulist[onbeam_ids[i // len(word_list)]].cur_loc, m, scores_sorted[m].item(), bulist[onbeam_ids[i // len(word_list)]].seq_len + 1, self.vocab, lm_score=lm_cum_logprob[i].item(), lm_state=hn[i // len(word_list)]) for (m, i) in enumerate(ids_sorted)] else: scores_topK, ids_topK = lm_cum_logprob.topk(K) nexttopKK = [BeamUnit(self.vocab.stoi[word_list[i % len(word_list)]], bulist[onbeam_ids[i // len(word_list)]].cur_loc, m, scores_topK[m].item(), bulist[onbeam_ids[i // len(word_list)]].seq_len + 1, self.vocab, lm_score=lm_cum_logprob[i].item(), lm_state=hn[i // len(word_list)]) for (m, i) in enumerate(ids_topK)] endbus = [] return nexttopKK, endbus
def combscoreK_GPT2(self, bulist, K, template_vec, ge, LMModel, word_list=None, subvocab=None, clustermask=None, mono=True, normalized=True, renorm=False, temperature=1, bpe2word='last', alpha=0.01, stopbyLMeos=False, ifadditive=False): """ Given a list of 'BeamUnit', score the next tokens from the candidate word list based on the combination of sentence similarities and a pretrained language model. Output the top K scored new 'BeamUnit', in a list. Input: stopbyLMeos: whether to use the LM '<eos>' to solely decide end of sentence, i.e. when '<eos>' gets the highest probability from the LM, remove the generated sentence out of beam. Default: False. Note: 'word_list', 'subvocab', and 'clustermask' should be coupled, sorted based on the full vocabulary. """ ## when no candidate word list is provided, use the full vocabulary if word_list is None: word_list = self.vocab.itos subvocab=None clustermask=None ## calculate the similarity scores endbus = [] # finished sequences onbeam_ids = list(range(len(bulist))) # keep track of sequences on beam that have not aligned to the end of the source sequence sim_cum_allbeam = None indices_allbeam = None states_allbeam = [] for (i, bu) in enumerate(bulist): try: scores, indices, states = simScoreNext_GPT2(template_vec, word_list, ge, prevs_state=bu.gpt2_state, prevs_align=bu.align_loc if mono else None, normalized=normalized, bpe2word=bpe2word) scores_logprob = F.log_softmax(scores, dim=0) sim_cum_logprob = scores_logprob + torch.tensor(bu.sim_score, dtype=torch.float, device=self.device) sim_cum_allbeam = sim_cum_logprob if sim_cum_allbeam is None else torch.cat([sim_cum_allbeam, sim_cum_logprob]) indices_allbeam = indices if indices_allbeam is None else torch.cat([indices_allbeam, indices]) states_allbeam = states_allbeam + states # current sequence already aligned to the end: move out of beam except AssertionError as e: print('AssertionError:', e) endbus.append((i, bu.seq_len, bu.score, bu.sim_score, bu.lm_score)) onbeam_ids.remove(i) ## calculate the RNN LM scores ## note that LMModel should have the same vocab as that in Beam() if len(bulist) == 1 and bulist[0].word_id is None: # first beam step after initialization, only relying on similarity scores and no LM calculation is needed scores_comb = sim_cum_allbeam lm_cum_logprob = torch.zeros_like(sim_cum_allbeam) hn = [None] * len(onbeam_ids) # at the initial step, 'onbeam_ids' wouldn't be empty anyway else: ## all sequences have aligned to the end of source sentence if onbeam_ids == []: return [], endbus ## do the RNN LM forward calculation if bulist[onbeam_ids[0]].lm_state is None: # 'lm_state' for the current beam is either all 'None' or all not 'None'. batch_hn = None else: batch_hn = (torch.cat([bulist[i].lm_state[0] for i in onbeam_ids], dim=1), torch.cat([bulist[i].lm_state[1] for i in onbeam_ids], dim=1)) batch_text = next(LMModel.parameters()).new_tensor([bulist[i].word_id for i in onbeam_ids], dtype=torch.long).unsqueeze(0) subprobs, probs, hn = prob_next_1step(LMModel, batch_text, hn=batch_hn, subvocab=subvocab, clustermask=clustermask, onscore=False, renorm=renorm, temperature=temperature) ### LM predictes '<eos>' with the highest probability: move out of beam if stopbyLMeos: subprobs_max, subprobs_maxids = torch.max(subprobs, dim=1) eospos = (subprobs_maxids == word_list.index('<eos>')).nonzero() if eospos.size(0) > 0: # number of ended sentences # Note: have to delete backwards! Otherwise the indices will change. oob_ids = [onbeam_ids.pop(ep.item()) for ep in eospos.squeeze(1).sort(descending=True)[0]] oob_ids = sorted(oob_ids) print('-' * 5 + ' <eos> predicted most likely by LM at location:', *oob_ids) for i in oob_ids: endbus.append((i, bulist[i].seq_len, bulist[i].score, bulist[i].sim_score, bulist[i].lm_score)) # all sequences have been predicted with '<eos>' having highest probabilities if onbeam_ids == []: return [], endbus else: remainpos = [i for i in range(len(subprobs)) if i not in eospos] subprobs = subprobs[remainpos, :] probs = probs[remainpos, :] hn = (hn[0][:, remainpos, :], hn[1][:, remainpos, :]) remainpos_simallbeam = [] for rp in remainpos: remainpos_simallbeam += list(range(len(word_list) * rp, len(word_list) * (rp + 1))) sim_cum_allbeam = sim_cum_allbeam[remainpos_simallbeam] indices_allbeam = indices_allbeam[remainpos_simallbeam] states_allbeam = [s for (i, s) in enumerate(states_allbeam) if i in remainpos_simallbeam] # convert the hidden state tuple into a list of tuples, corresponding to each beam sequence hn = list(zip(torch.chunk(hn[0], chunks=len(onbeam_ids), dim=1), torch.chunk(hn[1], chunks=len(onbeam_ids), dim=1))) lm_cum_logprob = subprobs.new_tensor([bulist[i].lm_score for i in onbeam_ids]).unsqueeze(1) + torch.log(subprobs) lm_cum_logprob = lm_cum_logprob.view(-1) # this is the cumulative log probabilities if ifadditive: scores_comb = torch.log((1 - alpha) * torch.exp(sim_cum_allbeam) + alpha * torch.exp(lm_cum_logprob)) else: scores_comb = (1 - alpha) * sim_cum_allbeam + alpha * lm_cum_logprob ## rank and update if K > len(scores_comb): scores_comb_sorted, ids_sorted = scores_comb.sort(descending=True) nexttopKK = [BeamUnit(self.vocab.stoi[word_list[i % len(word_list)]], bulist[onbeam_ids[i // len(word_list)]].cur_loc, m, scores_comb_sorted[m].item(), bulist[onbeam_ids[i // len(word_list)]].seq_len + 1, self.vocab, sim_score=sim_cum_allbeam[i].item(), lm_score=lm_cum_logprob[i].item(), lm_state=hn[i // len(word_list)], gpt2_state=states_allbeam[i], align_loc=indices_allbeam[i]) for (m, i) in enumerate(ids_sorted)] else: scores_comb_topK, ids_topK = scores_comb.topk(K) nexttopKK = [BeamUnit(self.vocab.stoi[word_list[i % len(word_list)]], bulist[onbeam_ids[i // len(word_list)]].cur_loc, m, scores_comb_topK[m].item(), bulist[onbeam_ids[i // len(word_list)]].seq_len + 1, self.vocab, sim_score=sim_cum_allbeam[i].item(), lm_score=lm_cum_logprob[i].item(), lm_state=hn[i // len(word_list)], gpt2_state=states_allbeam[i], align_loc=indices_allbeam[i]) for (m, i) in enumerate(ids_topK)] return nexttopKK, endbus