class Model():
    """Wrapper around the stack-transformer model"""
    def __init__(self, models, target_dictionary):
        self.temperature = 1.
        self.target_dictionary = target_dictionary
        self.models = models
        self.reset()

    def reset(self):
        # This is to clear the cache of key values, there may be more efficient
        # ways
        self.model = EnsembleModel(self.models)
        # reset cache for encoder
        self.encoder_outs = None
        self.model.eval()

    def precompute_encoder(self, sample):
        """Encoder of the encoder-decoder is fixed and can be precomputed"""
        encoder_input = extract_encoder(sample)
        encoder_outs = self.model.forward_encoder(encoder_input)
        return encoder_outs

    def get_action(self, sample, parser_state, prev_actions):

        # Compute part of the model that does not depend on episode steps
        # (encoder). Cache it for future use
        # precompute encoder for speed
        if self.encoder_outs is None:
            self.encoder_outs = self.precompute_encoder(sample)

        # call model with pre-computed encoder, previous generated actions
        # (tokens) and state machine status
        lprobs, avg_attn_scores = self.model.forward_decoder(
            prev_actions,
            self.encoder_outs,
            parser_state,
            temperature=self.temperature)

        # Get most probable action
        if True:
            best_action_indices = lprobs.argmax(dim=1).tolist()
        else:
            # sampling
            best_action_indices = torch.squeeze(lprobs.exp().multinomial(1),
                                                1).tolist()
        actions = [self.target_dictionary[i] for i in best_action_indices]
        actions_lprob = [lprobs[0, i] for i in best_action_indices]
        return actions, actions_lprob
Exemple #2
0
    def generate(self,
                 models,
                 sample,
                 prefix_tokens=None,
                 bos_token=None,
                 **kwargs):
        """Generate a batch of translations.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            sample (dict): batch
            prefix_tokens (torch.LongTensor, optional): force decoder to begin
                with these tokens
        """
        model = EnsembleModel(models)
        if not self.retain_dropout:
            model.eval()

        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v
            for k, v in sample['net_input'].items()
            if k != 'prev_output_tokens'
        }

        src_tokens = encoder_input['src_tokens']
        src_lengths = (src_tokens.ne(self.eos)
                       & src_tokens.ne(self.pad)).long().sum(dim=1)
        input_size = src_tokens.size()
        # batch dimension goes first followed by source lengths
        bsz = input_size[0]
        src_len = input_size[1]
        beam_size = self.beam_size

        if self.match_source_len:
            max_len = src_lengths.max().item()
        else:
            max_len = min(
                int(self.max_len_a * src_len + self.max_len_b),
                # exclude the EOS marker
                model.max_decoder_positions() - 1,
            )

        # compute the encoder output for each beam
        encoder_outs = model.forward_encoder(encoder_input)
        new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
        new_order = new_order.to(src_tokens.device).long()
        encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)

        # initialize buffers
        scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens.data.new(bsz * beam_size,
                                     max_len + 2).long().fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = bos_token or self.eos
        attn, attn_buf = None, None
        nonpad_idxs = None

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{
            'idx': None,
            'score': -math.inf
        } for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
                        beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early or step == max_len or unfinalized_scores is None:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                best_unfinalized_score = unfinalized_scores[sent].max()
                if self.normalize_scores:
                    best_unfinalized_score /= max_len**self.len_penalty
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step,
                           bbsz_idx,
                           eos_scores,
                           unfinalized_scores=None):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(
                0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1)**self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

            sents_seen = set()
            for i, (idx, score) in enumerate(
                    zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                if self.match_source_len and step > src_lengths[unfin_idx]:
                    score = -math.inf

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i][nonpad_idxs[sent]]
                        _, alignment = hypo_attn.max(dim=0)
                    else:
                        hypo_attn = None
                        alignment = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': alignment,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif not self.stop_early and score > worst_finalized[sent][
                        'score']:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]['idx']
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]),
                                 key=lambda r: r[1]['score'])
                    worst_finalized[sent] = {
                        'score': s['score'],
                        'idx': idx,
                    }

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step,
                                                      unfinalized_scores):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished

        reorder_state = None
        batch_idxs = None
        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(
                        batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(
                        corr.unsqueeze(-1) * beam_size)
                model.reorder_incremental_state(reorder_state)
                model.reorder_encoder_out(encoder_outs, reorder_state)

            lprobs, avg_attn_scores = model.forward_decoder(
                tokens[:, :step + 1], encoder_outs)

            lprobs[:, self.pad] = -math.inf  # never select pad
            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty

            if self.no_repeat_ngram_size > 0:
                # for each beam and batch sentence, generate a list of previous ngrams
                gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
                for bbsz_idx in range(bsz * beam_size):
                    gen_tokens = tokens[bbsz_idx].tolist()
                    for ngram in zip(*[
                            gen_tokens[i:]
                            for i in range(self.no_repeat_ngram_size)
                    ]):
                        gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
                            gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]

            # Record attention scores
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, src_tokens.size(1),
                                      max_len + 2)
                    attn_buf = attn.clone()
                    nonpad_idxs = src_tokens.ne(self.pad)
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)
            if step < max_len:
                self.search.set_src_lengths(src_lengths)

                if self.no_repeat_ngram_size > 0:

                    def calculate_banned_tokens(bbsz_idx):
                        # before decoding the next token, prevent decoding of ngrams that have already appeared
                        ngram_index = tuple(
                            tokens[bbsz_idx,
                                   step + 2 - self.no_repeat_ngram_size:step +
                                   1].tolist())
                        return gen_ngrams[bbsz_idx].get(ngram_index, [])

                    if step + 2 - self.no_repeat_ngram_size >= 0:
                        # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
                        banned_tokens = [
                            calculate_banned_tokens(bbsz_idx)
                            for bbsz_idx in range(bsz * beam_size)
                        ]
                    else:
                        banned_tokens = [[]
                                         for bbsz_idx in range(bsz * beam_size)
                                         ]

                    for bbsz_idx in range(bsz * beam_size):
                        lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf

                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:,
                                                                        0, :]
                    cand_scores = torch.gather(
                        probs_slice,
                        dim=1,
                        index=prefix_tokens[:, step].view(-1, 1)).view(
                            -1, 1).repeat(1, cand_size)
                    if step > 0:
                        # save cumulative scores for each hypothesis
                        cand_scores.add_(scores[:, step - 1].view(
                            bsz, beam_size).repeat(1, 2))
                    cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(
                        1, cand_size)
                    cand_beams = torch.zeros_like(cand_indices)

                    # handle prefixes of different lengths
                    partial_prefix_mask = prefix_tokens[:, step].eq(self.pad)
                    if partial_prefix_mask.any():
                        partial_scores, partial_indices, partial_beams = self.search.step(
                            step,
                            lprobs.view(bsz, -1, self.vocab_size),
                            scores.view(bsz, beam_size, -1)[:, :, :step],
                        )
                        cand_scores[partial_prefix_mask] = partial_scores[
                            partial_prefix_mask]
                        cand_indices[partial_prefix_mask] = partial_indices[
                            partial_prefix_mask]
                        cand_beams[partial_prefix_mask] = partial_beams[
                            partial_prefix_mask]
                else:
                    cand_scores, cand_indices, cand_beams = self.search.step(
                        step,
                        lprobs.view(bsz, -1, self.vocab_size),
                        scores.view(bsz, beam_size, -1)[:, :, :step],
                    )
            else:
                # make probs contain cumulative scores for each hypothesis
                lprobs.add_(scores[:, step - 1].unsqueeze(-1))

                # finalize all active hypotheses once we hit max_len
                # pick the hypothesis with the highest prob of EOS right now
                torch.sort(
                    lprobs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
                num_remaining_sent -= len(
                    finalize_hypos(step, eos_bbsz_idx, eos_scores))
                assert num_remaining_sent == 0
                break

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)

            finalized_sents = set()
            if step >= self.min_len:
                # only consider eos when it's among the top beam_size indices
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
                if eos_bbsz_idx.numel() > 0:
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
                    finalized_sents = finalize_hypos(step, eos_bbsz_idx,
                                                     eos_scores, cand_scores)
                    num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < max_len

            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = batch_mask.nonzero().squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths = src_lengths[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(
                        new_bsz * beam_size, attn.size(1), -1)
                    attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer('active_mask')
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
            torch.topk(active_mask,
                       k=beam_size,
                       dim=1,
                       largest=False,
                       out=(_ignore, active_hypos))

            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx,
                dim=1,
                index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            if attn is not None:
                torch.index_select(
                    attn[:, :, :step + 2],
                    dim=0,
                    index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],
                )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent],
                                     key=lambda r: r['score'],
                                     reverse=True)

        return finalized
