Example #1
class FastTranslator(Translator):
    A fast implementation of the Beam Search based translator
    Based on Fairseq implementation

    def __init__(self, opt):


        # self.eos = onmt.constants.EOS
        # self.pad = onmt.constants.PAD
        # self.bos = self.bos_id

        self.src_bos = onmt.constants.SRC_BOS
        self.src_eos = onmt.constants.SRC_EOS
        self.src_pad = onmt.constants.SRC_PAD
        self.src_unk = onmt.constants.SRC_UNK

        self.tgt_bos = self.bos_id
        self.tgt_pad = onmt.constants.TGT_PAD
        self.tgt_eos = onmt.constants.TGT_EOS
        self.tgt_unk = onmt.constants.TGT_UNK

        self.search = BeamSearch(self.tgt_dict)

        self.vocab_size = self.tgt_dict.size()
        self.min_len = 1
        self.normalize_scores = opt.normalize
        self.len_penalty = opt.alpha
        self.buffering = not opt.no_buffering
        # self.buffering = False  # buffering is currently bugged

        if hasattr(opt, 'no_repeat_ngram_size'):
            self.no_repeat_ngram_size = opt.no_repeat_ngram_size
            self.no_repeat_ngram_size = 0

        if hasattr(opt, 'dynamic_max_len'):
            self.dynamic_max_len = opt.dynamic_max_len
            self.dynamic_max_len = False

        if hasattr(opt, 'dynamic_max_len_scale'):
            self.dynamic_max_len_scale = opt.dynamic_max_len_scale
            self.dynamic_max_len_scale = 1.2

        if opt.verbose:
            # print('* Current bos id is: %d, default bos id is: %d' % (self.tgt_bos, onmt.constants.BOS))
            print("src bos id is %d; src eos id is %d;  src pad id is %d; src unk id is %d"
                  % (self.src_bos, self.src_eos, self.src_pad, self.src_unk))
            print("tgt bos id is %d; tgt eos id is %d;  tgt_pad id is %d; tgt unk id is %d"
                  % (self.tgt_bos, self.tgt_eos, self.tgt_pad, self.tgt_unk))
            print('* Using fast beam search implementation')

        if opt.vocab_list:
            word_list = list()
            for line in open(opt.vocab_list).readlines():
                word = line.strip()

            self.filter = torch.Tensor(self.tgt_dict.size()).zero_()
            for word_idx in [self.tgt_eos, self.tgt_unk]:
                self.filter[word_idx] = 1

            for word in word_list:
                idx = self.tgt_dict.lookup(word)
                if idx is not None:
                    self.filter[idx] = 1

            self.filter = self.filter.bool()
            # print(self.filter)
            if opt.cuda:
                self.filter = self.filter.cuda()

            self.use_filter = True
            self.use_filter = False

        if opt.sub_model:
            self.sub_models = list()
            self.sub_model_types = list()

            # models are string with | as delimiter
            sub_models = opt.sub_model.split("|")

            print("Loading sub models ... ")
            self.n_sub_models = len(sub_models)
            self.sub_type = 'text'

            for i, model_path in enumerate(sub_models):
                checkpoint = torch.load(model_path,
                                        map_location=lambda storage, loc: storage)

                model_opt = checkpoint['opt']
                model_opt = backward_compatible(model_opt)
                if hasattr(model_opt, "enc_not_load_state"):
                    model_opt.enc_not_load_state = True
                    model_opt.dec_not_load_state = True

                dicts = checkpoint['dicts']

                # update special tokens
                onmt.constants = add_tokenidx(model_opt, onmt.constants, dicts)
                # self.bos_token = model_opt.tgt_bos_word

                """"BE CAREFUL: the sub-models might mismatch with the main models in terms of language dict"""
                """"REQUIRE RE-matching"""

                if i == 0:
                    if "src" in checkpoint['dicts']:
                        self.src_dict = checkpoint['dicts']['src']
                #     else:
                #         self._type = "audio"
                #     self.tgt_dict = checkpoint['dicts']['tgt']
                #     if "langs" in checkpoint["dicts"]:
                #         self.lang_dict = checkpoint['dicts']['langs']
                #     else:
                #         self.lang_dict = {'src': 0, 'tgt': 1}
                #     self.bos_id = self.tgt_dict.labelToIdx[self.bos_token]
                if opt.verbose:
                    print('Loading sub-model from %s' % model_path)

                model = build_model(model_opt, checkpoint['dicts'])

                if model_opt.model in model_list:
                    # if model.decoder.positional_encoder.len_max < self.opt.max_sent_length:
                    #     print("Not enough len to decode. Renewing .. ")
                    #     model.decoder.renew_buffer(self.opt.max_sent_length)

                if opt.fp16:
                    model = model.half()

                if opt.cuda:
                    model = model.cuda()
                    model = model.cpu()

                if opt.dynamic_quantile == 1:

                    engines = torch.backends.quantized.supported_engines
                    if 'fbgemm' in engines:
                        torch.backends.quantized.engine = 'fbgemm'
                            "[INFO] fbgemm is not found in the available engines. "
                            " Possibly the CPU does not support AVX2."
                            " It is recommended to disable Quantization (set to 0).")
                        torch.backends.quantized.engine = 'qnnpack'

                    model = torch.quantization.quantize_dynamic(
                        model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8


            self.n_sub_models = 0
            self.sub_models = []

        if opt.ensemble_weight:
            ensemble_weight = [float(item) for item in opt.ensemble_weight.split("|")]
            assert len(ensemble_weight) == self.n_models

            if opt.sub_ensemble_weight:
                sub_ensemble_weight = [float(item) for item in opt.sub_ensemble_weight.split("|")]
                assert len(sub_ensemble_weight) == self.n_sub_models
                ensemble_weight = ensemble_weight + sub_ensemble_weight

            total = sum(ensemble_weight)
            self.ensemble_weight = [ item / total for item in ensemble_weight]
            self.ensemble_weight = None


    def translate_batch(self, batches, sub_batches=None):

        with torch.no_grad():
            return self._translate_batch(batches, sub_batches=sub_batches)

    def _translate_batch(self, batches, sub_batches):
        batch = batches[0]
        # Batch size is in different location depending on data.

        beam_size = self.opt.beam_size
        bsz = batch_size = batch.size

        max_len = self.opt.max_sent_length

        gold_scores = batch.get('source').data.new(batch_size).float().zero_()
        gold_words = 0
        allgold_scores = []

        if batch.has_target:
            # Use the first model to decode (also batches[0])
            model_ = self.models[0]

            gold_words, gold_scores, allgold_scores = model_.decode(batch)

        #  (3) Start decoding

        # initialize buffers
        src = batch.get('source')
        scores = src.new(bsz * beam_size, max_len + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src.new(bsz * beam_size, max_len + 2).long().fill_(self.tgt_pad)
        tokens_buf = tokens.clone()
        tokens[:, 0].fill_(self.tgt_bos)  # first token is bos
        attn, attn_buf = None, None
        nonpad_idxs = None
        src_tokens = src.transpose(0, 1)  # batch x time
        src_lengths = (src_tokens.ne(self.src_eos) & src_tokens.ne(self.src_pad)).long().sum(dim=1)
        blacklist = src_tokens.new_zeros(bsz, beam_size).eq(-1)  # forward and backward-compatible False mask
        prefix_tokens = None

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                return True
            return False

        def finalize_hypos(step, bbsz_idx, eos_scores):
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.
            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step + 2]  # skip the first index, which is EOS
            assert not tokens_clone.eq(self.tgt_eos).any()
            tokens_clone[:, step] = self.tgt_eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1) ** self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1

            sents_seen = set()
            for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                # if self.match_source_len and step > src_lengths[unfin_idx]:
                #     score = -math.inf

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i]
                        hypo_attn = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': None,
                        'positional_scores': pos_scores[i],

                if len(finalized[sent]) < beam_size:

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfin_idx):
                    finished[sent] = True
            return newly_finished

        reorder_state = None
        batch_idxs = None

        # initialize the decoder state, including:
        # - expanding the context over the batch dimension len_src x (B*beam) x H
        # - expanding the mask over the batch dimension    (B*beam) x len_src
        decoder_states = dict()
        sub_decoder_states = dict()  # for sub-model
        for i in range(self.n_models):
            decoder_states[i] = self.models[i].create_decoder_state(batches[i], beam_size, type=2,
        if self.opt.sub_model:
            for i in range(self.n_sub_models):
                sub_decoder_states[i] = self.sub_models[i].create_decoder_state(sub_batches[i], beam_size, type=2,

        if self.dynamic_max_len:
            src_len = src.size(0)
            max_len = math.ceil(int(src_len) * self.dynamic_max_len_scale)

        # Start decoding
        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
                for i, model in enumerate(self.models):
                for i, model in enumerate(self.sub_models):

            decode_input = tokens[:, :step + 1]

            lprobs, avg_attn_scores = self._decode(decode_input, decoder_states,
            avg_attn_scores = None

            if self.use_filter:
                # the marked words are 1, so fill the reverse to inf
                lprobs.masked_fill_(~self.filter.unsqueeze(0), -math.inf)
            lprobs[:, self.tgt_pad] = -math.inf  # never select pad

            # handle min and max length constraints

            if step >= max_len:
                lprobs[:, :self.tgt_eos] = -math.inf
                lprobs[:, self.tgt_eos + 1:] = -math.inf
            elif step < self.min_len:
                lprobs[:, self.tgt_eos] = -math.inf

            # handle prefix tokens (possibly with different lengths)
            # here prefix tokens is a list of word-ids
            if prefix_tokens is not None:

                if step == 0 and bsz == 1:
                    # run the decoder through the prefix_tokens
                    # store the scores and store the incremental states
                    prefix_tokens = torch.tensor(prefix_tokens).type_as(tokens)

                    if step < prefix_tokens.size(1) and step < max_len:
                        prefix_tokens = torch.tensor(prefix_tokens).type_as(tokens)
                        prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
                        prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
                        prefix_mask = prefix_toks.ne(self.tgt_pad)
                        lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)

                        lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
                            -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]

                        # if prefix includes eos, then we should make sure tokens and
                        # scores are the same across all beams
                        eos_mask = prefix_toks.eq(self.tgt_eos)
                        if eos_mask.any():
                            # validate that the first beam matches the prefix
                            first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
                            eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
                            target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
                            assert (first_beam == target_prefix).all()

                            def replicate_first_beam(tensor, mask):
                                tensor = tensor.view(-1, beam_size, tensor.size(-1))
                                tensor[mask] = tensor[mask][:, :1, :]
                                return tensor.view(-1, tensor.size(-1))

                            # copy tokens, scores and lprobs from the first beam to all beams
                            tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
                            scores = replicate_first_beam(scores, eos_mask_batch_dim)
                            lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim)

            if self.no_repeat_ngram_size > 0:
                # for each beam and batch sentence, generate a list of previous ngrams
                gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
                for bbsz_idx in range(bsz * beam_size):
                    gen_tokens = tokens[bbsz_idx].tolist()
                    for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
                        gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
                            gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]

            # Record attention scores
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
                    attn_buf = attn.clone()
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)

            if self.no_repeat_ngram_size > 0:
                def calculate_banned_tokens(bbsz_idx):
                    # before decoding the next token, prevent decoding of ngrams that have already appeared
                    ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
                    return gen_ngrams[bbsz_idx].get(ngram_index, [])

                if step + 2 - self.no_repeat_ngram_size >= 0:
                    # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
                    banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
                    banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]

                for bbsz_idx in range(bsz * beam_size):
                    lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf

            cand_scores, cand_indices, cand_beams = self.search.step(
                lprobs.view(bsz, -1, self.vocab_size),
                scores.view(bsz, beam_size, -1)[:, :, :step],

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos (except for blacklisted ones)
            eos_mask = cand_indices.eq(self.tgt_eos)
            eos_mask[:, :beam_size][blacklist] = 0

            # only consider eos when it's among the top beam_size indices
                cand_bbsz_idx[:, :beam_size],
                mask=eos_mask[:, :beam_size],

            finalized_sents = set()
            if eos_bbsz_idx.numel() > 0:
                    cand_scores[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores)
                num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
            # assert step < max_len
            if len(finalized_sents) > 0:

                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = batch_mask.nonzero(as_tuple=False).squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                # if prefix_tokens is not None:
                #     prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths = src_lengths[batch_idxs]
                blacklist = blacklist[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
                bsz = new_bsz
                batch_idxs = None

            # Set active_mask so that values > cand_size indicate eos or
            # blacklisted hypos and values < cand_size indicate candidate
            # active hypos. After this, the min values per row are the top
            # candidate active hypos.

            active_mask = buffer('active_mask')
            eos_mask[:, :beam_size] |= blacklist
                eos_mask.type_as(cand_offsets) * cand_size,

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, new_blacklist = buffer('active_hypos'), buffer('new_blacklist')
                active_mask, k=beam_size, dim=1, largest=False,
                out=(new_blacklist, active_hypos)

            # update blacklist to ignore any finalized hypos
            blacklist = new_blacklist.ge(cand_size)[:, :beam_size]
            assert (~blacklist).any(dim=1).all()

            active_bbsz_idx = buffer('active_bbsz_idx')
                cand_bbsz_idx, dim=1, index=active_hypos,
            active_scores = torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
                tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
                cand_indices, dim=1, index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            if step > 0:
                    scores[:, :step], dim=0, index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                cand_scores, dim=1, index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],

            # copy attention for active hypotheses
            if attn is not None:
                    attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens

            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)

        return finalized, gold_scores, gold_words, allgold_scores

    def _decode(self, tokens, decoder_states, sub_decoder_states=None):

        # require batch first for everything
        outs = dict()
        attns = dict()

        for i in range(self.n_models):
            # decoder output contains the log-prob distribution of the next step
            decoder_output = self.models[i].step(tokens, decoder_states[i])

            outs[i] = decoder_output['log_prob']
            attns[i] = decoder_output['coverage']

        for j in range(self.n_sub_models):
            sub_decoder_output = self.sub_models[j].step(tokens, sub_decoder_states[j])
            outs[self.n_models + j] = sub_decoder_output['log_prob']

        out = self._combine_outputs(outs, weight=self.ensemble_weight)
        # attn = self._combine_attention(attns)

        if self.vocab_size > out.size(-1):
            self.vocab_size = out.size(-1)  # what the hell ?
        # attn = attn[:, -1, :] # I dont know what this line does
        attn = None  # attn is never used in decoding probably

        return out, attn

    def translate(self, src_data, tgt_data, sub_src_data=None, type='mt'):
        #  (1) convert words to indexes
        if isinstance(src_data[0], list) and type == 'asr':
            batches = list()
            for src_data_ in src_data:
                dataset = self.build_data(src_data_, tgt_data, type=type)
                batch = dataset.get_batch(0)
            dataset = self.build_data(src_data, tgt_data, type=type)
            batch = dataset.get_batch(0)  # this dataset has only one mini-batch
            batches = [batch] * self.n_models
            src_data = [src_data] * self.n_models

        if sub_src_data is not None and len(sub_src_data) > 0:
            sub_dataset = self.build_data(sub_src_data, tgt_data, type='mt')
            sub_batch = sub_dataset.get_batch(0)
            sub_batches = [sub_batch] * self.n_sub_models
            sub_src_data = [sub_src_data] * self.n_sub_models
            sub_batches, sub_src_data = None, None

        batch_size = batches[0].size
        if self.cuda:
            for i, _ in enumerate(batches):
            if sub_batches:
                for i, _ in enumerate(sub_batches):

        #  (2) translate
        finalized, gold_score, gold_words, allgold_words = self.translate_batch(batches, sub_batches=sub_batches)
        pred_length = []

        #  (3) convert indexes to words
        pred_batch = []
        src_data = src_data[0]
        for b in range(batch_size):

            # probably when the src is empty so beam search stops immediately
            if len(finalized[b]) == 0:
                assert len(src_data[b]) == 0, "The target search result is empty, assuming that the source is empty."
                    [self.build_target_tokens([], src_data[b], None)
                     for n in range(self.opt.n_best)]
                    [self.build_target_tokens(finalized[b][n]['tokens'], src_data[b], None)
                     for n in range(self.opt.n_best)]
        pred_score = []
        for b in range(batch_size):
            if len(finalized[b]) == 0:
                     for n in range(self.opt.n_best)]
                     for n in range(self.opt.n_best)]

        return pred_batch, pred_score, pred_length, gold_score, gold_words, allgold_words
