def reorder_encoder_out(self, encoder_out, new_order): (x, src_tokens, encoder_padding_mask) = encoder_out src_tokens_tensor = pytorch_translate_utils.get_source_tokens_tensor( src_tokens) if x is not None: x = x.index_select(1, new_order) if src_tokens_tensor is not None: src_tokens_tensor = src_tokens_tensor.index_select(0, new_order) if encoder_padding_mask is not None: encoder_padding_mask = encoder_padding_mask.index_select( 0, new_order) return (x, src_tokens_tensor, encoder_padding_mask)
def generate_hypo(self, repacked_inputs, maxlen_a=0.0, maxlen_b=None): if maxlen_b is None: maxlen_b = self.maxlen src_tokens = repacked_inputs["src_tokens"] srclen = pytorch_translate_utils.get_source_tokens_tensor( src_tokens).size(1) hypos = self.generate( repacked_inputs, beam_size=self.beam_size, maxlen=int(maxlen_a * srclen + maxlen_b), # If we need to generate predictions with teacher forcing, this # won't work. Right now this is fine. prefix_tokens=None, ) return self._pick_hypothesis_unpack_output(hypos)
def forward(self, src_tokens, src_lengths): # Embed tokens x = self.embed_tokens(src_tokens) src_tokens_tensor = pytorch_translate_utils.get_source_tokens_tensor(src_tokens) # Add position embeddings and dropout x = self.embed_scale * x positions = self.embed_positions(src_tokens_tensor) x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask (B x T) encoder_padding_mask = src_tokens_tensor.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None return x, encoder_padding_mask, positions
def _decode_target( self, encoder_input, encoder_outs, incremental_states, diversity_sibling_gamma=0.0, beam_size=None, maxlen=None, prefix_tokens=None, ): src_tokens_tensor = pytorch_translate_utils.get_source_tokens_tensor( encoder_input["src_tokens"]) beam_size = beam_size if beam_size is not None else self.beam_size bsz = src_tokens_tensor.size(0) reorder_indices = (torch.arange(bsz).view(-1, 1).repeat( 1, beam_size).view(-1).long()) for i, model in enumerate(self.models): encoder_outs[i] = model.encoder.reorder_encoder_out( encoder_out=encoder_outs[i], new_order=reorder_indices.type_as(src_tokens_tensor), ) maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen # initialize buffers scores = src_tokens_tensor.new(bsz * beam_size, maxlen + 1).float().fill_(0) scores_buf = scores.clone() tokens = src_tokens_tensor.new(bsz * beam_size, maxlen + 2).fill_(self.pad) tokens_buf = tokens.clone() tokens[:, 0] = self.eos # may differ from input length if isinstance(encoder_outs[0], (list, tuple)): src_encoding_len = encoder_outs[0][0].size(0) elif isinstance(encoder_outs[0], dict): if isinstance(encoder_outs[0]["encoder_out"], tuple): # Fairseq compatibility src_encoding_len = encoder_outs[0]["encoder_out"][0].size(1) else: src_encoding_len = encoder_outs[0]["encoder_out"].size(0) attn = scores.new(bsz * beam_size, src_encoding_len, maxlen + 2) attn_buf = attn.clone() # list of completed sentences finalized = [[] for i in range(bsz)] finished = [False for i in range(bsz)] worst_finalized = [{ "idx": None, "score": -math.inf } for i in range(bsz)] num_remaining_sent = bsz # number of candidate hypos per step cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens) # helper function for allocating buffers on the fly buffers = {} def buffer(name, type_of=tokens): # noqa if name not in buffers: buffers[name] = type_of.new() return buffers[name] def is_finished(sent, step, unfinalized_scores=None): """ Check whether we've finished generation for a given sentence, by comparing the worst score among finalized hypotheses to the best possible score among unfinalized hypotheses. """ assert len(finalized[sent]) <= beam_size if len(finalized[sent]) == beam_size: if self.stop_early or step == maxlen or unfinalized_scores is None: return True # stop if the best unfinalized score is worse than the worst # finalized one best_unfinalized_score = unfinalized_scores[sent].max() if self.normalize_scores: best_unfinalized_score /= (maxlen + 1)**self.len_penalty if worst_finalized[sent]["score"] >= best_unfinalized_score: return True return False def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None): """ Finalize the given hypotheses at this step, while keeping the total number of finalized hypotheses per sentence <= beam_size. Note: the input must be in the desired finalization order, so that hypotheses that appear earlier in the input are preferred to those that appear later. Args: step: current time step bbsz_idx: A vector of indices in the range [0, bsz*beam_size), indicating which hypotheses to finalize eos_scores: A vector of the same size as bbsz_idx containing scores for each hypothesis unfinalized_scores: A vector containing scores for all unfinalized hypotheses """ assert bbsz_idx.numel() == eos_scores.numel() # clone relevant token and attention tensors tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS tokens_clone[:, step] = self.eos attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step + 2] # compute scores per token position pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] pos_scores[:, step] = eos_scores # convert from cumulative to per-position scores pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] # normalize sentence-level scores if self.normalize_scores: eos_scores /= (step + 1)**self.len_penalty sents_seen = set() for i, (idx, score) in enumerate( zip(bbsz_idx.tolist(), eos_scores.tolist())): sent = idx // beam_size sents_seen.add(sent) def get_hypo(): _, alignment = attn_clone[i].max(dim=0) return { "tokens": tokens_clone[i], "score": score, "attention": attn_clone[i], # src_len x tgt_len "alignment": alignment, "positional_scores": pos_scores[i], } if len(finalized[sent]) < beam_size: finalized[sent].append(get_hypo()) elif not self.stop_early and score > worst_finalized[sent][ "score"]: # replace worst hypo for this sentence with new/better one worst_idx = worst_finalized[sent]["idx"] if worst_idx is not None: finalized[sent][worst_idx] = get_hypo() # find new worst finalized hypo for this sentence idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]["score"]) worst_finalized[sent] = {"score": s["score"], "idx": idx} # return number of hypotheses finished this step num_finished = 0 for sent in sents_seen: # check termination conditions for this sentence if not finished[sent] and is_finished(sent, step, unfinalized_scores): finished[sent] = True num_finished += 1 return num_finished reorder_state = None for step in range(maxlen + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: for model in self.models: if isinstance(model.decoder, FairseqIncrementalDecoder): model.decoder.reorder_incremental_state( incremental_states[model], reorder_state) # Run decoder for one step logprobs, avg_attn, possible_translation_tokens = self._decode( tokens[:, :step + 1], encoder_outs, incremental_states) logprobs[:, self.pad] = -math.inf # never select pad # apply unk reward if possible_translation_tokens is None: # No vocab reduction, so unk is represented by self.unk at # position self.unk unk_index = self.unk logprobs[:, unk_index] += self.unk_reward else: # When we use vocab reduction, the token value self.unk may not # be at the position self.unk, but somewhere else in the list # of possible_translation_tokens. It's also possible not to # show up in possible_translation_tokens at all, meaning we # can't generate an unk. unk_pos = torch.nonzero( possible_translation_tokens == self.unk) if unk_pos.size()[0] != 0: # only add unk_reward if unk index appears in # possible_translation_tokens unk_index = unk_pos[0][0] logprobs[:, unk_index] += self.unk_reward # external lexicon reward logprobs[:, self.lexicon_indices] += self.lexicon_reward logprobs += self.word_reward logprobs[:, self.eos] -= self.word_reward # Record attention scores attn[:, :, step + 1].copy_(avg_attn) cand_scores = buffer("cand_scores", type_of=scores) cand_indices = buffer("cand_indices") cand_beams = buffer("cand_beams") eos_bbsz_idx = buffer("eos_bbsz_idx") eos_scores = buffer("eos_scores", type_of=scores) scores = scores.type_as(logprobs) scores_buf = scores_buf.type_as(logprobs) if step < maxlen: if prefix_tokens is not None and step < prefix_tokens.size(1): logprobs_slice = logprobs.view(bsz, -1, logprobs.size(-1))[:, 0, :] cand_scores = torch.gather( logprobs_slice, dim=1, index=prefix_tokens[:, step].view(-1, 1)).expand( -1, cand_size) cand_indices = (prefix_tokens[:, step].view(-1, 1).expand( bsz, cand_size)) cand_beams.resize_as_(cand_indices).fill_(0) else: possible_tokens_size = self.vocab_size if possible_translation_tokens is not None: possible_tokens_size = possible_translation_tokens.size( 0) if diversity_sibling_gamma > 0: logprobs = self.diversity_sibling_rank( logprobs.view(bsz, -1, possible_tokens_size), diversity_sibling_gamma, ) cand_scores, cand_indices, cand_beams = self.search.step( step, logprobs.view(bsz, -1, possible_tokens_size), scores.view(bsz, beam_size, -1)[:, :, :step], ) # vocabulary reduction if possible_translation_tokens is not None: possible_translation_tokens = possible_translation_tokens.view( 1, possible_tokens_size).expand( cand_indices.size(0), possible_tokens_size) cand_indices = torch.gather( possible_translation_tokens, dim=1, index=cand_indices, out=cand_indices, ) else: # finalize all active hypotheses once we hit maxlen # pick the hypothesis with the highest log prob of EOS right now logprobs.add_(scores[:, step - 1].view(-1, 1)) torch.sort( logprobs[:, self.eos], descending=True, out=(eos_scores, eos_bbsz_idx), ) num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores) assert num_remaining_sent == 0 break # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] cand_bbsz_idx = cand_beams.add_(bbsz_offsets) # finalize hypotheses that end in eos eos_mask = cand_indices.eq(self.eos) if step >= self.minlen: # only consider eos when it's among the top beam_size indices torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_bbsz_idx, ) if eos_bbsz_idx.numel() > 0: torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_scores, ) num_remaining_sent -= finalize_hypos( step, eos_bbsz_idx, eos_scores, cand_scores) assert num_remaining_sent >= 0 if num_remaining_sent == 0: break assert step < maxlen # set active_mask so that values > cand_size indicate eos hypos # and values < cand_size indicate candidate active hypos. # After, the min values per row are the top candidate active hypos active_mask = buffer("active_mask") torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[:eos_mask.size(1)], out=active_mask, ) # get the top beam_size active hypotheses, which are just the hypos # with the smallest values in active_mask active_hypos, _ignore = buffer("active_hypos"), buffer("_ignore") torch.topk( active_mask, k=beam_size, dim=1, largest=False, out=(_ignore, active_hypos), ) active_bbsz_idx = buffer("active_bbsz_idx") torch.gather(cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx) active_scores = torch.gather( cand_scores, dim=1, index=active_hypos, out=scores[:, step].view(bsz, beam_size), ) active_bbsz_idx = active_bbsz_idx.view(-1) active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses torch.index_select( tokens[:, :step + 1], dim=0, index=active_bbsz_idx, out=tokens_buf[:, :step + 1], ) torch.gather( cand_indices, dim=1, index=active_hypos, out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], ) if step > 0: torch.index_select( scores[:, :step], dim=0, index=active_bbsz_idx, out=scores_buf[:, :step], ) torch.gather( cand_scores, dim=1, index=active_hypos, out=scores_buf.view(bsz, beam_size, -1)[:, :, step], ) # copy attention for active hypotheses torch.index_select( attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, out=attn_buf[:, :, :step + 2], ) # swap buffers tokens, tokens_buf = tokens_buf, tokens scores, scores_buf = scores_buf, scores attn, attn_buf = attn_buf, attn # reorder incremental state in decoder reorder_state = active_bbsz_idx # sort by score descending for sent in range(bsz): finalized[sent] = sorted(finalized[sent], key=lambda r: r["score"], reverse=True) return finalized