Exemple #3
0
class SimpleGreedyDecoder(nn.Module):
    def __init__(
        self,
        models,
        dictionary,
        max_len_a=0,
        max_len_b=200,
        max_len=0,
        temperature=1.0,
        eos=None,
        symbols_to_strip_from_output=None,
        for_validation=True,
        **kwargs,
    ):
        """Decode given speech audios with the simple greedy search.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models,
                currently support fairseq.models.TransformerModel for scripting
            dictionary (~fairseq.data.Dictionary): dictionary
            max_len_a/b (int, optional): generate sequences of maximum length
                ax + b, where x is the source length
            max_len (int, optional): the maximum length of the generated output
                (not including end-of-sentence)
            temperature (float, optional): temperature, where values
                >1.0 produce more uniform samples and values <1.0 produce
                sharper samples (default: 1.0)
            for_validation (bool, optional): indicate whether the decoder is
                used for validation. It affects how max_len is determined, and
                whether a tensor of lprobs is returned. If true, target should be
                not None
        """
        super().__init__()
        from fairseq.sequence_generator import EnsembleModel

        if isinstance(models, EnsembleModel):
            self.model = models
        else:
            self.model = EnsembleModel(models)
        self.pad = dictionary.pad()
        self.unk = dictionary.unk()
        self.eos = dictionary.eos() if eos is None else eos
        self.symbols_to_strip_from_output = (
            symbols_to_strip_from_output.union({self.eos})
            if symbols_to_strip_from_output is not None else {self.eos})
        self.vocab_size = len(dictionary)
        self.max_len_a = max_len_a
        self.max_len_b = max_len_b
        self.max_len = max_len or self.model.max_decoder_positions()
        self.temperature = temperature
        assert temperature > 0, "--temperature must be greater than 0"

        self.model.eval()
        self.for_validation = for_validation

    def cuda(self):
        self.model.cuda()
        return self

    @torch.no_grad()
    def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
        """Generate a batch of translations. Match the api of other fairseq generators.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            sample (dict): batch
            bos_token (int, optional): beginning of sentence token
                (default: self.eos)
        """
        return self._decode(sample, **kwargs)

    @torch.no_grad()
    def _decode(self,
                sample: Dict[str, Dict[str, Tensor]],
                bos_token: Optional[int] = None):
        incremental_states = torch.jit.annotate(
            List[Dict[str, Dict[str, Optional[Tensor]]]],
            [
                torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
                for i in range(self.model.models_size)
            ],
        )
        net_input = sample["net_input"]
        src_tokens = net_input["src_tokens"]
        bsz, src_len = src_tokens.size()[:2]

        # compute the encoder output
        encoder_outs = self.model.forward_encoder(net_input)
        target = sample["target"]
        # target can only be None if not for validation
        assert target is not None or not self.for_validation
        max_encoder_output_length = encoder_outs[0]["encoder_out"][0].size(0)
        # for validation, make the maximum decoding length equal to at least the
        # length of target, and the length of encoder_out if possible; otherwise
        # max_len is obtained from max_len_a/b
        max_len = (max(max_encoder_output_length, target.size(1))
                   if self.for_validation else min(
                       int(self.max_len_a * src_len + self.max_len_b),
                       self.max_len - 1,
                   ))

        tokens = src_tokens.new(bsz, max_len + 2).long().fill_(self.pad)
        tokens[:, 0] = self.eos if bos_token is None else bos_token
        # lprobs is only used when target is not None (i.e., for validation)
        lprobs = (encoder_outs[0]["encoder_out"][0].new_full(
            (bsz, target.size(1), self.vocab_size),
            -np.log(self.vocab_size),
        ) if self.for_validation else None)
        attn = None
        for step in range(max_len + 1):  # one extra step for EOS marker
            is_eos = tokens[:, step].eq(self.eos)
            if step > 0 and is_eos.sum() == is_eos.size(0):
                # all predictions are finished (i.e., ended with eos)
                tokens = tokens[:, :step + 1]
                if attn is not None:
                    attn = attn[:, :, :step + 1]
                break
            log_probs, avg_attn_scores = self.model.forward_decoder(
                tokens[:, :step + 1],
                encoder_outs,
                incremental_states,
                temperature=self.temperature,
            )
            tokens[:, step + 1] = log_probs.argmax(-1)
            if step > 0:  # deal with finished predictions
                # make log_probs uniform if the previous output token is EOS
                # and add consecutive EOS to the end of prediction
                log_probs[is_eos, :] = -np.log(log_probs.size(1))
                tokens[is_eos, step + 1] = self.eos
            if self.for_validation and step < target.size(1):
                lprobs[:, step, :] = log_probs

            # Record attention scores
            if type(avg_attn_scores) is list:
                avg_attn_scores = avg_attn_scores[0]
            if avg_attn_scores is not None:
                if attn is None:
                    attn = avg_attn_scores.new(bsz, max_encoder_output_length,
                                               max_len + 2)
                attn[:, :, step + 1].copy_(avg_attn_scores)

        return tokens[:, 1:], lprobs, attn
    def generate(self,
                 models,
                 sample,
                 prefix_tokens=None,
                 bos_token=None,
                 **kwargs):
        """Generate a batch of translations.
        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            sample (dict): batch
            prefix_tokens (torch.LongTensor, optional): force decoder to begin
                with these tokens
        """
        model = EnsembleModel(models)
        incremental_states = torch.jit.annotate(
            List[Dict[str, Dict[str, Optional[Tensor]]]],
            [
                torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
                for i in range(model.models_size)
            ],
        )
        if not self.retain_dropout:
            model.eval()

        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v
            for k, v in sample['net_input'].items()
            if k != 'prev_output_tokens'
        }
        src_tokens = encoder_input['src_tokens']
        src_lengths_no_eos = (src_tokens.ne(self.eos)
                              & src_tokens.ne(self.pad)).long().sum(dim=1)
        input_size = src_tokens.size()
        # batch dimension goes first followed by source lengths
        bsz = input_size[0]
        src_len = input_size[1]
        beam_size = self.beam_size

        if self.match_source_len:
            max_len = src_lengths_no_eos.max().item()
        else:
            max_len = min(
                int(self.max_len_a * src_len + self.max_len_b),
                # exclude the EOS marker
                model.max_decoder_positions() - 1,
            )

        # compute the encoder output for each beam
        encoder_outs = model.forward_encoder(encoder_input)
        new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
        new_order = new_order.to(src_tokens.device).long()
        encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)

        src_lengths = encoder_input['src_lengths']
        # initialize buffers
        scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
        lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0)

        scores_buf = scores.clone()
        tokens = src_tokens.new(bsz * beam_size,
                                max_len + 2).long().fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos if bos_token is None else bos_token

        # reorder source tokens so they may be used as a reference in generating P(S|T)
        src_tokens = reorder_all_tokens(src_tokens, src_lengths,
                                        self.src_dict.eos_index)

        src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len)
        src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view(
            bsz * beam_size, -1)

        attn, attn_buf = None, None
        nonpad_idxs = None

        # The cands_to_ignore indicates candidates that should be ignored.
        # For example, suppose we're sampling and have already finalized 2/5
        # samples. Then the cands_to_ignore would mark 2 positions as being ignored,
        # so that we only finalize the remaining 3 samples.
        cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq(
            -1)  # forward and backward-compatible False mask

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
                        beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfin_idx):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                return True
            return False

        def finalize_hypos(step, bbsz_idx, eos_scores,
                           combined_noisy_channel_eos_scores):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    fw scores for each hypothesis
                combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing
                    combined noisy channel scores for each hypothesis
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            assert not tokens_clone.eq(self.eos).any()
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(
                0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                combined_noisy_channel_eos_scores /= (step +
                                                      1)**self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

            sents_seen = set()
            for i, (idx, score) in enumerate(
                    zip(bbsz_idx.tolist(),
                        combined_noisy_channel_eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                if self.match_source_len and step > src_lengths_no_eos[
                        unfin_idx]:
                    score = -math.inf

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i][nonpad_idxs[sent]]
                        _, alignment = hypo_attn.max(dim=0)
                    else:
                        hypo_attn = None
                        alignment = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': alignment,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfin_idx):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished

        def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens,
                                    k):
            """Rescore the top k hypothesis from each beam using noisy channel modeling
            Returns:
                new_fw_lprobs: the direct model probabilities after pruning the top k
                new_ch_lm_lprobs:  the combined channel and language model probabilities
                new_lm_lprobs: the language model probabilities after pruning the top k
            """
            with torch.no_grad():
                lprobs_size = lprobs.size()
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:,
                                                                        0, :]
                    cand_scores = torch.gather(
                        probs_slice,
                        dim=1,
                        index=prefix_tokens[:, step].view(-1, 1).data).expand(
                            -1,
                            beam_size).contiguous().view(bsz * beam_size, 1)
                    cand_indices = prefix_tokens[:, step].view(-1, 1).expand(
                        bsz,
                        beam_size).data.contiguous().view(bsz * beam_size, 1)

                    # need to calculate and save fw and lm probs for prefix tokens
                    fw_top_k = cand_scores
                    fw_top_k_idx = cand_indices
                    k = 1
                else:
                    # take the top k best words for every sentence in batch*beam
                    fw_top_k, fw_top_k_idx = torch.topk(lprobs.view(
                        beam_size * bsz, -1),
                                                        k=k)
                eos_idx = torch.nonzero(
                    fw_top_k_idx.view(bsz * beam_size * k, -1) == self.eos)[:,
                                                                            0]
                ch_scores = fw_top_k.new_full((beam_size * bsz * k, ), 0)
                src_size = torch.sum(
                    src_tokens[:, :] != self.src_dict.pad_index,
                    dim=1,
                    keepdim=True,
                    dtype=fw_top_k.dtype)

                if self.combine_method != "lm_only":
                    temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view(
                        bsz * beam_size * k, -1)
                    not_padding = temp_src_tokens_full[:,
                                                       1:] != self.src_dict.pad_index
                    cur_tgt_size = step + 2

                    # add eos to all candidate sentences except those that already end in eos
                    eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1)
                    eos_tokens[eos_idx] = self.tgt_dict.pad_index

                    if step == 0:
                        channel_input = torch.cat(
                            (fw_top_k_idx.view(-1, 1), eos_tokens), 1)
                    else:
                        # move eos from beginning to end of target sentence
                        channel_input = torch.cat(
                            (tokens[:, 1:step + 1].repeat(1, k).view(-1, step),
                             fw_top_k_idx.view(-1, 1), eos_tokens), 1)

                    ch_input_lengths = torch.tensor(
                        np.full(channel_input.size(0), cur_tgt_size))
                    ch_input_lengths[eos_idx] = cur_tgt_size - 1
                    if self.channel_scoring_type == "unnormalized":
                        ch_encoder_output = channel_model.encoder(
                            channel_input, src_lengths=ch_input_lengths)
                        ch_decoder_output, _ = channel_model.decoder(
                            temp_src_tokens_full,
                            encoder_out=ch_encoder_output,
                            features_only=True)
                        del ch_encoder_output
                        ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target(
                            ch_decoder_output,
                            target_ids=temp_src_tokens_full[:, 1:])
                        ch_intermed_scores = ch_intermed_scores.float()
                        ch_intermed_scores *= not_padding.float()
                        ch_scores = torch.sum(ch_intermed_scores, dim=1)
                    elif self.channel_scoring_type == "k2_separate":
                        for k_idx in range(k):
                            k_eos_tokens = eos_tokens[k_idx::k, :]
                            if step == 0:
                                k_ch_input = torch.cat(
                                    (fw_top_k_idx[:, k_idx:k_idx + 1],
                                     k_eos_tokens), 1)
                            else:
                                # move eos from beginning to end of target sentence
                                k_ch_input = torch.cat(
                                    (tokens[:, 1:step + 1],
                                     fw_top_k_idx[:, k_idx:k_idx + 1],
                                     k_eos_tokens), 1)
                            k_ch_input_lengths = ch_input_lengths[k_idx::k]
                            k_ch_output = channel_model(
                                k_ch_input, k_ch_input_lengths, src_tokens)
                            k_ch_lprobs = channel_model.get_normalized_probs(
                                k_ch_output, log_probs=True)
                            k_ch_intermed_scores = torch.gather(
                                k_ch_lprobs[:, :-1, :], 2,
                                src_tokens[:, 1:].unsqueeze(2)).squeeze(2)
                            k_ch_intermed_scores *= not_padding.float()
                            ch_scores[k_idx::k] = torch.sum(
                                k_ch_intermed_scores, dim=1)
                    elif self.channel_scoring_type == "src_vocab":
                        ch_encoder_output = channel_model.encoder(
                            channel_input, src_lengths=ch_input_lengths)
                        ch_decoder_output, _ = channel_model.decoder(
                            temp_src_tokens_full,
                            encoder_out=ch_encoder_output,
                            features_only=True)

                        del ch_encoder_output
                        ch_lprobs = normalized_scores_with_batch_vocab(
                            channel_model.decoder,
                            ch_decoder_output,
                            src_tokens,
                            k,
                            bsz,
                            beam_size,
                            self.src_dict.pad_index,
                            top_k=self.top_k_vocab)
                        ch_scores = torch.sum(ch_lprobs, dim=1)
                    elif self.channel_scoring_type == "src_vocab_batched":
                        ch_bsz_size = temp_src_tokens_full.shape[0]
                        ch_lprobs_list = [None] * len(
                            range(0, ch_bsz_size, self.ch_scoring_bsz))
                        for i, start_idx in enumerate(
                                range(0, ch_bsz_size, self.ch_scoring_bsz)):
                            end_idx = min(start_idx + self.ch_scoring_bsz,
                                          ch_bsz_size)
                            temp_src_tokens_full_batch = temp_src_tokens_full[
                                start_idx:end_idx, :]
                            channel_input_batch = channel_input[
                                start_idx:end_idx, :]
                            ch_input_lengths_batch = ch_input_lengths[
                                start_idx:end_idx]
                            ch_encoder_output_batch = channel_model.encoder(
                                channel_input_batch,
                                src_lengths=ch_input_lengths_batch)
                            ch_decoder_output_batch, _ = channel_model.decoder(
                                temp_src_tokens_full_batch,
                                encoder_out=ch_encoder_output_batch,
                                features_only=True)
                            ch_lprobs_list[
                                i] = normalized_scores_with_batch_vocab(
                                    channel_model.decoder,
                                    ch_decoder_output_batch,
                                    src_tokens,
                                    k,
                                    bsz,
                                    beam_size,
                                    self.src_dict.pad_index,
                                    top_k=self.top_k_vocab,
                                    start_idx=start_idx,
                                    end_idx=end_idx)
                        ch_lprobs = torch.cat(ch_lprobs_list, dim=0)
                        ch_scores = torch.sum(ch_lprobs, dim=1)
                    else:
                        ch_output = channel_model(channel_input,
                                                  ch_input_lengths,
                                                  temp_src_tokens_full)
                        ch_lprobs = channel_model.get_normalized_probs(
                            ch_output, log_probs=True)
                        ch_intermed_scores = torch.gather(
                            ch_lprobs[:, :-1, :], 2,
                            temp_src_tokens_full[:, 1:].unsqueeze(
                                2)).squeeze().view(bsz * beam_size * k, -1)
                        ch_intermed_scores *= not_padding.float()
                        ch_scores = torch.sum(ch_intermed_scores, dim=1)

                else:
                    cur_tgt_size = 0
                ch_scores = ch_scores.view(bsz * beam_size, k)
                expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze(
                    1).expand(-1, k).flatten()

                if self.share_tgt_dict:
                    lm_scores = get_lm_scores(
                        lm, tokens[:, :step + 1].view(-1, step + 1),
                        lm_incremental_states, fw_top_k_idx.view(-1, 1),
                        torch.tensor(np.full(tokens.size(0), step + 1)), k)
                else:
                    new_lm_input = dict2dict(
                        tokens[:, :step + 1].view(-1, step + 1),
                        self.tgt_to_lm)
                    new_cands = dict2dict(fw_top_k_idx.view(-1, 1),
                                          self.tgt_to_lm)
                    lm_scores = get_lm_scores(
                        lm, new_lm_input, lm_incremental_states, new_cands,
                        torch.tensor(np.full(tokens.size(0), step + 1)), k)

                lm_scores.add_(expanded_lm_prefix_scores)
                ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores,
                                             lm_scores, src_size, cur_tgt_size)
                # initialize all as min value
                new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(
                    bsz * beam_size, -1)
                new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_(
                    -1e17).view(bsz * beam_size, -1)
                new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(
                    bsz * beam_size, -1)
                new_fw_lprobs[:, self.pad] = -math.inf
                new_ch_lm_lprobs[:, self.pad] = -math.inf
                new_lm_lprobs[:, self.pad] = -math.inf

                new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k)
                new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores)
                new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k))
                return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs

        def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size,
                          tgt_size):
            if self.channel_scoring_type == "unnormalized":
                ch_scores = self.log_softmax_fn(
                    ch_scores.view(-1, self.beam_size * self.k2)).view(
                        ch_scores.shape)
            ch_scores = ch_scores * self.ch_weight
            lm_scores1 = lm_scores1 * self.lm_weight

            if combine_type == "lm_only":
                # log P(T|S) + log P(T)
                ch_scores = lm_scores1.view(ch_scores.size())
            elif combine_type == "noisy_channel":
                # 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T)
                if self.normalize_lm_scores_by_tgt_len:
                    ch_scores.div_(src_size)
                    lm_scores_norm = lm_scores1.view(
                        ch_scores.size()).div(tgt_size)
                    ch_scores.add_(lm_scores_norm)
                # 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T)
                else:
                    ch_scores.add_(lm_scores1.view(ch_scores.size()))
                    ch_scores.div_(src_size)

            return ch_scores

        if self.channel_models is not None:
            channel_model = self.channel_models[
                0]  # assume only one channel_model model
        else:
            channel_model = None

        lm = EnsembleModel(self.lm_models)
        lm_incremental_states = torch.jit.annotate(
            List[Dict[str, Dict[str, Optional[Tensor]]]],
            [
                torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
                for i in range(lm.models_size)
            ],
        )

        reorder_state = None
        batch_idxs = None
        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(
                        batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(
                        corr.unsqueeze(-1) * beam_size)
                model.reorder_incremental_state(incremental_states,
                                                reorder_state)
                encoder_outs = model.reorder_encoder_out(
                    encoder_outs, reorder_state)

                lm.reorder_incremental_state(lm_incremental_states,
                                             reorder_state)

            fw_lprobs, avg_attn_scores = model.forward_decoder(
                tokens[:, :step + 1],
                encoder_outs,
                incremental_states,
                temperature=self.temperature,
            )

            fw_lprobs[:, self.pad] = -math.inf  # never select pad
            fw_lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty
            fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring(
                fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2)

            # handle min and max length constraints
            if step >= max_len:
                fw_lprobs[:, :self.eos] = -math.inf
                fw_lprobs[:, self.eos + 1:] = -math.inf
            elif step < self.min_len:
                fw_lprobs[:, self.eos] = -math.inf

            # handle prefix tokens (possibly with different lengths)
            if prefix_tokens is not None and step < prefix_tokens.size(1):
                prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(
                    1, beam_size).view(-1)
                prefix_mask = prefix_toks.ne(self.pad)

                prefix_fw_lprobs = fw_lprobs.gather(-1,
                                                    prefix_toks.unsqueeze(-1))
                fw_lprobs[prefix_mask] = -math.inf
                fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_(
                    -1, prefix_toks[prefix_mask].unsqueeze(-1),
                    prefix_fw_lprobs)

                prefix_ch_lm_lprobs = ch_lm_lprobs.gather(
                    -1, prefix_toks.unsqueeze(-1))
                ch_lm_lprobs[prefix_mask] = -math.inf
                ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_(
                    -1, prefix_toks[prefix_mask].unsqueeze(-1),
                    prefix_ch_lm_lprobs)

                prefix_lm_lprobs = lm_lprobs.gather(-1,
                                                    prefix_toks.unsqueeze(-1))
                lm_lprobs[prefix_mask] = -math.inf
                lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_(
                    -1, prefix_toks[prefix_mask].unsqueeze(-1),
                    prefix_lm_lprobs)

                # if prefix includes eos, then we should make sure tokens and
                # scores are the same across all beams
                eos_mask = prefix_toks.eq(self.eos)
                if eos_mask.any():
                    # validate that the first beam matches the prefix
                    first_beam = tokens[eos_mask].view(
                        -1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
                    eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
                    target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
                    assert (first_beam == target_prefix).all()

                    def replicate_first_beam(tensor, mask):
                        tensor = tensor.view(-1, beam_size, tensor.size(-1))
                        tensor[mask] = tensor[mask][:, :1, :]
                        return tensor.view(-1, tensor.size(-1))

                    # copy tokens, scores and lprobs from the first beam to all beams
                    tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
                    scores = replicate_first_beam(scores, eos_mask_batch_dim)

                    fw_lprobs = replicate_first_beam(fw_lprobs,
                                                     eos_mask_batch_dim)
                    ch_lm_lprobs = replicate_first_beam(
                        ch_lm_lprobs, eos_mask_batch_dim)
                    lm_lprobs = replicate_first_beam(lm_lprobs,
                                                     eos_mask_batch_dim)

            if self.no_repeat_ngram_size > 0:
                # for each beam and batch sentence, generate a list of previous ngrams
                gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
                for bbsz_idx in range(bsz * beam_size):
                    gen_tokens = tokens[bbsz_idx].tolist()
                    for ngram in zip(*[
                            gen_tokens[i:]
                            for i in range(self.no_repeat_ngram_size)
                    ]):
                        gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
                                gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]

            # Record attention scores
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, src_tokens.size(1),
                                      max_len + 2)
                    attn_buf = attn.clone()
                    nonpad_idxs = src_tokens.ne(self.pad)
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(fw_lprobs)
            scores_buf = scores_buf.type_as(fw_lprobs)

            self.search.set_src_lengths(src_lengths_no_eos)

            if self.no_repeat_ngram_size > 0:

                def calculate_banned_tokens(bbsz_idx):
                    # before decoding the next token, prevent decoding of ngrams that have already appeared
                    ngram_index = tuple(
                        tokens[bbsz_idx, step + 2 -
                               self.no_repeat_ngram_size:step + 1].tolist())
                    return gen_ngrams[bbsz_idx].get(ngram_index, [])

                if step + 2 - self.no_repeat_ngram_size >= 0:
                    # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
                    banned_tokens = [
                        calculate_banned_tokens(bbsz_idx)
                        for bbsz_idx in range(bsz * beam_size)
                    ]
                else:
                    banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]

                for bbsz_idx in range(bsz * beam_size):
                    fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf

            combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step(
                step, fw_lprobs.view(bsz, -1, self.vocab_size),
                scores.view(bsz, beam_size, -1)[:, :, :step],
                ch_lm_lprobs.view(bsz, -1, self.vocab_size),
                lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method)

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos (except for candidates to be ignored)
            eos_mask = cand_indices.eq(self.eos)
            eos_mask[:, :beam_size] &= ~cands_to_ignore

            # only consider eos when it's among the top beam_size indices
            eos_bbsz_idx = torch.masked_select(cand_bbsz_idx[:, :beam_size],
                                               mask=eos_mask[:, :beam_size])

            finalized_sents = set()
            if eos_bbsz_idx.numel() > 0:
                eos_scores = torch.masked_select(
                    fw_lprobs_top_k[:, :beam_size],
                    mask=eos_mask[:, :beam_size])
                combined_noisy_channel_eos_scores = torch.masked_select(
                    combined_noisy_channel_scores[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                )

                # finalize hypo using channel model score
                finalized_sents = finalize_hypos(
                    step, eos_bbsz_idx, eos_scores,
                    combined_noisy_channel_eos_scores)

                num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break

            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = torch.nonzero(batch_mask).squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)

                lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs]

                fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths_no_eos = src_lengths_no_eos[batch_idxs]
                cands_to_ignore = cands_to_ignore[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                lm_prefix_scores = lm_prefix_scores.view(
                    bsz, -1)[batch_idxs].view(new_bsz * beam_size,
                                              -1).squeeze()

                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(
                        new_bsz * beam_size, attn.size(1), -1)
                    attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

            # Set active_mask so that values > cand_size indicate eos or
            # ignored hypos and values < cand_size indicate candidate
            # active hypos. After this, the min values per row are the top
            # candidate active hypos.
            eos_mask[:, :beam_size] |= cands_to_ignore
            active_mask = torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer(
                'new_cands_to_ignore')
            torch.topk(active_mask,
                       k=beam_size,
                       dim=1,
                       largest=False,
                       out=(new_cands_to_ignore, active_hypos))

            # update cands_to_ignore to ignore any finalized hypos
            cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
            assert (~cands_to_ignore).any(dim=1).all()

            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx,
                dim=1,
                index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                fw_lprobs_top_k,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                fw_lprobs_top_k,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )
            torch.gather(lm_lprobs_top_k,
                         dim=1,
                         index=active_hypos,
                         out=lm_prefix_scores.view(bsz, beam_size))

            # copy attention for active hypotheses
            if attn is not None:
                torch.index_select(
                    attn[:, :, :step + 2],
                    dim=0,
                    index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],
                )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent],
                                     key=lambda r: r['score'],
                                     reverse=True)

        return finalized