Example #2
class GlobalStreamTranslator(Translator):
    A fast implementation of the Beam Search based translator
    Based on Fairseq implementation
    def __init__(self, opt):

        self.search = BeamSearch(self.tgt_dict)
        self.eos = onmt.constants.EOS
        self.pad = onmt.constants.PAD
        self.bos = self.bos_id
        self.vocab_size = self.tgt_dict.size()
        self.min_len = 1
        self.normalize_scores = opt.normalize
        self.len_penalty = opt.alpha
        self.decoder_states = defaultdict(lambda: None)
        self.prev_scores = torch.Tensor(self.opt.beam_size).fill_(0)
        self.prev_lengths = torch.LongTensor(self.opt.beam_size).fill_(0)

        if hasattr(opt, 'no_repeat_ngram_size'):
            self.no_repeat_ngram_size = opt.no_repeat_ngram_size
            self.no_repeat_ngram_size = 0

        if hasattr(opt, 'dynamic_max_len'):
            self.dynamic_max_len = opt.dynamic_max_len
            self.dynamic_max_len = False

        if hasattr(opt, 'dynamic_max_len_scale'):
            self.dynamic_max_len_scale = opt.dynamic_max_len_scale
            self.dynamic_max_len_scale = 1.2

        if hasattr(opt, 'dynamic_min_len_scale'):
            self.dynamic_min_len_scale = opt.dynamic_min_len_scale
            self.dynamic_min_len_scale = 0.8

        if opt.verbose:
            print('* Current bos id: %d' % self.bos_id, onmt.constants.BOS)
            print('* Using fast beam search implementation')

        self.max_memory_size = opt.max_memory_size

        for i in range(len(self.models)):

    def reset_stream(self):
        self.decoder_states = defaultdict(lambda: None)

    def translateBatch(self, batch):

        with torch.no_grad():
            return self._translateBatch(batch)

    def _translateBatch(self, batch):

        # Batch size is in different location depending on data.

        beam_size = self.opt.beam_size
        bsz = batch_size = batch.size

        max_len = self.opt.max_sent_length

        gold_scores = batch.get('source').data.new(batch_size).float().zero_()
        gold_words = 0
        allgold_scores = []

        if batch.has_target:
            # Use the first model to decode
            model_ = self.models[0]

            gold_words, gold_scores, allgold_scores = model_.decode(batch)

        #  (3) Start decoding

        # initialize buffers
        src = batch.get('source')
        scores = src.new(bsz * beam_size, max_len + 1).float().fill_(0)
        self.prev_scores = self.prev_scores.type_as(scores)
        self.prev_lengths = self.prev_lengths.to(scores.device)
        scores_buf = scores.clone()
        tokens = src.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
        beams = src.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
        tokens_buf = tokens.clone()
        beams_buf = beams.clone()

        tokens[:, 0].fill_(self.bos)  # first token is bos
        beams[:, 0].fill_(0)  # first one is the same ...
        attn, attn_buf = None, None
        nonpad_idxs = None
        src_tokens = src.transpose(0, 1)  # batch x time
        src_lengths = (src_tokens.ne(self.eos)
                       & src_tokens.ne(self.pad)).long().sum(dim=1)
        blacklist = src_tokens.new_zeros(bsz, beam_size).eq(
            -1)  # forward and backward-compatible False mask
        prefix_tokens = None

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                return True
            return False

        def finalize_hypos(step, bbsz_idx, eos_scores):
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.
            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            beams_clone = beams.index_select(0, bbsz_idx)
            prev_lengths = self.prev_lengths.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            beams_clone = beams_clone[:, 0:step + 2]
            assert not tokens_clone.eq(self.eos).any()
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(
                0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            raw_scores = eos_scores.clone()

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1 + prev_lengths)**self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1

            sents_seen = set()

            assert len(self.decoder_states) == 1
            beam_buffers = self.decoder_states[0].get_beam_buffer(bbsz_idx)

            for i, (idx, score) in enumerate(
                    zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                # looks like sent and unfin_idx are both 0 when batch_size is 1 ...
                # until everything is finished
                sents_seen.add((sent, unfin_idx))

                def get_buffer():

                    buffer = dict()
                    for l in beam_buffers:
                        buffer[l] = dict()

                        # take that state
                        for key in beam_buffers[l]:
                            buffer[l][key] = beam_buffers[l][
                                key][:, i, :].unsqueeze(1)

                    return buffer

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i]
                        hypo_attn = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': None,
                        'positional_scores': pos_scores[i],
                        'hidden_buffer': get_buffer(),
                        'raw_score': raw_scores[i]

                if len(finalized[sent]) < beam_size:

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfin_idx):
                    finished[sent] = True
            return newly_finished

        reorder_state = None
        batch_idxs = None

        # initialize the decoder state, including:
        # - expanding the context over the batch dimension len_src x (B*beam) x H
        # - expanding the mask over the batch dimension    (B*beam) x len_src
        for i in range(self.n_models):
            # decoder_states[i] = self.models[i].create_decoder_state(batch, beam_size, type=2, streaming=False)
            self.decoder_states[i] = self.models[i].create_decoder_state(

        if self.dynamic_max_len:
            src_len = src.size(0)
            max_len = min(math.ceil(int(src_len) * self.dynamic_max_len_scale),
            min_len = math.ceil(int(src_len) * self.dynamic_min_len_scale)
            min_len = self.min_len

        # Start decoding
        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(
                    reorder_state.view(-1, beam_size).add_(
                        corr.unsqueeze(-1) * beam_size)
                for i, model in enumerate(self.models):

            decode_input = tokens[:, :step + 1]
            # lprobs size: [batch x beam x vocab_size]
            lprobs, avg_attn_scores = self._decode(decode_input,
            avg_attn_scores = None

            lprobs[:, self.pad] = -math.inf  # never select pad
            lprobs[:, self.bos] = -math.inf  # never select bos ...

            # handle min and max length constraints
            if step >= max_len:
                lprobs[:, :self.eos] = -math.inf
                lprobs[:, self.eos + 1:] = -math.inf
            elif step < min_len:
                lprobs[:, self.eos] = -math.inf

            # handle prefix tokens (possibly with different lengths)
            # if prefix_tokens is not None and step < prefix_tokens.size(1):
            #     prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
            #     prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
            #     prefix_mask = prefix_toks.ne(self.pad)
            #     lprobs[prefix_mask] = -math.inf
            #     lprobs[prefix_mask] = lprobs[prefix_mask].scatter_(
            #         -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs
            #     )
            #     # if prefix includes eos, then we should make sure tokens and
            #     # scores are the same across all beams
            #     eos_mask = prefix_toks.eq(self.eos)
            #     if eos_mask.any():
            #         # validate that the first beam matches the prefix
            #         first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
            #         eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
            #         target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
            #         assert (first_beam == target_prefix).all()
            #         def replicate_first_beam(tensor, mask):
            #             tensor = tensor.view(-1, beam_size, tensor.size(-1))
            #             tensor[mask] = tensor[mask][:, :1, :]
            #             return tensor.view(-1, tensor.size(-1))
            #         # copy tokens, scores and lprobs from the first beam to all beams
            #         tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
            #         scores = replicate_first_beam(scores, eos_mask_batch_dim)
            #         lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim)

            if self.no_repeat_ngram_size > 0:
                # for each beam and batch sentence, generate a list of previous ngrams
                gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
                for bbsz_idx in range(bsz * beam_size):
                    gen_tokens = tokens[bbsz_idx].tolist()
                    for ngram in zip(*[
                            for i in range(self.no_repeat_ngram_size)
                        gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
                            gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]

            # Record attention scores
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, src_tokens.size(1),
                                      max_len + 2)
                    attn_buf = attn.clone()
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)

            if self.no_repeat_ngram_size > 0:

                def calculate_banned_tokens(bbsz_idx):
                    # before decoding the next token, prevent decoding of ngrams that have already appeared
                    ngram_index = tuple(
                        tokens[bbsz_idx, step + 2 -
                               self.no_repeat_ngram_size:step + 1].tolist())
                    return gen_ngrams[bbsz_idx].get(ngram_index, [])

                if step + 2 - self.no_repeat_ngram_size >= 0:
                    # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
                    banned_tokens = [
                        for bbsz_idx in range(bsz * beam_size)
                    banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]

                for bbsz_idx in range(bsz * beam_size):
                    lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf

            cand_scores, cand_indices, cand_beams = self.search.step(
                lprobs.view(bsz, -1, self.vocab_size),
                scores.view(bsz, beam_size, -1)[:, :, :step],
            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]

            # when bsz = 1, cand_bbsz_idx is not different than cand_beams
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos (except for blacklisted ones)
            eos_mask = cand_indices.eq(self.eos)
            eos_mask[:, :beam_size][blacklist] = 0

            # only consider eos when it's among the top beam_size indices
                cand_bbsz_idx[:, :beam_size],
                mask=eos_mask[:, :beam_size],

            # so: cand_bbsz_idx is a list of beam indices
            # eos_bbsz_idx in the case of batch_size 1: a list of beam_indices in which the eos is reached

            finalized_sents = set()
            if eos_bbsz_idx.numel() > 0:
                    cand_scores[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                finalized_sents = finalize_hypos(step, eos_bbsz_idx,
                num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
            assert step < max_len

            # if batch size == 1 then this block will not be touched
            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = batch_mask.nonzero().squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                # if prefix_tokens is not None:
                #     prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths = src_lengths[batch_idxs]
                blacklist = blacklist[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(
                        new_bsz * beam_size, attn.size(1), -1)
                bsz = new_bsz
                batch_idxs = None

            # Set active_mask so that values > cand_size indicate eos or
            # blacklisted hypos and values < cand_size indicate candidate
            # active hypos. After this, the min values per row are the top
            # candidate active hypos.
            active_mask = buffer('active_mask')
            eos_mask[:, :beam_size] |= blacklist
                eos_mask.type_as(cand_offsets) * cand_size,

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, new_blacklist = buffer('active_hypos'), buffer(
                       out=(new_blacklist, active_hypos))

            # update blacklist to ignore any finalized hypos
            blacklist = new_blacklist.ge(cand_size)[:, :beam_size]
            assert (~blacklist).any(dim=1).all()

            active_bbsz_idx = buffer('active_bbsz_idx')
            active_scores = torch.gather(
                out=scores[:, step].view(bsz, beam_size),

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
                tokens[:, :step + 1],
                out=tokens_buf[:, :step + 1],

                beams[:, :step + 1],
                out=beams_buf[:, step + 1],

            # add the cand_indices (words) into the token buffer of the last step
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],

                out=beams_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            # print(cand_indices.size(), cand_bbsz_idx.size())

            if step > 0:
                    scores[:, :step],
                    out=scores_buf[:, :step],
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],

            # copy attention for active hypotheses
            if attn is not None:
                    attn[:, :, :step + 2],
                    out=attn_buf[:, :, :step + 2],

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            beams, beams_buf = beams_buf, beams
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending

        # Re-encoding step
        # for beam in range(self.opt.beam_size):
        #     " batch size = 1 "
        #     tensor = finalized[0][beam]['tokens']
        #     words = " ".join(self.tgt_dict.convertToLabels(tensor, onmt.constants.EOS, including_stop=False))
        #     beam_org = finalized[0][beam]['beam_origin']
        #     print(beam_org, words)

        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent],
                                     key=lambda r: r['score'],

        for sent in range(len(finalized)):
            for beam in range(len(finalized[sent])):
                tensor = finalized[sent][beam]['tokens']
                words = self.tgt_dict.convertToLabels(tensor,
                n_words = len(words)
                buffer_state = finalized[sent][beam]['hidden_buffer']
                sentence = " ".join(words)
                # self.prev_scores[beam].fill_(finalized[sent][beam]['raw_score'])
                # self.prev_lengths[beam].fill_(n_words + 2)

            # assign the buffers to the decoder_states
            # at this point, we need to somehow make zero padding

        # self.decoder_states = defaultdict(lambda: None)

        # Should we do it before sorting, or after sorting
        # Step 1: revert the memory of the decoder to the starting point
        # Done. they are the buffer_state

        # Step 3: Re-select the buffer (

        # print(tensor)

        return finalized, gold_scores, gold_words, allgold_scores

    def _decode(self, tokens, decoder_states):

        # require batch first for everything
        outs = dict()
        attns = dict()

        for i in range(self.n_models):
            # streaming = True in this case
            decoder_output = self.models[i].step(tokens,

            # take the last decoder state
            # decoder_hidden = decoder_hidden.squeeze(1)
            # attns[i] = coverage[:, -1, :].squeeze(1)  # batch * beam x src_len

            # batch * beam x vocab_size
            # outs[i] = self.models[i].generator(decoder_hidden)
            outs[i] = decoder_output['log_prob']
            attns[i] = decoder_output['coverage']

        out = self._combine_outputs(outs)
        attn = self._combine_attention(attns)
        # attn = attn[:, -1, :] # I dont know what this line means
        attn = None  # lol this is never used probably

        return out, attn

    def translate(self, src_data, tgt_data, type='mt'):
        #  (1) convert words to indexes
        dataset = self.build_data(src_data, tgt_data, type=type)
        batch = dataset.next()[0]
        if self.cuda:
        # ~ batch = self.to_variable(dataset.next()[0])
        batch_size = batch.size

        #  (2) translate
        finalized, gold_score, gold_words, allgold_words = self.translateBatch(
        pred_length = []

        #  (3) convert indexes to words
        pred_batch = []
        for b in range(batch_size):
                                         src_data[b], None)
                for n in range(self.opt.n_best)
        pred_score = []
        for b in range(batch_size):
                for n in range(self.opt.n_best)

        return pred_batch, pred_score, pred_length, gold_score, gold_words, allgold_words