def _fast_translate_batch(self, batch, max_length, min_length=0): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert not self.dump_beam beam_size = self.beam_size batch_size = batch.batch_size src = batch.src segs = batch.segs mask_src = batch.mask_src src_features = self.model.bert(src, segs, mask_src) dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True) device = src_features.device # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) src_features = tile(src_features, beam_size, dim=0) batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device) alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device) # Give full probability to the first beam on the first step. topk_log_probs = (torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 results = {} results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["gold_score"] = [0] * batch_size results["batch"] = batch for step in range(max_length): decoder_input = alive_seq[:, -1].view(1, -1) # Decoder forward. decoder_input = decoder_input.transpose(0, 1) dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step) # Generator forward. log_probs = self.generator.forward( dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) if step < min_length: log_probs[:, self.end_token] = -1e20 # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty if (self.args.block_trigram): cur_len = alive_seq.size(1) if (cur_len > 3): for i in range(alive_seq.size(0)): fail = False words = [int(w) for w in alive_seq[i]] words = [self.vocab.ids_to_tokens[w] for w in words] words = ' '.join(words).replace(' ##', '').split() if (len(words) <= 3): continue trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)] trigram = tuple(trigrams[-1]) if trigram in trigrams[:-1]: fail = True if fail: curr_scores[i] = -10e20 curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat([ alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1) ], -1) is_finished = topk_ids.eq(self.end_token) if step + 1 == max_length: is_finished.fill_(1) # End condition is top beam is finished. end_condition = is_finished[:, 0].eq(1) # Save finished hypotheses. if is_finished.any(): predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: hypotheses[b].append( (topk_scores[i, j], predictions[i, j, 1:])) # If the batch reached the end, save the n_best hypotheses. if end_condition[i]: best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) score, pred = best_hyp[0] results["scores"][b].append(score) results["predictions"][b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) # Reorder states. select_indices = batch_index.view(-1) src_features = src_features.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) return results
def _fast_translate_batch(self, batch, memory_bank, max_length, memory_mask=None, min_length=2, beam_size=3, hidden_state=None, copy_attn=False): batch_size = memory_bank.size(0) if self.args.decoder == "rnn": dec_states = self.decoder.init_decoder_state( batch.src, memory_bank, hidden_state) else: dec_states = self.decoder.init_decoder_state(batch.src, memory_bank, with_cache=True) # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) memory_bank = tile(memory_bank, beam_size, dim=0) memory_mask = tile(memory_mask, beam_size, dim=0) if copy_attn: src_map = tile(batch.src_map, beam_size, dim=0) else: src_map = None batch_offset = torch.arange(batch_size, dtype=torch.long, device=self.device) beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=self.device) alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=self.device) # Give full probability to the first beam on the first step. topk_log_probs = (torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=self.device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 results = [[] for _ in range(batch_size)] # noqa: F812 for step in range(max_length): # Decoder forward. decoder_input = alive_seq[:, -1].view(1, -1) decoder_input = decoder_input.transpose(0, 1) if self.args.decoder == "rnn": dec_out, dec_states, attn = self.decoder( decoder_input, memory_bank, dec_states, step=step, memory_masks=memory_mask) else: dec_out, dec_states, attn = self.decoder( decoder_input, memory_bank, dec_states, step=step, memory_masks=memory_mask, requires_att=copy_attn) # Generator forward. if copy_attn: probs = self.generator( dec_out.transpose(0, 1).squeeze(0), attn['copy'].transpose(0, 1).squeeze(0), src_map) probs = collapse_copy_scores( probs.unsqueeze(1), batch, self.vocab, tile(batch_offset, beam_size, dim=0)) log_probs = probs.squeeze(1)[:, :self.vocab_size].log() else: log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) if step < min_length: log_probs[:, self.end_token] = -1e20 if self.args.block_trigram: cur_len = alive_seq.size(1) if (cur_len > 3): for i in range(alive_seq.size(0)): fail = False words = [int(w) for w in alive_seq[i]] if (len(words) <= 3): continue trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)] trigram = tuple(trigrams[-1]) if trigram in trigrams[:-1]: fail = True if fail: log_probs[i] = -1e20 # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.args.alpha length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat([ alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1) ], -1) is_finished = topk_ids.eq(self.end_token) if step + 1 == max_length: is_finished.fill_(1) # End condition is top beam is finished. end_condition = is_finished[:, 0].eq(1) # Save finished hypotheses. if is_finished.any(): predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: hypotheses[b].append( (topk_scores[i, j], predictions[i, j, 1:])) # If the batch reached the end, save the n_best hypotheses. if end_condition[i]: best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) _, pred = best_hyp[0] results[b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) # Reorder states. select_indices = batch_index.view(-1) if memory_bank is not None: memory_bank = memory_bank.index_select(0, select_indices) if memory_mask is not None: memory_mask = memory_mask.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) results = [t[0] for t in results] return results
def _fast_translate_batch(self, batch, max_length, min_length=0): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert not self.dump_beam beam_size = self.beam_size batch_size = batch.batch_size src = batch.src segs = batch.segs mask_src = batch.mask_src # print("mask_tgt = ", batch.mask_tgt.size()) # print(batch.mask_tgt) # print("tgt_segs = ", batch.tgt_segs.size()) # print(batch.tgt_segs) # print("tgt = ", batch.tgt.size()) # print(batch.tgt) # exit() if self.args.bart: src_features = self.model.bert.model.encoder( input_ids=src, attention_mask=mask_src)[0] # print("output = ", self.model。) # past_key_values = # print("src_features = ", src_features.size()) # print(src_features) else: src_features = self.model.bert(src, segs, mask_src) dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True) device = src_features.device # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) src_features = tile(src_features, beam_size, dim=0) if self.args.bart: # print("src = ", src.size()) # print(src) # print("mask_src0 = ", mask_src.size()) # print(mask_src) mask_src = tile(mask_src, beam_size, dim=0).byte() bart_dec_cache = None batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device) alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device) if self.args.language_limit: language_limit = torch.Tensor(json.load(open( self.args.tgt_mask))).long().cuda() # language_seg language_segs = torch.full([batch_size * beam_size, 1], 0, dtype=torch.long, device=device) # Give full probability to the first beam on the first step. topk_log_probs = (torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 results = {} results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["gold_score"] = [0] * batch_size results["batch"] = batch for step in range(max_length): # print("alive_seq = ", alive_seq.size()) # print(alive_seq) # print("language_segs = ", language_segs.size()) # print(language_segs) decoder_input = alive_seq[:, -1].view(1, -1) # language_seg和alive_seq始终保持一致,alive_seq ,language也选择。 # 选择完之后,看这次的token是否为5,做一个01矩阵,新的language直接等于1的位置取反,那就取反 # 先给取反的位置赋0,(算取反向量,乘以01矩阵)加到原有的上边即可。 # 最后concat到一起 decoder_seg_input = language_segs[:, -1].view(1, -1) # Decoder forward. decoder_input = decoder_input.transpose(0, 1) decoder_seg_input = decoder_seg_input.transpose(0, 1) if self.args.bart: tgt_mask = torch.zeros(decoder_input.size()).byte().cuda() # causal_mask = (1 - _get_attn_subsequent_mask(tgt_mask.size(1)).float().cuda()) # * float("-inf")).cuda() causal_mask = torch.triu( torch.zeros(tgt_mask.size(1), tgt_mask.size(1)).float().fill_( float("-inf")).float(), 1).cuda() # print(src_features.size()) # print(mask_src.size()) # print(mask_src) # model_inputs = self.bert.model.prepare_inputs_for_generation( # decoder_input, past=src_features, attention_mask=tgt_mask, use_cache=True # ) dec_output = self.model.bert.model.decoder( input_ids=alive_seq, encoder_hidden_states=src_features, encoder_padding_mask=mask_src, decoder_padding_mask=tgt_mask, decoder_causal_mask=causal_mask, decoder_cached_states=bart_dec_cache, use_cache=True) dec_out = dec_output[0] bart_dec_cache = dec_output[1][1] # print(bart_dec_cache) # print('dec_out = ') # print(dec_out[0]) # exit() elif self.args.predict_first_language and self.args.multi_task: dec_out, dec_states = self.model.decoder_monolingual( decoder_input, src_features, dec_states, step=step, tgt_segs=decoder_seg_input) else: dec_out, dec_states = self.model.decoder( decoder_input, src_features, dec_states, step=step, tgt_segs=decoder_seg_input) # Generator forward. log_probs = self.generator.forward( dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) if self.args.language_limit: mask_language_limit = torch.zeros(log_probs.size()).cuda() mask_language_limit.index_fill_(1, language_limit, 1) # 如果两个语言拼接,那么生成第二语言才限制 如果是直接跨语言的话,那从最开始就要限制 if self.args.predict_2language: mask_language_limit = mask_language_limit.long( ) * decoder_seg_input mask_language_limit = mask_language_limit + ( 1 - decoder_seg_input) else: mask_language_limit = mask_language_limit.long( ) * torch.ones(decoder_seg_input.size()).long().cuda() # 这里,把除了备选位置的都赋为负无穷 log_probs.masked_fill_((1 - mask_language_limit).byte(), -1e20) if step < min_length: log_probs[:, self.end_token] = -1e20 # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty if (self.args.block_trigram): cur_len = alive_seq.size(1) if (cur_len > 3): for i in range(alive_seq.size(0)): fail = False words = [int(w) for w in alive_seq[i]] if self.args.bart: words = [self.vocab.decoder[w] for w in words] else: words = [ self.vocab.ids_to_tokens[w] for w in words ] words = ' '.join(words).replace(' ##', '').split() if (len(words) <= 3): continue trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)] trigram = tuple(trigrams[-1]) if trigram in trigrams[:-1]: fail = True if fail: curr_scores[i] = -10e20 curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat([ alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1) ], -1) # print("topk_ids = ", topk_ids) # print("end_token = ", self.end_token) if self.args.bart: is_finished = topk_ids.eq(2) else: is_finished = topk_ids.eq(self.end_token) # 是否是第二个语言 # 先把最后一部分取出来,然后把新加入5的填上1,再拼接起来。 is_languaged = topk_ids.eq(5) # is_languaged = alive_seq[:, -2].unsqueeze(0).eq(5) language_segs = language_segs.index_select(0, select_indices) last_segs = language_segs[:, -1] tmp_seg = last_segs.masked_fill(is_languaged, 1) language_segs = torch.cat([language_segs, tmp_seg.view(-1, 1)], -1) # print("is_finished = ", is_finished) if step + 1 == max_length: is_finished.fill_(1) # End condition is top beam is finished. end_condition = is_finished[:, 0].eq(1) # Save finished hypotheses. if is_finished.any(): predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: hypotheses[b].append( (topk_scores[i, j], predictions[i, j, 1:])) # If the batch reached the end, save the n_best hypotheses. if end_condition[i]: best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) score, pred = best_hyp[0] results["scores"][b].append(score) results["predictions"][b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) # Reorder states. select_indices = batch_index.view(-1) src_features = src_features.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) return results
def _fast_translate_batch(self, batch, max_length, min_length=0): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert not self.dump_beam beam_size = self.beam_size batch_size = batch.batch_size src = batch.src segs = batch.segs mask_src = batch.mask_src clss = batch.clss mask_cls = batch.mask_cls # print("src ", src.size()) # print(src) # print("mask src", mask_src.size()) # print(mask_src) if self.args.oracle: labels = batch.src_sent_labels ext_scores = ((labels.float() + 0.1) / 1.3) * mask_cls.float() else: ext_scores, _, sent_vec = self.model.extractor( src, segs, clss, mask_src, mask_cls) # print("ext_scores : ", ext_scores.size()) # print(ext_scores) src_features = self.model.abstractor.bert(src, segs, mask_src) dec_states = self.model.abstractor.decoder.init_decoder_state( src, src_features, with_cache=True) device = src_features.device # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) src_features = tile(src_features, beam_size, dim=0) ext_scores = tile(ext_scores, beam_size, dim=0) mask_src = tile(mask_src, beam_size, dim=0) mask_cls = tile(mask_cls, beam_size, dim=0) clss = tile(clss, beam_size, dim=0) src = tile(src, beam_size, dim=0) batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device) alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device) # Give full probability to the first beam on the first step. topk_log_probs = (torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 results = {} results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["gold_score"] = [0] * batch_size results["batch"] = batch for step in range(max_length): decoder_input = alive_seq[:, -1].view(1, -1) # print("alive seq ", alive_seq.size()) # print(alive_seq) # print("decoder_input ", decoder_input.size()) # print(decoder_input) # Decoder forward. decoder_input = decoder_input.transpose(0, 1) encoder_state = src_features decoder_outputs, dec_states, y_emb = self.model.abstractor.decoder( decoder_input, src_features, dec_states, step=step, need_y_emb=True) # print("encoder_state ", encoder_state.size()) # print(encoder_state) # print("decoder_outputs", decoder_outputs.size()) # print(decoder_outputs) # print("y_emb", y_emb.size()) # print(y_emb) src_pad_mask = (1 - mask_src).unsqueeze(1) # print("src_pad_mask ", src_pad_mask.size()) # print(src_pad_mask) context_vector, attn_dist = self.model.context_attn( src_features, src_features, decoder_outputs, mask=src_pad_mask, # layer_cache=layer_cache, type="context") # print("context vector ", context_vector.size()) # print(context_vector) # print("attn_dist 0", attn_dist.size()) # print(attn_dist) g = torch.sigmoid( F.linear( torch.cat([decoder_outputs, y_emb, context_vector], -1), self.model.v, self.model.bv)) xids = src.unsqueeze(0).transpose(0, 1) # xids = xids * mask_tgt.unsqueeze(2)[:, :-1, :].long() len0 = src.size(1) len0 = torch.Tensor([[len0]]).repeat(src.size(0), 1).long().to('cuda') clss_up = torch.cat((clss, len0), dim=1) sent_len = (clss_up[:, 1:] - clss) * mask_cls.long() for i in range(mask_cls.size(0)): for j in range(mask_cls.size(1)): if sent_len[i][j] < 0: sent_len[i][j] += src.size(1) # print("sent_len ", sent_len.size()) # print(sent_len) ext_scores_0 = ext_scores.unsqueeze(1).transpose(1, 2).repeat( 1, 1, mask_src.size(1)) for i in range(mask_cls.size(0)): tmp_vec = ext_scores_0[i, 0, :sent_len[i][0].int()] for j in range(1, mask_cls.size(1)): tmp_vec = torch.cat( (tmp_vec, ext_scores_0[i, j, :sent_len[i][j].int()]), dim=0) if i == 0: ext_scores_new = tmp_vec.unsqueeze(0) else: ext_scores_new = torch.cat( (ext_scores_new, tmp_vec.unsqueeze(0)), dim=0) # print("ext_score_new", ext_scores_new.size()) # print(ext_scores_new) ext_scores_new = ext_scores_new * mask_src.float() attn_dist = attn_dist * (ext_scores_new + 1).unsqueeze(1) # print("attn_dist 1", attn_dist.size()) # print(attn_dist) ext_dist = Variable( torch.zeros( src.size(0), 1, self.model.abstractor.bert.model.config.vocab_size).to( 'cuda')) # ext_vocab_prob = ext_dist.scatter_add(2, xids, (1 - g) * mask_tgt.unsqueeze(2)[:,:-1,:].float() * attn_pad_mask) * mask_tgt.unsqueeze(2)[:,:-1,:].float() ext_vocab_prob = ext_dist.scatter_add(2, xids, (1 - g) * attn_dist) # Generator forward. softmax_probs = self.model.abstractor.generator.forward( decoder_outputs.transpose(0, 1).squeeze(0)) * g.transpose( 0, 1).squeeze(0) + ext_vocab_prob.transpose(0, 1).squeeze(0) log_probs = torch.log(softmax_probs) vocab_size = log_probs.size(-1) if step < min_length: log_probs[:, self.end_token] = -1e20 # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty if (self.args.block_trigram): cur_len = alive_seq.size(1) if (cur_len > 3): for i in range(alive_seq.size(0)): fail = False words = [int(w) for w in alive_seq[i]] words = [self.vocab.ids_to_tokens[w] for w in words] words = ' '.join(words).replace(' ##', '').split() if (len(words) <= 3): continue trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)] trigram = tuple(trigrams[-1]) if trigram in trigrams[:-1]: fail = True if fail: curr_scores[i] = -10e20 curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat([ alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1) ], -1) is_finished = topk_ids.eq(self.end_token) if step + 1 == max_length: is_finished.fill_(1) # End condition is top beam is finished. end_condition = is_finished[:, 0].eq(1) # Save finished hypotheses. if is_finished.any(): predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: hypotheses[b].append( (topk_scores[i, j], predictions[i, j, 1:])) # If the batch reached the end, save the n_best hypotheses. if end_condition[i]: best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) score, pred = best_hyp[0] results["scores"][b].append(score) results["predictions"][b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) # Reorder states. select_indices = batch_index.view(-1) src_features = src_features.index_select(0, select_indices) ext_scores = ext_scores.index_select(0, select_indices) mask_src = mask_src.index_select(0, select_indices) mask_cls = mask_cls.index_select(0, select_indices) clss = clss.index_select(0, select_indices) src = src.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) return results