Exemple #5
0
class FairseqPredictor(Predictor):
    """Predictor for using fairseq models."""
    def __init__(self, model_path, user_dir, lang_pair, n_cpu_threads=-1):
        """Initializes a fairseq predictor.

        Args:
            model_path (string): Path to the fairseq model (*.pt). Like
                                 --path in fairseq-interactive.
            lang_pair (string): Language pair string (e.g. 'en-fr').
            user_dir (string): Path to fairseq user directory.
            n_cpu_threads (int): Number of CPU threads. If negative,
                                 use GPU.
        """
        super(FairseqPredictor, self).__init__()
        _initialize_fairseq(user_dir)
        self.use_cuda = torch.cuda.is_available() and n_cpu_threads < 0

        parser = options.get_generation_parser()
        input_args = ["--path", model_path, os.path.dirname(model_path)]
        if lang_pair:
            src, trg = lang_pair.split("-")
            input_args.extend(["--source-lang", src, "--target-lang", trg])
        args = options.parse_args_and_arch(parser, input_args)

        # Setup task, e.g., translation
        task = tasks.setup_task(args)
        self.src_vocab_size = len(task.source_dictionary)
        self.trg_vocab_size = len(task.target_dictionary)
        self.pad_id = task.source_dictionary.pad()

        # Load ensemble
        logging.info('Loading fairseq model(s) from {}'.format(model_path))
        self.models, _ = checkpoint_utils.load_model_ensemble(
            model_path.split(':'),
            task=task,
        )

        # Optimize ensemble for generation
        for model in self.models:
            model.make_generation_fast_(
                beamable_mm_beam_size=1,
                need_attn=False,
            )
            if self.use_cuda:
                model.cuda()
        self.model = EnsembleModel(self.models)
        self.model.eval()

    def get_unk_probability(self, posterior):
        """Fetch posterior[utils.UNK_ID]"""
        return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF)

    def predict_next(self):
        """Call the fairseq model."""
        lprobs, _ = self.model.forward_decoder(
            torch.LongTensor([self.consumed]), self.encoder_outs)
        lprobs[0, self.pad_id] = utils.NEG_INF
        return np.array(lprobs[0])

    def initialize(self, src_sentence):
        """Initialize source tensors, reset consumed."""
        self.consumed = []
        src_tokens = torch.LongTensor([
            utils.oov_to_unk(src_sentence + [utils.EOS_ID],
                             self.src_vocab_size)
        ])
        src_lengths = torch.LongTensor([len(src_sentence) + 1])
        if self.use_cuda:
            src_tokens = src_tokens.cuda()
            src_lengths = src_lengths.cuda()
        self.encoder_outs = self.model.forward_encoder({
            'src_tokens':
            src_tokens,
            'src_lengths':
            src_lengths
        })
        self.consumed = [utils.GO_ID or utils.EOS_ID]
        # Reset incremental states
        for model in self.models:
            self.model.incremental_states[model] = {}

    def consume(self, word):
        """Append ``word`` to the current history."""
        self.consumed.append(word)

    def get_state(self):
        """The predictor state is the complete history."""
        return self.consumed, [
            self.model.incremental_states[m] for m in self.models
        ]

    def set_state(self, state):
        """The predictor state is the complete history."""
        self.consumed, inc_states = state
        for model, inc_state in zip(self.models, inc_states):
            self.model.incremental_states[model] = inc_state

    def is_equal(self, state1, state2):
        """Returns true if the history is the same """
        return state1[0] == state2[0]
