def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False): model = EnsembleModel(models) # model.forward normally channels prev_output_tokens into the decoder # separately, but SequenceGenerator directly calls model.encoder encoder_input = { k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" } # compute the encoder output for each beam encoder_outs = model.forward_encoder(encoder_input) np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype( np.float32) encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy( ).astype(np.float32) encoder_mask = np.expand_dims(encoder_mask.T, axis=2) if has_langtok: encoder_mask = encoder_mask[1:, :, :] np_encoder_outs = np_encoder_outs[1, :, :] masked_encoder_outs = encoder_mask * np_encoder_outs avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0) return avg_pool
class GenerateLogProbsForDecoding(nn.Module): def __init__(self, models, retain_dropout=False, apply_log_softmax=False): """Generate the neural network's output intepreted as log probabilities for decoding with Kaldi. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models, currently support fairseq.models.TransformerModel for scripting retain_dropout (bool, optional): use dropout when generating (default: False) apply_log_softmax (bool, optional): apply log-softmax on top of the network's output (default: False) """ super().__init__() from fairseq.sequence_generator import EnsembleModel if isinstance(models, EnsembleModel): self.model = models else: self.model = EnsembleModel(models) self.retain_dropout = retain_dropout self.apply_log_softmax = apply_log_softmax if not self.retain_dropout: self.model.eval() def cuda(self): self.model.cuda() return self @torch.no_grad() def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): """Generate a batch of translations. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models sample (dict): batch """ self.model.reset_incremental_state() return self._generate(sample, **kwargs) def _generate(self, sample: Dict[str, Dict[str, Tensor]], **kwargs): net_input = sample["net_input"] src_tokens = net_input["src_tokens"] bsz = src_tokens.size(0) # compute the encoder output encoder_outs = self.model.forward_encoder(net_input) logits = encoder_outs[0].encoder_out.transpose( 0, 1).float() # T x B x V -> B x T x V assert logits.size(0) == bsz padding_mask = encoder_outs[0].encoder_padding_mask.t() \ if encoder_outs[0].encoder_padding_mask is not None else None if self.apply_log_softmax: return F.log_softmax(logits, dim=-1), padding_mask return logits, padding_mask
class Model(): """Wrapper around the stack-transformer model""" def __init__(self, models, target_dictionary): self.temperature = 1. self.target_dictionary = target_dictionary self.models = models self.reset() def reset(self): # This is to clear the cache of key values, there may be more efficient # ways self.model = EnsembleModel(self.models) # reset cache for encoder self.encoder_outs = None self.model.eval() def precompute_encoder(self, sample): """Encoder of the encoder-decoder is fixed and can be precomputed""" encoder_input = extract_encoder(sample) encoder_outs = self.model.forward_encoder(encoder_input) return encoder_outs def get_action(self, sample, parser_state, prev_actions): # Compute part of the model that does not depend on episode steps # (encoder). Cache it for future use # precompute encoder for speed if self.encoder_outs is None: self.encoder_outs = self.precompute_encoder(sample) # call model with pre-computed encoder, previous generated actions # (tokens) and state machine status lprobs, avg_attn_scores = self.model.forward_decoder( prev_actions, self.encoder_outs, parser_state, temperature=self.temperature) # Get most probable action if True: best_action_indices = lprobs.argmax(dim=1).tolist() else: # sampling best_action_indices = torch.squeeze(lprobs.exp().multinomial(1), 1).tolist() actions = [self.target_dictionary[i] for i in best_action_indices] actions_lprob = [lprobs[0, i] for i in best_action_indices] return actions, actions_lprob
def generate(self, models, sample, prefix_tokens=None, bos_token=None, **kwargs): """Generate a batch of translations. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models sample (dict): batch prefix_tokens (torch.LongTensor, optional): force decoder to begin with these tokens """ model = EnsembleModel(models) if not self.retain_dropout: model.eval() # model.forward normally channels prev_output_tokens into the decoder # separately, but SequenceGenerator directly calls model.encoder encoder_input = { k: v for k, v in sample['net_input'].items() if k != 'prev_output_tokens' } src_tokens = encoder_input['src_tokens'] src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) input_size = src_tokens.size() # batch dimension goes first followed by source lengths bsz = input_size[0] src_len = input_size[1] beam_size = self.beam_size if self.match_source_len: max_len = src_lengths.max().item() else: max_len = min( int(self.max_len_a * src_len + self.max_len_b), # exclude the EOS marker model.max_decoder_positions() - 1, ) # compute the encoder output for each beam encoder_outs = model.forward_encoder(encoder_input) new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) new_order = new_order.to(src_tokens.device).long() encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) # initialize buffers scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0) scores_buf = scores.clone() tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) tokens_buf = tokens.clone() tokens[:, 0] = bos_token or self.eos attn, attn_buf = None, None nonpad_idxs = None # list of completed sentences finalized = [[] for i in range(bsz)] finished = [False for i in range(bsz)] worst_finalized = [{ 'idx': None, 'score': -math.inf } for i in range(bsz)] num_remaining_sent = bsz # number of candidate hypos per step cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens) # helper function for allocating buffers on the fly buffers = {} def buffer(name, type_of=tokens): # noqa if name not in buffers: buffers[name] = type_of.new() return buffers[name] def is_finished(sent, step, unfinalized_scores=None): """ Check whether we've finished generation for a given sentence, by comparing the worst score among finalized hypotheses to the best possible score among unfinalized hypotheses. """ assert len(finalized[sent]) <= beam_size if len(finalized[sent]) == beam_size: if self.stop_early or step == max_len or unfinalized_scores is None: return True # stop if the best unfinalized score is worse than the worst # finalized one best_unfinalized_score = unfinalized_scores[sent].max() if self.normalize_scores: best_unfinalized_score /= max_len**self.len_penalty if worst_finalized[sent]['score'] >= best_unfinalized_score: return True return False def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None): """ Finalize the given hypotheses at this step, while keeping the total number of finalized hypotheses per sentence <= beam_size. Note: the input must be in the desired finalization order, so that hypotheses that appear earlier in the input are preferred to those that appear later. Args: step: current time step bbsz_idx: A vector of indices in the range [0, bsz*beam_size), indicating which hypotheses to finalize eos_scores: A vector of the same size as bbsz_idx containing scores for each hypothesis unfinalized_scores: A vector containing scores for all unfinalized hypotheses """ assert bbsz_idx.numel() == eos_scores.numel() # clone relevant token and attention tensors tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS tokens_clone[:, step] = self.eos attn_clone = attn.index_select( 0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None # compute scores per token position pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] pos_scores[:, step] = eos_scores # convert from cumulative to per-position scores pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] # normalize sentence-level scores if self.normalize_scores: eos_scores /= (step + 1)**self.len_penalty cum_unfin = [] prev = 0 for f in finished: if f: prev += 1 else: cum_unfin.append(prev) sents_seen = set() for i, (idx, score) in enumerate( zip(bbsz_idx.tolist(), eos_scores.tolist())): unfin_idx = idx // beam_size sent = unfin_idx + cum_unfin[unfin_idx] sents_seen.add((sent, unfin_idx)) if self.match_source_len and step > src_lengths[unfin_idx]: score = -math.inf def get_hypo(): if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i][nonpad_idxs[sent]] _, alignment = hypo_attn.max(dim=0) else: hypo_attn = None alignment = None return { 'tokens': tokens_clone[i], 'score': score, 'attention': hypo_attn, # src_len x tgt_len 'alignment': alignment, 'positional_scores': pos_scores[i], } if len(finalized[sent]) < beam_size: finalized[sent].append(get_hypo()) elif not self.stop_early and score > worst_finalized[sent][ 'score']: # replace worst hypo for this sentence with new/better one worst_idx = worst_finalized[sent]['idx'] if worst_idx is not None: finalized[sent][worst_idx] = get_hypo() # find new worst finalized hypo for this sentence idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score']) worst_finalized[sent] = { 'score': s['score'], 'idx': idx, } newly_finished = [] for sent, unfin_idx in sents_seen: # check termination conditions for this sentence if not finished[sent] and is_finished(sent, step, unfinalized_scores): finished[sent] = True newly_finished.append(unfin_idx) return newly_finished reorder_state = None batch_idxs = None for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: if batch_idxs is not None: # update beam indices to take into account removed sentences corr = batch_idxs - torch.arange( batch_idxs.numel()).type_as(batch_idxs) reorder_state.view(-1, beam_size).add_( corr.unsqueeze(-1) * beam_size) model.reorder_incremental_state(reorder_state) model.reorder_encoder_out(encoder_outs, reorder_state) lprobs, avg_attn_scores = model.forward_decoder( tokens[:, :step + 1], encoder_outs) lprobs[:, self.pad] = -math.inf # never select pad lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty if self.no_repeat_ngram_size > 0: # for each beam and batch sentence, generate a list of previous ngrams gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): gen_tokens = tokens[bbsz_idx].tolist() for ngram in zip(*[ gen_tokens[i:] for i in range(self.no_repeat_ngram_size) ]): gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] # Record attention scores if avg_attn_scores is not None: if attn is None: attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) attn_buf = attn.clone() nonpad_idxs = src_tokens.ne(self.pad) attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(lprobs) scores_buf = scores_buf.type_as(lprobs) eos_bbsz_idx = buffer('eos_bbsz_idx') eos_scores = buffer('eos_scores', type_of=scores) if step < max_len: self.search.set_src_lengths(src_lengths) if self.no_repeat_ngram_size > 0: def calculate_banned_tokens(bbsz_idx): # before decoding the next token, prevent decoding of ngrams that have already appeared ngram_index = tuple( tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) return gen_ngrams[bbsz_idx].get(ngram_index, []) if step + 2 - self.no_repeat_ngram_size >= 0: # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet banned_tokens = [ calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size) ] else: banned_tokens = [[] for bbsz_idx in range(bsz * beam_size) ] for bbsz_idx in range(bsz * beam_size): lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf if prefix_tokens is not None and step < prefix_tokens.size(1): probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :] cand_scores = torch.gather( probs_slice, dim=1, index=prefix_tokens[:, step].view(-1, 1)).view( -1, 1).repeat(1, cand_size) if step > 0: # save cumulative scores for each hypothesis cand_scores.add_(scores[:, step - 1].view( bsz, beam_size).repeat(1, 2)) cand_indices = prefix_tokens[:, step].view(-1, 1).repeat( 1, cand_size) cand_beams = torch.zeros_like(cand_indices) # handle prefixes of different lengths partial_prefix_mask = prefix_tokens[:, step].eq(self.pad) if partial_prefix_mask.any(): partial_scores, partial_indices, partial_beams = self.search.step( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], ) cand_scores[partial_prefix_mask] = partial_scores[ partial_prefix_mask] cand_indices[partial_prefix_mask] = partial_indices[ partial_prefix_mask] cand_beams[partial_prefix_mask] = partial_beams[ partial_prefix_mask] else: cand_scores, cand_indices, cand_beams = self.search.step( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], ) else: # make probs contain cumulative scores for each hypothesis lprobs.add_(scores[:, step - 1].unsqueeze(-1)) # finalize all active hypotheses once we hit max_len # pick the hypothesis with the highest prob of EOS right now torch.sort( lprobs[:, self.eos], descending=True, out=(eos_scores, eos_bbsz_idx), ) num_remaining_sent -= len( finalize_hypos(step, eos_bbsz_idx, eos_scores)) assert num_remaining_sent == 0 break # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] cand_bbsz_idx = cand_beams.add(bbsz_offsets) # finalize hypotheses that end in eos eos_mask = cand_indices.eq(self.eos) finalized_sents = set() if step >= self.min_len: # only consider eos when it's among the top beam_size indices torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_bbsz_idx, ) if eos_bbsz_idx.numel() > 0: torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_scores, ) finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores, cand_scores) num_remaining_sent -= len(finalized_sents) assert num_remaining_sent >= 0 if num_remaining_sent == 0: break assert step < max_len if len(finalized_sents) > 0: new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass batch_mask = cand_indices.new_ones(bsz) batch_mask[cand_indices.new(finalized_sents)] = 0 batch_idxs = batch_mask.nonzero().squeeze(-1) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] bbsz_offsets.resize_(new_bsz, 1) cand_bbsz_idx = cand_beams.add(bbsz_offsets) cand_scores = cand_scores[batch_idxs] cand_indices = cand_indices[batch_idxs] if prefix_tokens is not None: prefix_tokens = prefix_tokens[batch_idxs] src_lengths = src_lengths[batch_idxs] scores = scores.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) scores_buf.resize_as_(scores) tokens = tokens.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) tokens_buf.resize_as_(tokens) if attn is not None: attn = attn.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, attn.size(1), -1) attn_buf.resize_as_(attn) bsz = new_bsz else: batch_idxs = None # set active_mask so that values > cand_size indicate eos hypos # and values < cand_size indicate candidate active hypos. # After, the min values per row are the top candidate active hypos active_mask = buffer('active_mask') torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[:eos_mask.size(1)], out=active_mask, ) # get the top beam_size active hypotheses, which are just the hypos # with the smallest values in active_mask active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore') torch.topk(active_mask, k=beam_size, dim=1, largest=False, out=(_ignore, active_hypos)) active_bbsz_idx = buffer('active_bbsz_idx') torch.gather( cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx, ) active_scores = torch.gather( cand_scores, dim=1, index=active_hypos, out=scores[:, step].view(bsz, beam_size), ) active_bbsz_idx = active_bbsz_idx.view(-1) active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses torch.index_select( tokens[:, :step + 1], dim=0, index=active_bbsz_idx, out=tokens_buf[:, :step + 1], ) torch.gather( cand_indices, dim=1, index=active_hypos, out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], ) if step > 0: torch.index_select( scores[:, :step], dim=0, index=active_bbsz_idx, out=scores_buf[:, :step], ) torch.gather( cand_scores, dim=1, index=active_hypos, out=scores_buf.view(bsz, beam_size, -1)[:, :, step], ) # copy attention for active hypotheses if attn is not None: torch.index_select( attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, out=attn_buf[:, :, :step + 2], ) # swap buffers tokens, tokens_buf = tokens_buf, tokens scores, scores_buf = scores_buf, scores if attn is not None: attn, attn_buf = attn_buf, attn # reorder incremental state in decoder reorder_state = active_bbsz_idx # sort by score descending for sent in range(len(finalized)): finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) return finalized
class SimpleGreedyDecoder(nn.Module): def __init__( self, models, dictionary, max_len_a=0, max_len_b=200, max_len=0, temperature=1.0, eos=None, symbols_to_strip_from_output=None, for_validation=True, **kwargs, ): """Decode given speech audios with the simple greedy search. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models, currently support fairseq.models.TransformerModel for scripting dictionary (~fairseq.data.Dictionary): dictionary max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length max_len (int, optional): the maximum length of the generated output (not including end-of-sentence) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) for_validation (bool, optional): indicate whether the decoder is used for validation. It affects how max_len is determined, and whether a tensor of lprobs is returned. If true, target should be not None """ super().__init__() from fairseq.sequence_generator import EnsembleModel if isinstance(models, EnsembleModel): self.model = models else: self.model = EnsembleModel(models) self.pad = dictionary.pad() self.unk = dictionary.unk() self.eos = dictionary.eos() if eos is None else eos self.symbols_to_strip_from_output = ( symbols_to_strip_from_output.union({self.eos}) if symbols_to_strip_from_output is not None else {self.eos}) self.vocab_size = len(dictionary) self.max_len_a = max_len_a self.max_len_b = max_len_b self.max_len = max_len or self.model.max_decoder_positions() self.temperature = temperature assert temperature > 0, "--temperature must be greater than 0" self.model.eval() self.for_validation = for_validation def cuda(self): self.model.cuda() return self @torch.no_grad() def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): """Generate a batch of translations. Match the api of other fairseq generators. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models sample (dict): batch bos_token (int, optional): beginning of sentence token (default: self.eos) """ return self._decode(sample, **kwargs) @torch.no_grad() def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None): incremental_states = torch.jit.annotate( List[Dict[str, Dict[str, Optional[Tensor]]]], [ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) for i in range(self.model.models_size) ], ) net_input = sample["net_input"] src_tokens = net_input["src_tokens"] bsz, src_len = src_tokens.size()[:2] # compute the encoder output encoder_outs = self.model.forward_encoder(net_input) target = sample["target"] # target can only be None if not for validation assert target is not None or not self.for_validation max_encoder_output_length = encoder_outs[0]["encoder_out"][0].size(0) # for validation, make the maximum decoding length equal to at least the # length of target, and the length of encoder_out if possible; otherwise # max_len is obtained from max_len_a/b max_len = (max(max_encoder_output_length, target.size(1)) if self.for_validation else min( int(self.max_len_a * src_len + self.max_len_b), self.max_len - 1, )) tokens = src_tokens.new(bsz, max_len + 2).long().fill_(self.pad) tokens[:, 0] = self.eos if bos_token is None else bos_token # lprobs is only used when target is not None (i.e., for validation) lprobs = (encoder_outs[0]["encoder_out"][0].new_full( (bsz, target.size(1), self.vocab_size), -np.log(self.vocab_size), ) if self.for_validation else None) attn = None for step in range(max_len + 1): # one extra step for EOS marker is_eos = tokens[:, step].eq(self.eos) if step > 0 and is_eos.sum() == is_eos.size(0): # all predictions are finished (i.e., ended with eos) tokens = tokens[:, :step + 1] if attn is not None: attn = attn[:, :, :step + 1] break log_probs, avg_attn_scores = self.model.forward_decoder( tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature, ) tokens[:, step + 1] = log_probs.argmax(-1) if step > 0: # deal with finished predictions # make log_probs uniform if the previous output token is EOS # and add consecutive EOS to the end of prediction log_probs[is_eos, :] = -np.log(log_probs.size(1)) tokens[is_eos, step + 1] = self.eos if self.for_validation and step < target.size(1): lprobs[:, step, :] = log_probs # Record attention scores if type(avg_attn_scores) is list: avg_attn_scores = avg_attn_scores[0] if avg_attn_scores is not None: if attn is None: attn = avg_attn_scores.new(bsz, max_encoder_output_length, max_len + 2) attn[:, :, step + 1].copy_(avg_attn_scores) return tokens[:, 1:], lprobs, attn
def generate(self, models, sample, prefix_tokens=None, bos_token=None, **kwargs): """Generate a batch of translations. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models sample (dict): batch prefix_tokens (torch.LongTensor, optional): force decoder to begin with these tokens """ model = EnsembleModel(models) incremental_states = torch.jit.annotate( List[Dict[str, Dict[str, Optional[Tensor]]]], [ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) for i in range(model.models_size) ], ) if not self.retain_dropout: model.eval() # model.forward normally channels prev_output_tokens into the decoder # separately, but SequenceGenerator directly calls model.encoder encoder_input = { k: v for k, v in sample['net_input'].items() if k != 'prev_output_tokens' } src_tokens = encoder_input['src_tokens'] src_lengths_no_eos = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) input_size = src_tokens.size() # batch dimension goes first followed by source lengths bsz = input_size[0] src_len = input_size[1] beam_size = self.beam_size if self.match_source_len: max_len = src_lengths_no_eos.max().item() else: max_len = min( int(self.max_len_a * src_len + self.max_len_b), # exclude the EOS marker model.max_decoder_positions() - 1, ) # compute the encoder output for each beam encoder_outs = model.forward_encoder(encoder_input) new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) new_order = new_order.to(src_tokens.device).long() encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) src_lengths = encoder_input['src_lengths'] # initialize buffers scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0) lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0) scores_buf = scores.clone() tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) tokens_buf = tokens.clone() tokens[:, 0] = self.eos if bos_token is None else bos_token # reorder source tokens so they may be used as a reference in generating P(S|T) src_tokens = reorder_all_tokens(src_tokens, src_lengths, self.src_dict.eos_index) src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len) src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view( bsz * beam_size, -1) attn, attn_buf = None, None nonpad_idxs = None # The cands_to_ignore indicates candidates that should be ignored. # For example, suppose we're sampling and have already finalized 2/5 # samples. Then the cands_to_ignore would mark 2 positions as being ignored, # so that we only finalize the remaining 3 samples. cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq( -1) # forward and backward-compatible False mask # list of completed sentences finalized = [[] for i in range(bsz)] finished = [False for i in range(bsz)] num_remaining_sent = bsz # number of candidate hypos per step cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens) # helper function for allocating buffers on the fly buffers = {} def buffer(name, type_of=tokens): # noqa if name not in buffers: buffers[name] = type_of.new() return buffers[name] def is_finished(sent, step, unfin_idx): """ Check whether we've finished generation for a given sentence, by comparing the worst score among finalized hypotheses to the best possible score among unfinalized hypotheses. """ assert len(finalized[sent]) <= beam_size if len(finalized[sent]) == beam_size: return True return False def finalize_hypos(step, bbsz_idx, eos_scores, combined_noisy_channel_eos_scores): """ Finalize the given hypotheses at this step, while keeping the total number of finalized hypotheses per sentence <= beam_size. Note: the input must be in the desired finalization order, so that hypotheses that appear earlier in the input are preferred to those that appear later. Args: step: current time step bbsz_idx: A vector of indices in the range [0, bsz*beam_size), indicating which hypotheses to finalize eos_scores: A vector of the same size as bbsz_idx containing fw scores for each hypothesis combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing combined noisy channel scores for each hypothesis """ assert bbsz_idx.numel() == eos_scores.numel() # clone relevant token and attention tensors tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS assert not tokens_clone.eq(self.eos).any() tokens_clone[:, step] = self.eos attn_clone = attn.index_select( 0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None # compute scores per token position pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] pos_scores[:, step] = eos_scores # convert from cumulative to per-position scores pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] # normalize sentence-level scores if self.normalize_scores: combined_noisy_channel_eos_scores /= (step + 1)**self.len_penalty cum_unfin = [] prev = 0 for f in finished: if f: prev += 1 else: cum_unfin.append(prev) sents_seen = set() for i, (idx, score) in enumerate( zip(bbsz_idx.tolist(), combined_noisy_channel_eos_scores.tolist())): unfin_idx = idx // beam_size sent = unfin_idx + cum_unfin[unfin_idx] sents_seen.add((sent, unfin_idx)) if self.match_source_len and step > src_lengths_no_eos[ unfin_idx]: score = -math.inf def get_hypo(): if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i][nonpad_idxs[sent]] _, alignment = hypo_attn.max(dim=0) else: hypo_attn = None alignment = None return { 'tokens': tokens_clone[i], 'score': score, 'attention': hypo_attn, # src_len x tgt_len 'alignment': alignment, 'positional_scores': pos_scores[i], } if len(finalized[sent]) < beam_size: finalized[sent].append(get_hypo()) newly_finished = [] for sent, unfin_idx in sents_seen: # check termination conditions for this sentence if not finished[sent] and is_finished(sent, step, unfin_idx): finished[sent] = True newly_finished.append(unfin_idx) return newly_finished def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens, k): """Rescore the top k hypothesis from each beam using noisy channel modeling Returns: new_fw_lprobs: the direct model probabilities after pruning the top k new_ch_lm_lprobs: the combined channel and language model probabilities new_lm_lprobs: the language model probabilities after pruning the top k """ with torch.no_grad(): lprobs_size = lprobs.size() if prefix_tokens is not None and step < prefix_tokens.size(1): probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :] cand_scores = torch.gather( probs_slice, dim=1, index=prefix_tokens[:, step].view(-1, 1).data).expand( -1, beam_size).contiguous().view(bsz * beam_size, 1) cand_indices = prefix_tokens[:, step].view(-1, 1).expand( bsz, beam_size).data.contiguous().view(bsz * beam_size, 1) # need to calculate and save fw and lm probs for prefix tokens fw_top_k = cand_scores fw_top_k_idx = cand_indices k = 1 else: # take the top k best words for every sentence in batch*beam fw_top_k, fw_top_k_idx = torch.topk(lprobs.view( beam_size * bsz, -1), k=k) eos_idx = torch.nonzero( fw_top_k_idx.view(bsz * beam_size * k, -1) == self.eos)[:, 0] ch_scores = fw_top_k.new_full((beam_size * bsz * k, ), 0) src_size = torch.sum( src_tokens[:, :] != self.src_dict.pad_index, dim=1, keepdim=True, dtype=fw_top_k.dtype) if self.combine_method != "lm_only": temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view( bsz * beam_size * k, -1) not_padding = temp_src_tokens_full[:, 1:] != self.src_dict.pad_index cur_tgt_size = step + 2 # add eos to all candidate sentences except those that already end in eos eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1) eos_tokens[eos_idx] = self.tgt_dict.pad_index if step == 0: channel_input = torch.cat( (fw_top_k_idx.view(-1, 1), eos_tokens), 1) else: # move eos from beginning to end of target sentence channel_input = torch.cat( (tokens[:, 1:step + 1].repeat(1, k).view(-1, step), fw_top_k_idx.view(-1, 1), eos_tokens), 1) ch_input_lengths = torch.tensor( np.full(channel_input.size(0), cur_tgt_size)) ch_input_lengths[eos_idx] = cur_tgt_size - 1 if self.channel_scoring_type == "unnormalized": ch_encoder_output = channel_model.encoder( channel_input, src_lengths=ch_input_lengths) ch_decoder_output, _ = channel_model.decoder( temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True) del ch_encoder_output ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target( ch_decoder_output, target_ids=temp_src_tokens_full[:, 1:]) ch_intermed_scores = ch_intermed_scores.float() ch_intermed_scores *= not_padding.float() ch_scores = torch.sum(ch_intermed_scores, dim=1) elif self.channel_scoring_type == "k2_separate": for k_idx in range(k): k_eos_tokens = eos_tokens[k_idx::k, :] if step == 0: k_ch_input = torch.cat( (fw_top_k_idx[:, k_idx:k_idx + 1], k_eos_tokens), 1) else: # move eos from beginning to end of target sentence k_ch_input = torch.cat( (tokens[:, 1:step + 1], fw_top_k_idx[:, k_idx:k_idx + 1], k_eos_tokens), 1) k_ch_input_lengths = ch_input_lengths[k_idx::k] k_ch_output = channel_model( k_ch_input, k_ch_input_lengths, src_tokens) k_ch_lprobs = channel_model.get_normalized_probs( k_ch_output, log_probs=True) k_ch_intermed_scores = torch.gather( k_ch_lprobs[:, :-1, :], 2, src_tokens[:, 1:].unsqueeze(2)).squeeze(2) k_ch_intermed_scores *= not_padding.float() ch_scores[k_idx::k] = torch.sum( k_ch_intermed_scores, dim=1) elif self.channel_scoring_type == "src_vocab": ch_encoder_output = channel_model.encoder( channel_input, src_lengths=ch_input_lengths) ch_decoder_output, _ = channel_model.decoder( temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True) del ch_encoder_output ch_lprobs = normalized_scores_with_batch_vocab( channel_model.decoder, ch_decoder_output, src_tokens, k, bsz, beam_size, self.src_dict.pad_index, top_k=self.top_k_vocab) ch_scores = torch.sum(ch_lprobs, dim=1) elif self.channel_scoring_type == "src_vocab_batched": ch_bsz_size = temp_src_tokens_full.shape[0] ch_lprobs_list = [None] * len( range(0, ch_bsz_size, self.ch_scoring_bsz)) for i, start_idx in enumerate( range(0, ch_bsz_size, self.ch_scoring_bsz)): end_idx = min(start_idx + self.ch_scoring_bsz, ch_bsz_size) temp_src_tokens_full_batch = temp_src_tokens_full[ start_idx:end_idx, :] channel_input_batch = channel_input[ start_idx:end_idx, :] ch_input_lengths_batch = ch_input_lengths[ start_idx:end_idx] ch_encoder_output_batch = channel_model.encoder( channel_input_batch, src_lengths=ch_input_lengths_batch) ch_decoder_output_batch, _ = channel_model.decoder( temp_src_tokens_full_batch, encoder_out=ch_encoder_output_batch, features_only=True) ch_lprobs_list[ i] = normalized_scores_with_batch_vocab( channel_model.decoder, ch_decoder_output_batch, src_tokens, k, bsz, beam_size, self.src_dict.pad_index, top_k=self.top_k_vocab, start_idx=start_idx, end_idx=end_idx) ch_lprobs = torch.cat(ch_lprobs_list, dim=0) ch_scores = torch.sum(ch_lprobs, dim=1) else: ch_output = channel_model(channel_input, ch_input_lengths, temp_src_tokens_full) ch_lprobs = channel_model.get_normalized_probs( ch_output, log_probs=True) ch_intermed_scores = torch.gather( ch_lprobs[:, :-1, :], 2, temp_src_tokens_full[:, 1:].unsqueeze( 2)).squeeze().view(bsz * beam_size * k, -1) ch_intermed_scores *= not_padding.float() ch_scores = torch.sum(ch_intermed_scores, dim=1) else: cur_tgt_size = 0 ch_scores = ch_scores.view(bsz * beam_size, k) expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze( 1).expand(-1, k).flatten() if self.share_tgt_dict: lm_scores = get_lm_scores( lm, tokens[:, :step + 1].view(-1, step + 1), lm_incremental_states, fw_top_k_idx.view(-1, 1), torch.tensor(np.full(tokens.size(0), step + 1)), k) else: new_lm_input = dict2dict( tokens[:, :step + 1].view(-1, step + 1), self.tgt_to_lm) new_cands = dict2dict(fw_top_k_idx.view(-1, 1), self.tgt_to_lm) lm_scores = get_lm_scores( lm, new_lm_input, lm_incremental_states, new_cands, torch.tensor(np.full(tokens.size(0), step + 1)), k) lm_scores.add_(expanded_lm_prefix_scores) ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores, lm_scores, src_size, cur_tgt_size) # initialize all as min value new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view( bsz * beam_size, -1) new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_( -1e17).view(bsz * beam_size, -1) new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view( bsz * beam_size, -1) new_fw_lprobs[:, self.pad] = -math.inf new_ch_lm_lprobs[:, self.pad] = -math.inf new_lm_lprobs[:, self.pad] = -math.inf new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k) new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores) new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k)) return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size, tgt_size): if self.channel_scoring_type == "unnormalized": ch_scores = self.log_softmax_fn( ch_scores.view(-1, self.beam_size * self.k2)).view( ch_scores.shape) ch_scores = ch_scores * self.ch_weight lm_scores1 = lm_scores1 * self.lm_weight if combine_type == "lm_only": # log P(T|S) + log P(T) ch_scores = lm_scores1.view(ch_scores.size()) elif combine_type == "noisy_channel": # 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T) if self.normalize_lm_scores_by_tgt_len: ch_scores.div_(src_size) lm_scores_norm = lm_scores1.view( ch_scores.size()).div(tgt_size) ch_scores.add_(lm_scores_norm) # 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T) else: ch_scores.add_(lm_scores1.view(ch_scores.size())) ch_scores.div_(src_size) return ch_scores if self.channel_models is not None: channel_model = self.channel_models[ 0] # assume only one channel_model model else: channel_model = None lm = EnsembleModel(self.lm_models) lm_incremental_states = torch.jit.annotate( List[Dict[str, Dict[str, Optional[Tensor]]]], [ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) for i in range(lm.models_size) ], ) reorder_state = None batch_idxs = None for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: if batch_idxs is not None: # update beam indices to take into account removed sentences corr = batch_idxs - torch.arange( batch_idxs.numel()).type_as(batch_idxs) reorder_state.view(-1, beam_size).add_( corr.unsqueeze(-1) * beam_size) model.reorder_incremental_state(incremental_states, reorder_state) encoder_outs = model.reorder_encoder_out( encoder_outs, reorder_state) lm.reorder_incremental_state(lm_incremental_states, reorder_state) fw_lprobs, avg_attn_scores = model.forward_decoder( tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature, ) fw_lprobs[:, self.pad] = -math.inf # never select pad fw_lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring( fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2) # handle min and max length constraints if step >= max_len: fw_lprobs[:, :self.eos] = -math.inf fw_lprobs[:, self.eos + 1:] = -math.inf elif step < self.min_len: fw_lprobs[:, self.eos] = -math.inf # handle prefix tokens (possibly with different lengths) if prefix_tokens is not None and step < prefix_tokens.size(1): prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat( 1, beam_size).view(-1) prefix_mask = prefix_toks.ne(self.pad) prefix_fw_lprobs = fw_lprobs.gather(-1, prefix_toks.unsqueeze(-1)) fw_lprobs[prefix_mask] = -math.inf fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_fw_lprobs) prefix_ch_lm_lprobs = ch_lm_lprobs.gather( -1, prefix_toks.unsqueeze(-1)) ch_lm_lprobs[prefix_mask] = -math.inf ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_ch_lm_lprobs) prefix_lm_lprobs = lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1)) lm_lprobs[prefix_mask] = -math.inf lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lm_lprobs) # if prefix includes eos, then we should make sure tokens and # scores are the same across all beams eos_mask = prefix_toks.eq(self.eos) if eos_mask.any(): # validate that the first beam matches the prefix first_beam = tokens[eos_mask].view( -1, beam_size, tokens.size(-1))[:, 0, 1:step + 1] eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] assert (first_beam == target_prefix).all() def replicate_first_beam(tensor, mask): tensor = tensor.view(-1, beam_size, tensor.size(-1)) tensor[mask] = tensor[mask][:, :1, :] return tensor.view(-1, tensor.size(-1)) # copy tokens, scores and lprobs from the first beam to all beams tokens = replicate_first_beam(tokens, eos_mask_batch_dim) scores = replicate_first_beam(scores, eos_mask_batch_dim) fw_lprobs = replicate_first_beam(fw_lprobs, eos_mask_batch_dim) ch_lm_lprobs = replicate_first_beam( ch_lm_lprobs, eos_mask_batch_dim) lm_lprobs = replicate_first_beam(lm_lprobs, eos_mask_batch_dim) if self.no_repeat_ngram_size > 0: # for each beam and batch sentence, generate a list of previous ngrams gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): gen_tokens = tokens[bbsz_idx].tolist() for ngram in zip(*[ gen_tokens[i:] for i in range(self.no_repeat_ngram_size) ]): gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] # Record attention scores if avg_attn_scores is not None: if attn is None: attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) attn_buf = attn.clone() nonpad_idxs = src_tokens.ne(self.pad) attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(fw_lprobs) scores_buf = scores_buf.type_as(fw_lprobs) self.search.set_src_lengths(src_lengths_no_eos) if self.no_repeat_ngram_size > 0: def calculate_banned_tokens(bbsz_idx): # before decoding the next token, prevent decoding of ngrams that have already appeared ngram_index = tuple( tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) return gen_ngrams[bbsz_idx].get(ngram_index, []) if step + 2 - self.no_repeat_ngram_size >= 0: # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet banned_tokens = [ calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size) ] else: banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step( step, fw_lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], ch_lm_lprobs.view(bsz, -1, self.vocab_size), lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method) # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] cand_bbsz_idx = cand_beams.add(bbsz_offsets) # finalize hypotheses that end in eos (except for candidates to be ignored) eos_mask = cand_indices.eq(self.eos) eos_mask[:, :beam_size] &= ~cands_to_ignore # only consider eos when it's among the top beam_size indices eos_bbsz_idx = torch.masked_select(cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]) finalized_sents = set() if eos_bbsz_idx.numel() > 0: eos_scores = torch.masked_select( fw_lprobs_top_k[:, :beam_size], mask=eos_mask[:, :beam_size]) combined_noisy_channel_eos_scores = torch.masked_select( combined_noisy_channel_scores[:, :beam_size], mask=eos_mask[:, :beam_size], ) # finalize hypo using channel model score finalized_sents = finalize_hypos( step, eos_bbsz_idx, eos_scores, combined_noisy_channel_eos_scores) num_remaining_sent -= len(finalized_sents) assert num_remaining_sent >= 0 if num_remaining_sent == 0: break if len(finalized_sents) > 0: new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass batch_mask = cand_indices.new_ones(bsz) batch_mask[cand_indices.new(finalized_sents)] = 0 batch_idxs = torch.nonzero(batch_mask).squeeze(-1) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] bbsz_offsets.resize_(new_bsz, 1) cand_bbsz_idx = cand_beams.add(bbsz_offsets) lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs] fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs] cand_indices = cand_indices[batch_idxs] if prefix_tokens is not None: prefix_tokens = prefix_tokens[batch_idxs] src_lengths_no_eos = src_lengths_no_eos[batch_idxs] cands_to_ignore = cands_to_ignore[batch_idxs] scores = scores.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) scores_buf.resize_as_(scores) tokens = tokens.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) tokens_buf.resize_as_(tokens) src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) lm_prefix_scores = lm_prefix_scores.view( bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1).squeeze() if attn is not None: attn = attn.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, attn.size(1), -1) attn_buf.resize_as_(attn) bsz = new_bsz else: batch_idxs = None # Set active_mask so that values > cand_size indicate eos or # ignored hypos and values < cand_size indicate candidate # active hypos. After this, the min values per row are the top # candidate active hypos. eos_mask[:, :beam_size] |= cands_to_ignore active_mask = torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[:eos_mask.size(1)], ) # get the top beam_size active hypotheses, which are just the hypos # with the smallest values in active_mask active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer( 'new_cands_to_ignore') torch.topk(active_mask, k=beam_size, dim=1, largest=False, out=(new_cands_to_ignore, active_hypos)) # update cands_to_ignore to ignore any finalized hypos cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] assert (~cands_to_ignore).any(dim=1).all() active_bbsz_idx = buffer('active_bbsz_idx') torch.gather( cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx, ) active_scores = torch.gather( fw_lprobs_top_k, dim=1, index=active_hypos, out=scores[:, step].view(bsz, beam_size), ) active_bbsz_idx = active_bbsz_idx.view(-1) active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses torch.index_select( tokens[:, :step + 1], dim=0, index=active_bbsz_idx, out=tokens_buf[:, :step + 1], ) torch.gather( cand_indices, dim=1, index=active_hypos, out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], ) if step > 0: torch.index_select( scores[:, :step], dim=0, index=active_bbsz_idx, out=scores_buf[:, :step], ) torch.gather( fw_lprobs_top_k, dim=1, index=active_hypos, out=scores_buf.view(bsz, beam_size, -1)[:, :, step], ) torch.gather(lm_lprobs_top_k, dim=1, index=active_hypos, out=lm_prefix_scores.view(bsz, beam_size)) # copy attention for active hypotheses if attn is not None: torch.index_select( attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, out=attn_buf[:, :, :step + 2], ) # swap buffers tokens, tokens_buf = tokens_buf, tokens scores, scores_buf = scores_buf, scores if attn is not None: attn, attn_buf = attn_buf, attn # reorder incremental state in decoder reorder_state = active_bbsz_idx # sort by score descending for sent in range(len(finalized)): finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) return finalized
class FairseqPredictor(Predictor): """Predictor for using fairseq models.""" def __init__(self, model_path, user_dir, lang_pair, n_cpu_threads=-1): """Initializes a fairseq predictor. Args: model_path (string): Path to the fairseq model (*.pt). Like --path in fairseq-interactive. lang_pair (string): Language pair string (e.g. 'en-fr'). user_dir (string): Path to fairseq user directory. n_cpu_threads (int): Number of CPU threads. If negative, use GPU. """ super(FairseqPredictor, self).__init__() _initialize_fairseq(user_dir) self.use_cuda = torch.cuda.is_available() and n_cpu_threads < 0 parser = options.get_generation_parser() input_args = ["--path", model_path, os.path.dirname(model_path)] if lang_pair: src, trg = lang_pair.split("-") input_args.extend(["--source-lang", src, "--target-lang", trg]) args = options.parse_args_and_arch(parser, input_args) # Setup task, e.g., translation task = tasks.setup_task(args) self.src_vocab_size = len(task.source_dictionary) self.trg_vocab_size = len(task.target_dictionary) self.pad_id = task.source_dictionary.pad() # Load ensemble logging.info('Loading fairseq model(s) from {}'.format(model_path)) self.models, _ = checkpoint_utils.load_model_ensemble( model_path.split(':'), task=task, ) # Optimize ensemble for generation for model in self.models: model.make_generation_fast_( beamable_mm_beam_size=1, need_attn=False, ) if self.use_cuda: model.cuda() self.model = EnsembleModel(self.models) self.model.eval() def get_unk_probability(self, posterior): """Fetch posterior[utils.UNK_ID]""" return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF) def predict_next(self): """Call the fairseq model.""" lprobs, _ = self.model.forward_decoder( torch.LongTensor([self.consumed]), self.encoder_outs) lprobs[0, self.pad_id] = utils.NEG_INF return np.array(lprobs[0]) def initialize(self, src_sentence): """Initialize source tensors, reset consumed.""" self.consumed = [] src_tokens = torch.LongTensor([ utils.oov_to_unk(src_sentence + [utils.EOS_ID], self.src_vocab_size) ]) src_lengths = torch.LongTensor([len(src_sentence) + 1]) if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() self.encoder_outs = self.model.forward_encoder({ 'src_tokens': src_tokens, 'src_lengths': src_lengths }) self.consumed = [utils.GO_ID or utils.EOS_ID] # Reset incremental states for model in self.models: self.model.incremental_states[model] = {} def consume(self, word): """Append ``word`` to the current history.""" self.consumed.append(word) def get_state(self): """The predictor state is the complete history.""" return self.consumed, [ self.model.incremental_states[m] for m in self.models ] def set_state(self, state): """The predictor state is the complete history.""" self.consumed, inc_states = state for model, inc_state in zip(self.models, inc_states): self.model.incremental_states[model] = inc_state def is_equal(self, state1, state2): """Returns true if the history is the same """ return state1[0] == state2[0]
class FairseqPredictor(Predictor): """Predictor for using fairseq models.""" name = 'fairseq' def __init__(self, args): super(FairseqPredictor, self).__init__() _initialize_fairseq(args.fairseq_user_dir) self.use_cuda = torch.cuda.is_available() and args.n_cpu_threads < 0 fairseq_args = get_fairseq_args(args.fairseq_path, args.fairseq_lang_pair) # Setup task, e.g., translation task = tasks.setup_task(fairseq_args) source_dict = task.source_dictionary target_dict = task.target_dictionary self.src_vocab_size = len(source_dict) + 1 self.trg_vocab_size = len(target_dict) + 1 self.pad_id = target_dict.pad() # Load ensemble self.models = self.load_models(args.fairseq_path, task) self.model = EnsembleModel(self.models) self.model.eval() self.incremental_states = [{}]*len(self.models) def load_models(self, model_path, task): logging.info('Loading fairseq model(s) from {}'.format(model_path)) models, _ = checkpoint_utils.load_model_ensemble( model_path.split(':'), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=1, need_attn=False, ) if self.use_cuda: model.cuda() return models def get_unk_probability(self, posterior): """Fetch posterior[utils.UNK_ID]""" return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF) @torch.no_grad() def predict_next(self): """Call the fairseq model.""" inputs = torch.LongTensor([self.consumed]) if self.use_cuda: inputs = inputs.cuda() lprobs, _ = self.model.forward_decoder( inputs, self.encoder_outs, self.incremental_states) lprobs[:, self.pad_id] = utils.NEG_INF return np.array(lprobs[0].cpu() if self.use_cuda else lprobs[0], dtype=np.float64) @torch.no_grad() def initialize(self, src_sentence): """Initialize source tensors, reset consumed.""" src_tokens = torch.LongTensor([ utils.oov_to_unk(src_sentence + [utils.EOS_ID], self.src_vocab_size)]) src_lengths = torch.LongTensor([len(src_sentence) + 1]) if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() self.encoder_outs = self.model.forward_encoder({ 'src_tokens': src_tokens, 'src_lengths': src_lengths}) self.consumed = [utils.GO_ID or utils.EOS_ID] self.reset_states() def reset_states(self, states=None): # Reset incremental states for i in range(len(self.models)): self.incremental_states[i] = {} def consume(self, word, i=None): """Append ``word`` to the current history.""" self.consumed.append(word) if i is None else self.consumed[i].append(word) def get_empty_str_prob(self): return self.get_initial_dist()[utils.EOS_ID].item() @torch.no_grad() def get_initial_dist(self): inputs = torch.LongTensor([[utils.GO_ID or utils.EOS_ID]]) if self.use_cuda: inputs = inputs.cuda() lprobs, _ = self.model.forward_decoder( inputs, self.encoder_outs, [{}]*len(self.models) ) return np.array(lprobs[0].cpu() if self.use_cuda else lprobs[0], dtype=np.float64) def get_state(self): """The predictor state is the complete history.""" return self.consumed, self.incremental_states def set_state(self, state): """The predictor state is the complete history.""" self.consumed, self.incremental_states = state def is_equal(self, state1, state2): """Returns true if the history is the same """ return state1[0] == state2[0] @staticmethod def add_args(parser): parser.add_argument("--fairseq_path", default="", help="Points to the model file (*.pt) for the fairseq " "predictor. Like --path in fairseq-interactive.") parser.add_argument("--fairseq_user_dir", default="", help="fairseq user directory for additional models.") parser.add_argument("--fairseq_lang_pair", default="", help="Language pair such as 'en-fr' for fairseq. Used " "to load fairseq dictionaries")
class FairseqPredictor(Predictor): """Predictor for using fairseq models.""" name = 'fairseq' def __init__(self, args): super(FairseqPredictor, self).__init__() _initialize_fairseq(args.fairseq_user_dir) self.use_cuda = torch.cuda.is_available() and args.n_cpu_threads < 0 fairseq_args = get_fairseq_args(args.fairseq_path, args.fairseq_lang_pair) # Setup task, e.g., translation task = tasks.setup_task(fairseq_args) source_dict = task.source_dictionary target_dict = task.target_dictionary self.src_vocab_size = len(source_dict) + 1 self.trg_vocab_size = len(target_dict) + 1 self.pad_id = target_dict.pad() # Load ensemble self.models = self.load_models(args.fairseq_path, task) self.model = EnsembleModel(self.models) self.model.eval() def load_models(self, model_path, task): logging.info('Loading fairseq model(s) from {}'.format(model_path)) models, _ = checkpoint_utils.load_model_ensemble( model_path.split(':'), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=1, need_attn=False, ) if self.use_cuda: model.cuda() return models def get_unk_probability(self, posterior): """Fetch posterior[utils.UNK_ID]""" return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF) def predict_next(self): """Call the fairseq model.""" inputs = torch.LongTensor([self.consumed]) if self.use_cuda: inputs = inputs.cuda() lprobs, _ = self.model.forward_decoder(inputs, self.encoder_outs) lprobs[:, self.pad_id] = utils.NEG_INF return np.array(lprobs[0].cpu() if self.use_cuda else lprobs[0], dtype=np.float64) def initialize(self, src_sentence): """Initialize source tensors, reset consumed.""" src_tokens = torch.LongTensor([ utils.oov_to_unk(src_sentence + [utils.EOS_ID], self.src_vocab_size) ]) src_lengths = torch.LongTensor([len(src_sentence) + 1]) if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() self.encoder_outs = self.model.forward_encoder({ 'src_tokens': src_tokens, 'src_lengths': src_lengths }) self.consumed = [utils.GO_ID or utils.EOS_ID] self.reset_states() def reset_states(self, states=None): # Reset incremental states if states is not None: assert len(states) == len(self.models) for i, model in enumerate(self.models): self.model.incremental_states[ model] = {} if states is None else states[i] def consume(self, word, i=None): """Append ``word`` to the current history.""" self.consumed.append(word) if i is None else self.consumed[i].append( word) def get_empty_str_prob(self): return self.get_initial_dist()[utils.EOS_ID].item() def get_initial_dist(self): old_states = [self.model.incremental_states[m] for m in self.models] self.reset_states() inputs = torch.LongTensor([[utils.GO_ID or utils.EOS_ID]]) if self.use_cuda: inputs = inputs.cuda() lprobs, _ = self.model.forward_decoder(inputs, self.encoder_outs) self.reset_states(old_states) return np.array(lprobs[0].cpu() if self.use_cuda else lprobs[0], dtype=np.float64) def get_state(self): """The predictor state is the complete history.""" return self.consumed, [ self.model.incremental_states[m] for m in self.models ] def set_state(self, state): """The predictor state is the complete history.""" self.consumed, inc_states = state for model, inc_state in zip(self.models, inc_states): self.model.incremental_states[model] = inc_state def is_equal(self, state1, state2): """Returns true if the history is the same """ return state1[0] == state2[0]
class FairseqPredictor(Predictor): """Predictor for using fairseq models.""" def __init__(self, model_path, user_dir, lang_pair, n_cpu_threads=-1, subtract_uni=False, subtract_marg=False, marg_path=None, lmbda=1.0, ppmi=False, epsilon=0): """Initializes a fairseq predictor. Args: model_path (string): Path to the fairseq model (*.pt). Like --path in fairseq-interactive. lang_pair (string): Language pair string (e.g. 'en-fr'). user_dir (string): Path to fairseq user directory. n_cpu_threads (int): Number of CPU threads. If negative, use GPU. """ super(FairseqPredictor, self).__init__() _initialize_fairseq(user_dir) self.use_cuda = torch.cuda.is_available() and n_cpu_threads < 0 args = get_fairseq_args(model_path, lang_pair) # Setup task, e.g., translation task = tasks.setup_task(args) source_dict = task.source_dictionary target_dict = task.target_dictionary self.src_vocab_size = len(source_dict) + 1 self.trg_vocab_size = len(target_dict) + 1 self.pad_id = target_dict.pad() self.eos_id = target_dict.eos() self.bos_id = target_dict.bos() # Load ensemble self.models = self.load_models(model_path, task) self.model = EnsembleModel(self.models) self.model.eval() assert not subtract_marg & subtract_uni self.use_uni_dist = subtract_uni self.use_marg_dist = subtract_marg assert not ppmi or subtract_marg or subtract_uni self.lmbda = lmbda if self.use_uni_dist: unigram_dist = torch.Tensor(target_dict.count) #change frequency of eos to frequency of '.' so it's more realistic. unigram_dist[self.eos_id] = unigram_dist[target_dict.index('.')] self.log_uni_dist = unigram_dist.cuda( ) if self.use_cuda else unigram_dist self.log_uni_dist = (self.log_uni_dist / self.log_uni_dist.sum()).log() if self.use_marg_dist: if not marg_path: raise AttributeError( "No path (--marg_path) given for marginal model when --subtract_marg used" ) args = get_fairseq_args(marg_path, lang_pair) self.ppmi = ppmi self.eps = epsilon # Setup task, e.g., translation task = tasks.setup_task(args) assert source_dict == task.source_dictionary assert target_dict == task.target_dictionary # Load ensemble self.marg_models = self.load_models(marg_path, task) self.marg_model = EnsembleModel(self.marg_models) self.marg_model.eval() def load_models(self, model_path, task): logging.info('Loading fairseq model(s) from {}'.format(model_path)) models, _ = checkpoint_utils.load_model_ensemble( model_path.split(':'), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=1, need_attn=False, ) if self.use_cuda: model.cuda() return models def get_unk_probability(self, posterior): """Fetch posterior[utils.UNK_ID]""" return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF) def predict_next(self): """Call the fairseq model.""" inputs = torch.LongTensor([self.consumed]) if self.use_cuda: inputs = inputs.cuda() lprobs, _ = self.model.forward_decoder(inputs, self.encoder_outs) lprobs[0, self.pad_id] = utils.NEG_INF if self.use_uni_dist: lprobs[0] = lprobs[0] - self.lmbda * self.log_uni_dist if self.use_marg_dist: marg_lprobs, _ = self.marg_model.forward_decoder( inputs, self.marg_encoder_outs) if self.ppmi: marg_lprobs[0] = torch.clamp(marg_lprobs[0], -self.eps) lprobs[0] = lprobs[0] - self.lmbda * marg_lprobs[0] return lprobs[0] if self.use_cuda else np.array(lprobs[0]) def initialize(self, src_sentence): """Initialize source tensors, reset consumed.""" self.consumed = [] src_tokens = torch.LongTensor([ utils.oov_to_unk(src_sentence + [utils.EOS_ID], self.src_vocab_size) ]) src_lengths = torch.LongTensor([len(src_sentence) + 1]) if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() self.encoder_outs = self.model.forward_encoder({ 'src_tokens': src_tokens, 'src_lengths': src_lengths }) self.consumed = [utils.GO_ID or utils.EOS_ID] # Reset incremental states for model in self.models: self.model.incremental_states[model] = {} if self.use_marg_dist: self.initialize_marg() def initialize_marg(self): """Initialize source tensors, reset consumed.""" src_tokens = torch.LongTensor( [utils.oov_to_unk([utils.EOS_ID], self.src_vocab_size)]) src_lengths = torch.LongTensor([1]) if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() self.marg_encoder_outs = self.marg_model.forward_encoder({ 'src_tokens': src_tokens, 'src_lengths': src_lengths }) # Reset incremental states for model in self.marg_models: self.marg_model.incremental_states[model] = {} def consume(self, word): """Append ``word`` to the current history.""" self.consumed.append(word) def get_empty_str_prob(self): inputs = torch.LongTensor([[utils.GO_ID or utils.EOS_ID]]) if self.use_cuda: inputs = inputs.cuda() lprobs, _ = self.model.forward_decoder(inputs, self.encoder_outs) if self.use_uni_dist: lprobs[0] = lprobs[0] - self.lmbda * self.log_uni_dist if self.use_marg_dist: lprobs_marg, _ = self.marg_model.forward_decoder( inputs, self.marg_encoder_outs) eos_prob = (lprobs[0, self.eos_id] - self.lmbda * lprobs_marg[0, self.eos_id]).item() if self.ppmi: return min(eos_prob, 0) return eos_prob return lprobs[0, self.eos_id].item() def get_state(self): """The predictor state is the complete history.""" return self.consumed, [ self.model.incremental_states[m] for m in self.models ] def set_state(self, state): """The predictor state is the complete history.""" consumed, inc_states = state self.consumed = copy.copy(consumed) for model, inc_state in zip(self.models, inc_states): self.model.incremental_states[model] = inc_state def is_equal(self, state1, state2): """Returns true if the history is the same """ return state1[0] == state2[0]
def _paraphrase_sample(self, model, sample, sample_topN): """ model: MT model being trained sample: fairseq data structure for training batch sample_topN: number of top candidates to sample from in paraphraser output softmax """ # disable training model dropout model.eval() # disable paraphraser dropout # train() on the paraphraser model automatically gets when train() is called on the criteron, # we need to set it back to eval mode self.paraphraser_model.eval() # this should disable dropout self.paraphraser_model.training = False # not sure if this does anything pad = self.task.target_dictionary.pad() eos = self.task.target_dictionary.eos() bos = self.task.target_dictionary.bos() assert pad == self.task.source_dictionary.pad() assert eos == self.task.source_dictionary.eos() assert bos == self.task.source_dictionary.bos() # we don't know how long the paraphrase will be, so we take the target length and increase it a bit. target_length = sample['target'].shape[1] max_paraphrase_length = int(2 * target_length) + 3 batch_size = sample['net_input']['prev_output_tokens'].shape[0] combined_tokens = sample['net_input']['prev_output_tokens'][:, :1] combined_tokens[:, :] = eos # eos to match 'bug' in fairseq ("should" be bos) # make the target look like a source, to feed it into the paraphraser encoder paraphraser_src_lengths = torch.ones(batch_size, dtype=torch.int) paraphraser_source = sample['target'].new_zeros( tuple(sample['target'].shape)) + pad for i in range(batch_size): n_pad = (sample['target'][i] == pad).sum() paraphraser_src_lengths[i] = target_length - n_pad paraphraser_source[i, n_pad:target_length] = sample['target'][ i, :target_length - n_pad] paraphraser_prediction_tokens_list = [] paraphraser_probs_list = [] paraphraser = EnsembleModel([ self.paraphraser_model, ]) paraphraser_encoder_out = paraphraser.forward_encoder( dict(src_tokens=paraphraser_source, src_lengths=paraphraser_src_lengths)) if self.paraphraser_lang_prefix: # take one step update the state of the paraphraser, so that the "first" time step # in the loop below will pass in the language prefix paraphraser_probs, _ = paraphraser.forward_decoder( tokens=combined_tokens, encoder_outs=paraphraser_encoder_out, temperature=self.paraphraser_temperature, use_log_probs=False) prefixed_combined_tokens = sample['net_input'][ 'prev_output_tokens'][:, :2] prefixed_combined_tokens[:, 0] = eos # eos to match bug in fairseq ("should" be bos) prefixed_combined_tokens[:, 1] = self.task.target_dictionary.index( self.paraphraser_lang_prefix) else: prefixed_combined_tokens = None done = [ False, ] * batch_size for ii in range(max_paraphrase_length + 1): # paraphraser prefix may or may not have the language tag prepended (after the go symbol) to input if prefixed_combined_tokens is None: paraphraser_combined_tokens = combined_tokens else: paraphraser_combined_tokens = prefixed_combined_tokens # this is used to compute the loss paraphraser_probs, _ = paraphraser.forward_decoder( tokens=paraphraser_combined_tokens, encoder_outs=paraphraser_encoder_out, temperature=self.paraphraser_temperature, use_log_probs=False) # this is used to generate the previous context word paraphraser_probs_context = paraphraser_probs # save the paraphraser predictions to train toward (if we don't have a distribution loss) _, paraphraser_predictions = torch.max(paraphraser_probs, 1) if self.distribution_loss: paraphraser_probs_list.append(paraphraser_probs.unsqueeze(1)) # paraphraser predictions are simply the most likely next word, according to the paraphraser paraphraser_prediction_tokens_list.append( paraphraser_predictions.reshape((-1, 1))) combined_probs = paraphraser_probs_context # disallow length=0 paraphrases if ii == 0: combined_probs[:, eos] = 0.0 # disallow other undefined behavior combined_probs[:, pad] = 0.0 combined_probs[:, bos] = 0.0 if ii == max_paraphrase_length or all(done): break # sample from top N of paraphraser distribution if sample_topN == 1: _, combined_predictions = torch.max(combined_probs, 1) combined_predictions = combined_predictions.reshape((-1, 1)) else: topk_val, topk_ind = torch.topk(combined_probs, sample_topN) # re-normalize top values topk_val2 = topk_val / topk_val.sum(dim=1).reshape((-1, 1)) # make distribution from normalized topk values mm = dis.Categorical(topk_val2) # this will take un-normalized # sample indexes into topk topk_idx_idx = mm.sample().reshape((-1, 1)) # convert topk indexes back into vocab indexes combined_predictions = torch.cat( [v[i] for i, v in zip(topk_idx_idx, topk_ind)]).reshape( (-1, 1)) for jj in range(batch_size): if combined_predictions[jj, 0] == eos: done[jj] = True # append output tokens to input for next time step combined_tokens = torch.cat( (combined_tokens, combined_predictions), 1) if prefixed_combined_tokens is not None: prefixed_combined_tokens = torch.cat( (prefixed_combined_tokens, combined_predictions), 1) paraphraser_prediction_tokens = torch.cat( paraphraser_prediction_tokens_list, 1) if self.distribution_loss: paraphraser_probs_tokens = torch.cat(paraphraser_probs_list, 1) else: paraphraser_probs_tokens = None model.train() # re-enable dropout # compute length of valid output for each sentence n_tokens = 0 for i in range(batch_size): for j in range(paraphraser_prediction_tokens.shape[1]): if paraphraser_prediction_tokens[i, j] == eos: n_tokens += j # TODO should this include EOS? HK # set anything after EOS to PAD paraphraser_prediction_tokens[ i, j + 1:paraphraser_prediction_tokens.shape[1]] = pad break return combined_tokens, paraphraser_prediction_tokens, n_tokens, paraphraser_probs_tokens