class FairseqPredictor(Predictor):
    """Predictor for using fairseq models."""
    name = 'fairseq'
    def __init__(self, args):
        super(FairseqPredictor, self).__init__()
        _initialize_fairseq(args.fairseq_user_dir)

        self.use_cuda = torch.cuda.is_available() and args.n_cpu_threads < 0
        fairseq_args = get_fairseq_args(args.fairseq_path, args.fairseq_lang_pair)

        # Setup task, e.g., translation
        task = tasks.setup_task(fairseq_args)
        source_dict = task.source_dictionary
        target_dict = task.target_dictionary
        self.src_vocab_size = len(source_dict) + 1
        self.trg_vocab_size = len(target_dict) + 1
        self.pad_id = target_dict.pad()
         # Load ensemble
        self.models = self.load_models(args.fairseq_path, task)
        self.model = EnsembleModel(self.models)
        self.model.eval()
        self.incremental_states = [{}]*len(self.models)


    def load_models(self, model_path, task):
        logging.info('Loading fairseq model(s) from {}'.format(model_path))
        models, _ = checkpoint_utils.load_model_ensemble(
            model_path.split(':'),
            task=task,
        )

        # Optimize ensemble for generation
        for model in models:
            model.make_generation_fast_(
                beamable_mm_beam_size=1,
                need_attn=False,
            )
            if self.use_cuda:
                model.cuda()
        return models

    def get_unk_probability(self, posterior):
        """Fetch posterior[utils.UNK_ID]"""
        return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF)
        
    @torch.no_grad()  
    def predict_next(self):
        """Call the fairseq model."""
        inputs = torch.LongTensor([self.consumed])
        
        if self.use_cuda:
            inputs = inputs.cuda()
        lprobs, _  = self.model.forward_decoder(
            inputs, self.encoder_outs, self.incremental_states)
        lprobs[:, self.pad_id] = utils.NEG_INF
        return np.array(lprobs[0].cpu() if self.use_cuda else lprobs[0], dtype=np.float64)
    
    @torch.no_grad()   
    def initialize(self, src_sentence):
        """Initialize source tensors, reset consumed."""

        src_tokens = torch.LongTensor([
            utils.oov_to_unk(src_sentence + [utils.EOS_ID],
                             self.src_vocab_size)])
        src_lengths = torch.LongTensor([len(src_sentence) + 1])
        if self.use_cuda:
            src_tokens = src_tokens.cuda()
            src_lengths = src_lengths.cuda()
        self.encoder_outs = self.model.forward_encoder({
            'src_tokens': src_tokens,
            'src_lengths': src_lengths})

        self.consumed = [utils.GO_ID or utils.EOS_ID]
        self.reset_states()

    def reset_states(self, states=None):
         # Reset incremental states
        for i in range(len(self.models)):
            self.incremental_states[i] = {}
   
    def consume(self, word, i=None):
        """Append ``word`` to the current history."""
        self.consumed.append(word) if i is None else self.consumed[i].append(word)
    
    def get_empty_str_prob(self):
        return self.get_initial_dist()[utils.EOS_ID].item()

    @torch.no_grad()   
    def get_initial_dist(self):
        inputs = torch.LongTensor([[utils.GO_ID or utils.EOS_ID]])
        if self.use_cuda:
            inputs = inputs.cuda()
        
        lprobs, _ = self.model.forward_decoder(
            inputs, self.encoder_outs, [{}]*len(self.models)
        )
        return np.array(lprobs[0].cpu() if self.use_cuda else lprobs[0], dtype=np.float64)

    def get_state(self):
        """The predictor state is the complete history."""
        return self.consumed, self.incremental_states
    
    def set_state(self, state):
        """The predictor state is the complete history."""
        self.consumed, self.incremental_states = state

    def is_equal(self, state1, state2):
        """Returns true if the history is the same """
        return state1[0] == state2[0]

    @staticmethod
    def add_args(parser):
        parser.add_argument("--fairseq_path", default="",
                       help="Points to the model file (*.pt) for the fairseq "
                       "predictor. Like --path in fairseq-interactive.")
        parser.add_argument("--fairseq_user_dir", default="",
                           help="fairseq user directory for additional models.")
        parser.add_argument("--fairseq_lang_pair", default="",
                           help="Language pair such as 'en-fr' for fairseq. Used "
                           "to load fairseq dictionaries")
Exemple #7
0
class FairseqPredictor(Predictor):
    """Predictor for using fairseq models."""
    name = 'fairseq'

    def __init__(self, args):
        super(FairseqPredictor, self).__init__()
        _initialize_fairseq(args.fairseq_user_dir)
        self.use_cuda = torch.cuda.is_available() and args.n_cpu_threads < 0

        fairseq_args = get_fairseq_args(args.fairseq_path,
                                        args.fairseq_lang_pair)

        # Setup task, e.g., translation
        task = tasks.setup_task(fairseq_args)
        source_dict = task.source_dictionary
        target_dict = task.target_dictionary
        self.src_vocab_size = len(source_dict) + 1
        self.trg_vocab_size = len(target_dict) + 1
        self.pad_id = target_dict.pad()
        # Load ensemble
        self.models = self.load_models(args.fairseq_path, task)
        self.model = EnsembleModel(self.models)
        self.model.eval()

    def load_models(self, model_path, task):
        logging.info('Loading fairseq model(s) from {}'.format(model_path))
        models, _ = checkpoint_utils.load_model_ensemble(
            model_path.split(':'),
            task=task,
        )

        # Optimize ensemble for generation
        for model in models:
            model.make_generation_fast_(
                beamable_mm_beam_size=1,
                need_attn=False,
            )
            if self.use_cuda:
                model.cuda()
        return models

    def get_unk_probability(self, posterior):
        """Fetch posterior[utils.UNK_ID]"""
        return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF)

    def predict_next(self):
        """Call the fairseq model."""
        inputs = torch.LongTensor([self.consumed])

        if self.use_cuda:
            inputs = inputs.cuda()

        lprobs, _ = self.model.forward_decoder(inputs, self.encoder_outs)
        lprobs[:, self.pad_id] = utils.NEG_INF
        return np.array(lprobs[0].cpu() if self.use_cuda else lprobs[0],
                        dtype=np.float64)

    def initialize(self, src_sentence):
        """Initialize source tensors, reset consumed."""

        src_tokens = torch.LongTensor([
            utils.oov_to_unk(src_sentence + [utils.EOS_ID],
                             self.src_vocab_size)
        ])
        src_lengths = torch.LongTensor([len(src_sentence) + 1])
        if self.use_cuda:
            src_tokens = src_tokens.cuda()
            src_lengths = src_lengths.cuda()
        self.encoder_outs = self.model.forward_encoder({
            'src_tokens':
            src_tokens,
            'src_lengths':
            src_lengths
        })

        self.consumed = [utils.GO_ID or utils.EOS_ID]
        self.reset_states()

    def reset_states(self, states=None):
        # Reset incremental states
        if states is not None:
            assert len(states) == len(self.models)
        for i, model in enumerate(self.models):
            self.model.incremental_states[
                model] = {} if states is None else states[i]

    def consume(self, word, i=None):
        """Append ``word`` to the current history."""
        self.consumed.append(word) if i is None else self.consumed[i].append(
            word)

    def get_empty_str_prob(self):
        return self.get_initial_dist()[utils.EOS_ID].item()

    def get_initial_dist(self):
        old_states = [self.model.incremental_states[m] for m in self.models]
        self.reset_states()
        inputs = torch.LongTensor([[utils.GO_ID or utils.EOS_ID]])
        if self.use_cuda:
            inputs = inputs.cuda()

        lprobs, _ = self.model.forward_decoder(inputs, self.encoder_outs)
        self.reset_states(old_states)
        return np.array(lprobs[0].cpu() if self.use_cuda else lprobs[0],
                        dtype=np.float64)

    def get_state(self):
        """The predictor state is the complete history."""
        return self.consumed, [
            self.model.incremental_states[m] for m in self.models
        ]

    def set_state(self, state):
        """The predictor state is the complete history."""
        self.consumed, inc_states = state
        for model, inc_state in zip(self.models, inc_states):
            self.model.incremental_states[model] = inc_state

    def is_equal(self, state1, state2):
        """Returns true if the history is the same """
        return state1[0] == state2[0]
Exemple #8
0
class FairseqPredictor(Predictor):
    """Predictor for using fairseq models."""
    def __init__(self,
                 model_path,
                 user_dir,
                 lang_pair,
                 n_cpu_threads=-1,
                 subtract_uni=False,
                 subtract_marg=False,
                 marg_path=None,
                 lmbda=1.0,
                 ppmi=False,
                 epsilon=0):
        """Initializes a fairseq predictor.

        Args:
            model_path (string): Path to the fairseq model (*.pt). Like
                                 --path in fairseq-interactive.
            lang_pair (string): Language pair string (e.g. 'en-fr').
            user_dir (string): Path to fairseq user directory.
            n_cpu_threads (int): Number of CPU threads. If negative,
                                 use GPU.
        """
        super(FairseqPredictor, self).__init__()
        _initialize_fairseq(user_dir)
        self.use_cuda = torch.cuda.is_available() and n_cpu_threads < 0

        args = get_fairseq_args(model_path, lang_pair)

        # Setup task, e.g., translation
        task = tasks.setup_task(args)
        source_dict = task.source_dictionary
        target_dict = task.target_dictionary
        self.src_vocab_size = len(source_dict) + 1
        self.trg_vocab_size = len(target_dict) + 1
        self.pad_id = target_dict.pad()
        self.eos_id = target_dict.eos()
        self.bos_id = target_dict.bos()
        # Load ensemble
        self.models = self.load_models(model_path, task)
        self.model = EnsembleModel(self.models)
        self.model.eval()

        assert not subtract_marg & subtract_uni
        self.use_uni_dist = subtract_uni
        self.use_marg_dist = subtract_marg
        assert not ppmi or subtract_marg or subtract_uni

        self.lmbda = lmbda
        if self.use_uni_dist:
            unigram_dist = torch.Tensor(target_dict.count)
            #change frequency of eos to frequency of '.' so it's more realistic.
            unigram_dist[self.eos_id] = unigram_dist[target_dict.index('.')]
            self.log_uni_dist = unigram_dist.cuda(
            ) if self.use_cuda else unigram_dist
            self.log_uni_dist = (self.log_uni_dist /
                                 self.log_uni_dist.sum()).log()
        if self.use_marg_dist:
            if not marg_path:
                raise AttributeError(
                    "No path (--marg_path) given for marginal model when --subtract_marg used"
                )
            args = get_fairseq_args(marg_path, lang_pair)
            self.ppmi = ppmi
            self.eps = epsilon
            # Setup task, e.g., translation
            task = tasks.setup_task(args)
            assert source_dict == task.source_dictionary
            assert target_dict == task.target_dictionary
            # Load ensemble
            self.marg_models = self.load_models(marg_path, task)
            self.marg_model = EnsembleModel(self.marg_models)
            self.marg_model.eval()

    def load_models(self, model_path, task):
        logging.info('Loading fairseq model(s) from {}'.format(model_path))
        models, _ = checkpoint_utils.load_model_ensemble(
            model_path.split(':'),
            task=task,
        )

        # Optimize ensemble for generation
        for model in models:
            model.make_generation_fast_(
                beamable_mm_beam_size=1,
                need_attn=False,
            )
            if self.use_cuda:
                model.cuda()
        return models

    def get_unk_probability(self, posterior):
        """Fetch posterior[utils.UNK_ID]"""
        return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF)

    def predict_next(self):
        """Call the fairseq model."""
        inputs = torch.LongTensor([self.consumed])
        if self.use_cuda:
            inputs = inputs.cuda()
        lprobs, _ = self.model.forward_decoder(inputs, self.encoder_outs)
        lprobs[0, self.pad_id] = utils.NEG_INF
        if self.use_uni_dist:
            lprobs[0] = lprobs[0] - self.lmbda * self.log_uni_dist
        if self.use_marg_dist:
            marg_lprobs, _ = self.marg_model.forward_decoder(
                inputs, self.marg_encoder_outs)
            if self.ppmi:
                marg_lprobs[0] = torch.clamp(marg_lprobs[0], -self.eps)
            lprobs[0] = lprobs[0] - self.lmbda * marg_lprobs[0]

        return lprobs[0] if self.use_cuda else np.array(lprobs[0])

    def initialize(self, src_sentence):
        """Initialize source tensors, reset consumed."""
        self.consumed = []
        src_tokens = torch.LongTensor([
            utils.oov_to_unk(src_sentence + [utils.EOS_ID],
                             self.src_vocab_size)
        ])
        src_lengths = torch.LongTensor([len(src_sentence) + 1])
        if self.use_cuda:
            src_tokens = src_tokens.cuda()
            src_lengths = src_lengths.cuda()
        self.encoder_outs = self.model.forward_encoder({
            'src_tokens':
            src_tokens,
            'src_lengths':
            src_lengths
        })
        self.consumed = [utils.GO_ID or utils.EOS_ID]
        # Reset incremental states

        for model in self.models:
            self.model.incremental_states[model] = {}
        if self.use_marg_dist:
            self.initialize_marg()

    def initialize_marg(self):
        """Initialize source tensors, reset consumed."""
        src_tokens = torch.LongTensor(
            [utils.oov_to_unk([utils.EOS_ID], self.src_vocab_size)])
        src_lengths = torch.LongTensor([1])
        if self.use_cuda:
            src_tokens = src_tokens.cuda()
            src_lengths = src_lengths.cuda()
        self.marg_encoder_outs = self.marg_model.forward_encoder({
            'src_tokens':
            src_tokens,
            'src_lengths':
            src_lengths
        })
        # Reset incremental states
        for model in self.marg_models:
            self.marg_model.incremental_states[model] = {}

    def consume(self, word):
        """Append ``word`` to the current history."""
        self.consumed.append(word)

    def get_empty_str_prob(self):
        inputs = torch.LongTensor([[utils.GO_ID or utils.EOS_ID]])
        if self.use_cuda:
            inputs = inputs.cuda()

        lprobs, _ = self.model.forward_decoder(inputs, self.encoder_outs)
        if self.use_uni_dist:
            lprobs[0] = lprobs[0] - self.lmbda * self.log_uni_dist
        if self.use_marg_dist:
            lprobs_marg, _ = self.marg_model.forward_decoder(
                inputs, self.marg_encoder_outs)
            eos_prob = (lprobs[0, self.eos_id] -
                        self.lmbda * lprobs_marg[0, self.eos_id]).item()
            if self.ppmi:
                return min(eos_prob, 0)
            return eos_prob

        return lprobs[0, self.eos_id].item()

    def get_state(self):
        """The predictor state is the complete history."""
        return self.consumed, [
            self.model.incremental_states[m] for m in self.models
        ]

    def set_state(self, state):
        """The predictor state is the complete history."""
        consumed, inc_states = state
        self.consumed = copy.copy(consumed)
        for model, inc_state in zip(self.models, inc_states):
            self.model.incremental_states[model] = inc_state

    def is_equal(self, state1, state2):
        """Returns true if the history is the same """
        return state1[0] == state2[0]
Exemple #9
0
    def _paraphrase_sample(self, model, sample, sample_topN):
        """
        model: MT model being trained
        sample: fairseq data structure for training batch
        sample_topN: number of top candidates to sample from in paraphraser output softmax
        """
        # disable training model dropout
        model.eval()

        # disable paraphraser dropout
        # train() on the paraphraser model automatically gets when train() is called on the criteron,
        # we need to set it back to eval mode
        self.paraphraser_model.eval()  # this should disable dropout
        self.paraphraser_model.training = False  # not sure if this does anything

        pad = self.task.target_dictionary.pad()
        eos = self.task.target_dictionary.eos()
        bos = self.task.target_dictionary.bos()

        assert pad == self.task.source_dictionary.pad()
        assert eos == self.task.source_dictionary.eos()
        assert bos == self.task.source_dictionary.bos()

        # we don't know how long the paraphrase will be, so we take the target length and increase it a bit.
        target_length = sample['target'].shape[1]
        max_paraphrase_length = int(2 * target_length) + 3

        batch_size = sample['net_input']['prev_output_tokens'].shape[0]

        combined_tokens = sample['net_input']['prev_output_tokens'][:, :1]
        combined_tokens[:, :] = eos  # eos to match 'bug' in fairseq ("should" be bos)

        # make the target look like a source, to feed it into the paraphraser encoder
        paraphraser_src_lengths = torch.ones(batch_size, dtype=torch.int)
        paraphraser_source = sample['target'].new_zeros(
            tuple(sample['target'].shape)) + pad
        for i in range(batch_size):
            n_pad = (sample['target'][i] == pad).sum()
            paraphraser_src_lengths[i] = target_length - n_pad
            paraphraser_source[i, n_pad:target_length] = sample['target'][
                i, :target_length - n_pad]

        paraphraser_prediction_tokens_list = []
        paraphraser_probs_list = []

        paraphraser = EnsembleModel([
            self.paraphraser_model,
        ])

        paraphraser_encoder_out = paraphraser.forward_encoder(
            dict(src_tokens=paraphraser_source,
                 src_lengths=paraphraser_src_lengths))

        if self.paraphraser_lang_prefix:
            # take one step update the state of the paraphraser, so that the "first" time step
            #    in the loop below will pass in the language prefix
            paraphraser_probs, _ = paraphraser.forward_decoder(
                tokens=combined_tokens,
                encoder_outs=paraphraser_encoder_out,
                temperature=self.paraphraser_temperature,
                use_log_probs=False)

            prefixed_combined_tokens = sample['net_input'][
                'prev_output_tokens'][:, :2]
            prefixed_combined_tokens[:,
                                     0] = eos  # eos to match bug in fairseq ("should" be bos)
            prefixed_combined_tokens[:, 1] = self.task.target_dictionary.index(
                self.paraphraser_lang_prefix)
        else:
            prefixed_combined_tokens = None

        done = [
            False,
        ] * batch_size
        for ii in range(max_paraphrase_length + 1):
            # paraphraser prefix may or may not have the language tag prepended (after the go symbol) to input
            if prefixed_combined_tokens is None:
                paraphraser_combined_tokens = combined_tokens
            else:
                paraphraser_combined_tokens = prefixed_combined_tokens

            # this is used to compute the loss
            paraphraser_probs, _ = paraphraser.forward_decoder(
                tokens=paraphraser_combined_tokens,
                encoder_outs=paraphraser_encoder_out,
                temperature=self.paraphraser_temperature,
                use_log_probs=False)

            # this is used to generate the previous context word
            paraphraser_probs_context = paraphraser_probs

            # save the paraphraser predictions to train toward (if we don't have a distribution loss)
            _, paraphraser_predictions = torch.max(paraphraser_probs, 1)
            if self.distribution_loss:
                paraphraser_probs_list.append(paraphraser_probs.unsqueeze(1))

            # paraphraser predictions are simply the most likely next word, according to the paraphraser
            paraphraser_prediction_tokens_list.append(
                paraphraser_predictions.reshape((-1, 1)))

            combined_probs = paraphraser_probs_context
            # disallow length=0 paraphrases
            if ii == 0:
                combined_probs[:, eos] = 0.0
            # disallow other undefined behavior
            combined_probs[:, pad] = 0.0
            combined_probs[:, bos] = 0.0

            if ii == max_paraphrase_length or all(done):
                break

            # sample from top N of paraphraser distribution
            if sample_topN == 1:
                _, combined_predictions = torch.max(combined_probs, 1)
                combined_predictions = combined_predictions.reshape((-1, 1))
            else:
                topk_val, topk_ind = torch.topk(combined_probs, sample_topN)
                # re-normalize top values
                topk_val2 = topk_val / topk_val.sum(dim=1).reshape((-1, 1))
                # make distribution from normalized topk values
                mm = dis.Categorical(topk_val2)  # this will take un-normalized
                # sample indexes into topk
                topk_idx_idx = mm.sample().reshape((-1, 1))
                # convert topk indexes back into vocab indexes
                combined_predictions = torch.cat(
                    [v[i] for i, v in zip(topk_idx_idx, topk_ind)]).reshape(
                        (-1, 1))

            for jj in range(batch_size):
                if combined_predictions[jj, 0] == eos:
                    done[jj] = True

            # append output tokens to input for next time step
            combined_tokens = torch.cat(
                (combined_tokens, combined_predictions), 1)
            if prefixed_combined_tokens is not None:
                prefixed_combined_tokens = torch.cat(
                    (prefixed_combined_tokens, combined_predictions), 1)

        paraphraser_prediction_tokens = torch.cat(
            paraphraser_prediction_tokens_list, 1)
        if self.distribution_loss:
            paraphraser_probs_tokens = torch.cat(paraphraser_probs_list, 1)
        else:
            paraphraser_probs_tokens = None

        model.train()  # re-enable dropout

        # compute length of valid output for each sentence
        n_tokens = 0
        for i in range(batch_size):
            for j in range(paraphraser_prediction_tokens.shape[1]):
                if paraphraser_prediction_tokens[i, j] == eos:
                    n_tokens += j  # TODO should this include EOS? HK
                    # set anything after EOS to PAD
                    paraphraser_prediction_tokens[
                        i, j + 1:paraphraser_prediction_tokens.shape[1]] = pad
                    break

        return combined_tokens, paraphraser_prediction_tokens, n_tokens, paraphraser_probs_tokens