def __init__(self, opt): super().__init__(opt) self.search = BeamSearch(self.tgt_dict) self.eos = onmt.constants.EOS self.pad = onmt.constants.PAD self.bos = self.bos_id self.vocab_size = self.tgt_dict.size() self.min_len = 1 self.normalize_scores = opt.normalize self.len_penalty = opt.alpha self.buffering = not opt.no_buffering if hasattr(opt, 'no_repeat_ngram_size'): self.no_repeat_ngram_size = opt.no_repeat_ngram_size else: self.no_repeat_ngram_size = 0 if hasattr(opt, 'dynamic_max_len'): self.dynamic_max_len = opt.dynamic_max_len else: self.dynamic_max_len = False if hasattr(opt, 'dynamic_max_len_scale'): self.dynamic_max_len_scale = opt.dynamic_max_len_scale else: self.dynamic_max_len_scale = 1.2 if opt.verbose: print('* Current bos id: %d' % self.bos_id, onmt.constants.BOS) print('* Using fast beam search implementation')
def __init__(self, opt): super().__init__(opt) self.search = BeamSearch(self.tgt_dict) self.eos = onmt.constants.EOS self.pad = onmt.constants.PAD self.bos = self.bos_id self.vocab_size = self.tgt_dict.size() self.min_len = 1 self.normalize_scores = opt.normalize self.len_penalty = opt.alpha self.decoder_states = defaultdict(lambda: None) self.prev_scores = torch.Tensor(self.opt.beam_size).fill_(0) self.prev_lengths = torch.LongTensor(self.opt.beam_size).fill_(0) if hasattr(opt, 'no_repeat_ngram_size'): self.no_repeat_ngram_size = opt.no_repeat_ngram_size else: self.no_repeat_ngram_size = 0 if hasattr(opt, 'dynamic_max_len'): self.dynamic_max_len = opt.dynamic_max_len else: self.dynamic_max_len = False if hasattr(opt, 'dynamic_max_len_scale'): self.dynamic_max_len_scale = opt.dynamic_max_len_scale else: self.dynamic_max_len_scale = 1.2 if hasattr(opt, 'dynamic_min_len_scale'): self.dynamic_min_len_scale = opt.dynamic_min_len_scale else: self.dynamic_min_len_scale = 0.8 if opt.verbose: print('* Current bos id: %d' % self.bos_id, onmt.constants.BOS) print('* Using fast beam search implementation') self.max_memory_size = opt.max_memory_size for i in range(len(self.models)): self.models[i].set_memory_size(self.max_memory_size, self.max_memory_size)
def __init__(self, opt): super().__init__(opt) # self.eos = onmt.constants.EOS # self.pad = onmt.constants.PAD # self.bos = self.bos_id self.src_bos = onmt.constants.SRC_BOS self.src_eos = onmt.constants.SRC_EOS self.src_pad = onmt.constants.SRC_PAD self.src_unk = onmt.constants.SRC_UNK self.tgt_bos = self.bos_id self.tgt_pad = onmt.constants.TGT_PAD self.tgt_eos = onmt.constants.TGT_EOS self.tgt_unk = onmt.constants.TGT_UNK self.search = BeamSearch(self.tgt_dict) self.vocab_size = self.tgt_dict.size() self.min_len = 1 self.normalize_scores = opt.normalize self.len_penalty = opt.alpha self.buffering = not opt.no_buffering # self.buffering = False # buffering is currently bugged if hasattr(opt, 'no_repeat_ngram_size'): self.no_repeat_ngram_size = opt.no_repeat_ngram_size else: self.no_repeat_ngram_size = 0 if hasattr(opt, 'dynamic_max_len'): self.dynamic_max_len = opt.dynamic_max_len else: self.dynamic_max_len = False if hasattr(opt, 'dynamic_max_len_scale'): self.dynamic_max_len_scale = opt.dynamic_max_len_scale else: self.dynamic_max_len_scale = 1.2 if opt.verbose: # print('* Current bos id is: %d, default bos id is: %d' % (self.tgt_bos, onmt.constants.BOS)) print("src bos id is %d; src eos id is %d; src pad id is %d; src unk id is %d" % (self.src_bos, self.src_eos, self.src_pad, self.src_unk)) print("tgt bos id is %d; tgt eos id is %d; tgt_pad id is %d; tgt unk id is %d" % (self.tgt_bos, self.tgt_eos, self.tgt_pad, self.tgt_unk)) print('* Using fast beam search implementation') if opt.vocab_list: word_list = list() for line in open(opt.vocab_list).readlines(): word = line.strip() word_list.append(word) self.filter = torch.Tensor(self.tgt_dict.size()).zero_() for word_idx in [self.tgt_eos, self.tgt_unk]: self.filter[word_idx] = 1 for word in word_list: idx = self.tgt_dict.lookup(word) if idx is not None: self.filter[idx] = 1 self.filter = self.filter.bool() # print(self.filter) if opt.cuda: self.filter = self.filter.cuda() self.use_filter = True else: self.use_filter = False if opt.sub_model: self.sub_models = list() self.sub_model_types = list() # models are string with | as delimiter sub_models = opt.sub_model.split("|") print("Loading sub models ... ") self.n_sub_models = len(sub_models) self.sub_type = 'text' for i, model_path in enumerate(sub_models): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] model_opt = backward_compatible(model_opt) if hasattr(model_opt, "enc_not_load_state"): model_opt.enc_not_load_state = True model_opt.dec_not_load_state = True dicts = checkpoint['dicts'] # update special tokens onmt.constants = add_tokenidx(model_opt, onmt.constants, dicts) # self.bos_token = model_opt.tgt_bos_word """"BE CAREFUL: the sub-models might mismatch with the main models in terms of language dict""" """"REQUIRE RE-matching""" if i == 0: if "src" in checkpoint['dicts']: self.src_dict = checkpoint['dicts']['src'] # else: # self._type = "audio" # self.tgt_dict = checkpoint['dicts']['tgt'] # # if "langs" in checkpoint["dicts"]: # self.lang_dict = checkpoint['dicts']['langs'] # # else: # self.lang_dict = {'src': 0, 'tgt': 1} # # self.bos_id = self.tgt_dict.labelToIdx[self.bos_token] if opt.verbose: print('Loading sub-model from %s' % model_path) model = build_model(model_opt, checkpoint['dicts']) optimize_model(model) model.load_state_dict(checkpoint['model']) if model_opt.model in model_list: # if model.decoder.positional_encoder.len_max < self.opt.max_sent_length: # print("Not enough len to decode. Renewing .. ") # model.decoder.renew_buffer(self.opt.max_sent_length) model.renew_buffer(self.opt.max_sent_length) if opt.fp16: model = model.half() if opt.cuda: model = model.cuda() else: model = model.cpu() if opt.dynamic_quantile == 1: engines = torch.backends.quantized.supported_engines if 'fbgemm' in engines: torch.backends.quantized.engine = 'fbgemm' else: print( "[INFO] fbgemm is not found in the available engines. " " Possibly the CPU does not support AVX2." " It is recommended to disable Quantization (set to 0).") torch.backends.quantized.engine = 'qnnpack' model = torch.quantization.quantize_dynamic( model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8 ) model.eval() self.sub_models.append(model) self.sub_model_types.append(model_opt.model) else: self.n_sub_models = 0 self.sub_models = [] if opt.ensemble_weight: ensemble_weight = [float(item) for item in opt.ensemble_weight.split("|")] assert len(ensemble_weight) == self.n_models if opt.sub_ensemble_weight: sub_ensemble_weight = [float(item) for item in opt.sub_ensemble_weight.split("|")] assert len(sub_ensemble_weight) == self.n_sub_models ensemble_weight = ensemble_weight + sub_ensemble_weight total = sum(ensemble_weight) self.ensemble_weight = [ item / total for item in ensemble_weight] else: self.ensemble_weight = None print(self.main_model_opt)
class FastTranslator(Translator): """ A fast implementation of the Beam Search based translator Based on Fairseq implementation """ def __init__(self, opt): super().__init__(opt) # self.eos = onmt.constants.EOS # self.pad = onmt.constants.PAD # self.bos = self.bos_id self.src_bos = onmt.constants.SRC_BOS self.src_eos = onmt.constants.SRC_EOS self.src_pad = onmt.constants.SRC_PAD self.src_unk = onmt.constants.SRC_UNK self.tgt_bos = self.bos_id self.tgt_pad = onmt.constants.TGT_PAD self.tgt_eos = onmt.constants.TGT_EOS self.tgt_unk = onmt.constants.TGT_UNK self.search = BeamSearch(self.tgt_dict) self.vocab_size = self.tgt_dict.size() self.min_len = 1 self.normalize_scores = opt.normalize self.len_penalty = opt.alpha self.buffering = not opt.no_buffering # self.buffering = False # buffering is currently bugged if hasattr(opt, 'no_repeat_ngram_size'): self.no_repeat_ngram_size = opt.no_repeat_ngram_size else: self.no_repeat_ngram_size = 0 if hasattr(opt, 'dynamic_max_len'): self.dynamic_max_len = opt.dynamic_max_len else: self.dynamic_max_len = False if hasattr(opt, 'dynamic_max_len_scale'): self.dynamic_max_len_scale = opt.dynamic_max_len_scale else: self.dynamic_max_len_scale = 1.2 if opt.verbose: # print('* Current bos id is: %d, default bos id is: %d' % (self.tgt_bos, onmt.constants.BOS)) print("src bos id is %d; src eos id is %d; src pad id is %d; src unk id is %d" % (self.src_bos, self.src_eos, self.src_pad, self.src_unk)) print("tgt bos id is %d; tgt eos id is %d; tgt_pad id is %d; tgt unk id is %d" % (self.tgt_bos, self.tgt_eos, self.tgt_pad, self.tgt_unk)) print('* Using fast beam search implementation') if opt.vocab_list: word_list = list() for line in open(opt.vocab_list).readlines(): word = line.strip() word_list.append(word) self.filter = torch.Tensor(self.tgt_dict.size()).zero_() for word_idx in [self.tgt_eos, self.tgt_unk]: self.filter[word_idx] = 1 for word in word_list: idx = self.tgt_dict.lookup(word) if idx is not None: self.filter[idx] = 1 self.filter = self.filter.bool() # print(self.filter) if opt.cuda: self.filter = self.filter.cuda() self.use_filter = True else: self.use_filter = False if opt.sub_model: self.sub_models = list() self.sub_model_types = list() # models are string with | as delimiter sub_models = opt.sub_model.split("|") print("Loading sub models ... ") self.n_sub_models = len(sub_models) self.sub_type = 'text' for i, model_path in enumerate(sub_models): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] model_opt = backward_compatible(model_opt) if hasattr(model_opt, "enc_not_load_state"): model_opt.enc_not_load_state = True model_opt.dec_not_load_state = True dicts = checkpoint['dicts'] # update special tokens onmt.constants = add_tokenidx(model_opt, onmt.constants, dicts) # self.bos_token = model_opt.tgt_bos_word """"BE CAREFUL: the sub-models might mismatch with the main models in terms of language dict""" """"REQUIRE RE-matching""" if i == 0: if "src" in checkpoint['dicts']: self.src_dict = checkpoint['dicts']['src'] # else: # self._type = "audio" # self.tgt_dict = checkpoint['dicts']['tgt'] # # if "langs" in checkpoint["dicts"]: # self.lang_dict = checkpoint['dicts']['langs'] # # else: # self.lang_dict = {'src': 0, 'tgt': 1} # # self.bos_id = self.tgt_dict.labelToIdx[self.bos_token] if opt.verbose: print('Loading sub-model from %s' % model_path) model = build_model(model_opt, checkpoint['dicts']) optimize_model(model) model.load_state_dict(checkpoint['model']) if model_opt.model in model_list: # if model.decoder.positional_encoder.len_max < self.opt.max_sent_length: # print("Not enough len to decode. Renewing .. ") # model.decoder.renew_buffer(self.opt.max_sent_length) model.renew_buffer(self.opt.max_sent_length) if opt.fp16: model = model.half() if opt.cuda: model = model.cuda() else: model = model.cpu() if opt.dynamic_quantile == 1: engines = torch.backends.quantized.supported_engines if 'fbgemm' in engines: torch.backends.quantized.engine = 'fbgemm' else: print( "[INFO] fbgemm is not found in the available engines. " " Possibly the CPU does not support AVX2." " It is recommended to disable Quantization (set to 0).") torch.backends.quantized.engine = 'qnnpack' model = torch.quantization.quantize_dynamic( model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8 ) model.eval() self.sub_models.append(model) self.sub_model_types.append(model_opt.model) else: self.n_sub_models = 0 self.sub_models = [] if opt.ensemble_weight: ensemble_weight = [float(item) for item in opt.ensemble_weight.split("|")] assert len(ensemble_weight) == self.n_models if opt.sub_ensemble_weight: sub_ensemble_weight = [float(item) for item in opt.sub_ensemble_weight.split("|")] assert len(sub_ensemble_weight) == self.n_sub_models ensemble_weight = ensemble_weight + sub_ensemble_weight total = sum(ensemble_weight) self.ensemble_weight = [ item / total for item in ensemble_weight] else: self.ensemble_weight = None print(self.main_model_opt) def translate_batch(self, batches, sub_batches=None): with torch.no_grad(): return self._translate_batch(batches, sub_batches=sub_batches) def _translate_batch(self, batches, sub_batches): batch = batches[0] # Batch size is in different location depending on data. beam_size = self.opt.beam_size bsz = batch_size = batch.size max_len = self.opt.max_sent_length gold_scores = batch.get('source').data.new(batch_size).float().zero_() gold_words = 0 allgold_scores = [] if batch.has_target: # Use the first model to decode (also batches[0]) model_ = self.models[0] gold_words, gold_scores, allgold_scores = model_.decode(batch) # (3) Start decoding # initialize buffers src = batch.get('source') scores = src.new(bsz * beam_size, max_len + 1).float().fill_(0) scores_buf = scores.clone() tokens = src.new(bsz * beam_size, max_len + 2).long().fill_(self.tgt_pad) tokens_buf = tokens.clone() tokens[:, 0].fill_(self.tgt_bos) # first token is bos attn, attn_buf = None, None nonpad_idxs = None src_tokens = src.transpose(0, 1) # batch x time src_lengths = (src_tokens.ne(self.src_eos) & src_tokens.ne(self.src_pad)).long().sum(dim=1) blacklist = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask prefix_tokens = None # list of completed sentences finalized = [[] for i in range(bsz)] finished = [False for i in range(bsz)] num_remaining_sent = bsz # number of candidate hypos per step cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens) # helper function for allocating buffers on the fly buffers = {} def buffer(name, type_of=tokens): # noqa if name not in buffers: buffers[name] = type_of.new() return buffers[name] def is_finished(sent, step, unfinalized_scores=None): """ Check whether we've finished generation for a given sentence, by comparing the worst score among finalized hypotheses to the best possible score among unfinalized hypotheses. """ assert len(finalized[sent]) <= beam_size if len(finalized[sent]) == beam_size: return True return False def finalize_hypos(step, bbsz_idx, eos_scores): """ Finalize the given hypotheses at this step, while keeping the total number of finalized hypotheses per sentence <= beam_size. Note: the input must be in the desired finalization order, so that hypotheses that appear earlier in the input are preferred to those that appear later. Args: step: current time step bbsz_idx: A vector of indices in the range [0, bsz*beam_size), indicating which hypotheses to finalize eos_scores: A vector of the same size as bbsz_idx containing scores for each hypothesis """ assert bbsz_idx.numel() == eos_scores.numel() # clone relevant token and attention tensors tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS assert not tokens_clone.eq(self.tgt_eos).any() tokens_clone[:, step] = self.tgt_eos attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None # compute scores per token position pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] pos_scores[:, step] = eos_scores # convert from cumulative to per-position scores pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] # normalize sentence-level scores if self.normalize_scores: eos_scores /= (step + 1) ** self.len_penalty cum_unfin = [] prev = 0 for f in finished: if f: prev += 1 else: cum_unfin.append(prev) sents_seen = set() for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())): unfin_idx = idx // beam_size sent = unfin_idx + cum_unfin[unfin_idx] sents_seen.add((sent, unfin_idx)) # if self.match_source_len and step > src_lengths[unfin_idx]: # score = -math.inf def get_hypo(): if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i] else: hypo_attn = None return { 'tokens': tokens_clone[i], 'score': score, 'attention': hypo_attn, # src_len x tgt_len 'alignment': None, 'positional_scores': pos_scores[i], } if len(finalized[sent]) < beam_size: finalized[sent].append(get_hypo()) newly_finished = [] for sent, unfin_idx in sents_seen: # check termination conditions for this sentence if not finished[sent] and is_finished(sent, step, unfin_idx): finished[sent] = True newly_finished.append(unfin_idx) return newly_finished reorder_state = None batch_idxs = None # initialize the decoder state, including: # - expanding the context over the batch dimension len_src x (B*beam) x H # - expanding the mask over the batch dimension (B*beam) x len_src decoder_states = dict() sub_decoder_states = dict() # for sub-model for i in range(self.n_models): decoder_states[i] = self.models[i].create_decoder_state(batches[i], beam_size, type=2, buffering=self.buffering) if self.opt.sub_model: for i in range(self.n_sub_models): sub_decoder_states[i] = self.sub_models[i].create_decoder_state(sub_batches[i], beam_size, type=2, buffering=self.buffering) if self.dynamic_max_len: src_len = src.size(0) max_len = math.ceil(int(src_len) * self.dynamic_max_len_scale) # Start decoding for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: if batch_idxs is not None: # update beam indices to take into account removed sentences corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs) reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size) for i, model in enumerate(self.models): decoder_states[i]._reorder_incremental_state(reorder_state) for i, model in enumerate(self.sub_models): sub_decoder_states[i]._reorder_incremental_state(reorder_state) decode_input = tokens[:, :step + 1] lprobs, avg_attn_scores = self._decode(decode_input, decoder_states, sub_decoder_states=sub_decoder_states) avg_attn_scores = None if self.use_filter: # the marked words are 1, so fill the reverse to inf lprobs.masked_fill_(~self.filter.unsqueeze(0), -math.inf) lprobs[:, self.tgt_pad] = -math.inf # never select pad # handle min and max length constraints if step >= max_len: lprobs[:, :self.tgt_eos] = -math.inf lprobs[:, self.tgt_eos + 1:] = -math.inf elif step < self.min_len: lprobs[:, self.tgt_eos] = -math.inf # handle prefix tokens (possibly with different lengths) # here prefix tokens is a list of word-ids if prefix_tokens is not None: if step == 0 and bsz == 1: # run the decoder through the prefix_tokens # store the scores and store the incremental states pass else: prefix_tokens = torch.tensor(prefix_tokens).type_as(tokens) if step < prefix_tokens.size(1) and step < max_len: prefix_tokens = torch.tensor(prefix_tokens).type_as(tokens) prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) prefix_mask = prefix_toks.ne(self.tgt_pad) lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) lprobs[prefix_mask] = lprobs[prefix_mask].scatter( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] ) # if prefix includes eos, then we should make sure tokens and # scores are the same across all beams eos_mask = prefix_toks.eq(self.tgt_eos) if eos_mask.any(): # validate that the first beam matches the prefix first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1] eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] assert (first_beam == target_prefix).all() def replicate_first_beam(tensor, mask): tensor = tensor.view(-1, beam_size, tensor.size(-1)) tensor[mask] = tensor[mask][:, :1, :] return tensor.view(-1, tensor.size(-1)) # copy tokens, scores and lprobs from the first beam to all beams tokens = replicate_first_beam(tokens, eos_mask_batch_dim) scores = replicate_first_beam(scores, eos_mask_batch_dim) lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim) if self.no_repeat_ngram_size > 0: # for each beam and batch sentence, generate a list of previous ngrams gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): gen_tokens = tokens[bbsz_idx].tolist() for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]): gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] # Record attention scores if avg_attn_scores is not None: if attn is None: attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) attn_buf = attn.clone() attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(lprobs) scores_buf = scores_buf.type_as(lprobs) eos_bbsz_idx = buffer('eos_bbsz_idx') eos_scores = buffer('eos_scores', type_of=scores) if self.no_repeat_ngram_size > 0: def calculate_banned_tokens(bbsz_idx): # before decoding the next token, prevent decoding of ngrams that have already appeared ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) return gen_ngrams[bbsz_idx].get(ngram_index, []) if step + 2 - self.no_repeat_ngram_size >= 0: # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)] else: banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf cand_scores, cand_indices, cand_beams = self.search.step( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], ) # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] cand_bbsz_idx = cand_beams.add(bbsz_offsets) # finalize hypotheses that end in eos (except for blacklisted ones) eos_mask = cand_indices.eq(self.tgt_eos) eos_mask[:, :beam_size][blacklist] = 0 # only consider eos when it's among the top beam_size indices torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_bbsz_idx, ) finalized_sents = set() if eos_bbsz_idx.numel() > 0: torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_scores, ) finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores) num_remaining_sent -= len(finalized_sents) assert num_remaining_sent >= 0 if num_remaining_sent == 0: break # assert step < max_len if len(finalized_sents) > 0: new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass batch_mask = cand_indices.new_ones(bsz) batch_mask[cand_indices.new(finalized_sents)] = 0 batch_idxs = batch_mask.nonzero(as_tuple=False).squeeze(-1) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] bbsz_offsets.resize_(new_bsz, 1) cand_bbsz_idx = cand_beams.add(bbsz_offsets) cand_scores = cand_scores[batch_idxs] cand_indices = cand_indices[batch_idxs] # if prefix_tokens is not None: # prefix_tokens = prefix_tokens[batch_idxs] src_lengths = src_lengths[batch_idxs] blacklist = blacklist[batch_idxs] scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) scores_buf.resize_as_(scores) tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) tokens_buf.resize_as_(tokens) if attn is not None: attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1) attn_buf.resize_as_(attn) bsz = new_bsz else: batch_idxs = None # Set active_mask so that values > cand_size indicate eos or # blacklisted hypos and values < cand_size indicate candidate # active hypos. After this, the min values per row are the top # candidate active hypos. active_mask = buffer('active_mask') eos_mask[:, :beam_size] |= blacklist torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[:eos_mask.size(1)], out=active_mask, ) # get the top beam_size active hypotheses, which are just the hypos # with the smallest values in active_mask active_hypos, new_blacklist = buffer('active_hypos'), buffer('new_blacklist') torch.topk( active_mask, k=beam_size, dim=1, largest=False, out=(new_blacklist, active_hypos) ) # update blacklist to ignore any finalized hypos blacklist = new_blacklist.ge(cand_size)[:, :beam_size] assert (~blacklist).any(dim=1).all() active_bbsz_idx = buffer('active_bbsz_idx') torch.gather( cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx, ) active_scores = torch.gather( cand_scores, dim=1, index=active_hypos, out=scores[:, step].view(bsz, beam_size), ) active_bbsz_idx = active_bbsz_idx.view(-1) active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses torch.index_select( tokens[:, :step + 1], dim=0, index=active_bbsz_idx, out=tokens_buf[:, :step + 1], ) torch.gather( cand_indices, dim=1, index=active_hypos, out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], ) if step > 0: torch.index_select( scores[:, :step], dim=0, index=active_bbsz_idx, out=scores_buf[:, :step], ) torch.gather( cand_scores, dim=1, index=active_hypos, out=scores_buf.view(bsz, beam_size, -1)[:, :, step], ) # copy attention for active hypotheses if attn is not None: torch.index_select( attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, out=attn_buf[:, :, :step + 2], ) # swap buffers tokens, tokens_buf = tokens_buf, tokens scores, scores_buf = scores_buf, scores if attn is not None: attn, attn_buf = attn_buf, attn # reorder incremental state in decoder reorder_state = active_bbsz_idx # sort by score descending for sent in range(len(finalized)): finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) return finalized, gold_scores, gold_words, allgold_scores def _decode(self, tokens, decoder_states, sub_decoder_states=None): # require batch first for everything outs = dict() attns = dict() for i in range(self.n_models): # decoder output contains the log-prob distribution of the next step decoder_output = self.models[i].step(tokens, decoder_states[i]) outs[i] = decoder_output['log_prob'] attns[i] = decoder_output['coverage'] for j in range(self.n_sub_models): sub_decoder_output = self.sub_models[j].step(tokens, sub_decoder_states[j]) outs[self.n_models + j] = sub_decoder_output['log_prob'] out = self._combine_outputs(outs, weight=self.ensemble_weight) # attn = self._combine_attention(attns) if self.vocab_size > out.size(-1): self.vocab_size = out.size(-1) # what the hell ? # attn = attn[:, -1, :] # I dont know what this line does attn = None # attn is never used in decoding probably return out, attn def translate(self, src_data, tgt_data, sub_src_data=None, type='mt'): # (1) convert words to indexes if isinstance(src_data[0], list) and type == 'asr': batches = list() for src_data_ in src_data: dataset = self.build_data(src_data_, tgt_data, type=type) batch = dataset.get_batch(0) batches.append(batch) else: dataset = self.build_data(src_data, tgt_data, type=type) batch = dataset.get_batch(0) # this dataset has only one mini-batch batches = [batch] * self.n_models src_data = [src_data] * self.n_models if sub_src_data is not None and len(sub_src_data) > 0: sub_dataset = self.build_data(sub_src_data, tgt_data, type='mt') sub_batch = sub_dataset.get_batch(0) sub_batches = [sub_batch] * self.n_sub_models sub_src_data = [sub_src_data] * self.n_sub_models else: sub_batches, sub_src_data = None, None batch_size = batches[0].size if self.cuda: for i, _ in enumerate(batches): batches[i].cuda(fp16=self.fp16) if sub_batches: for i, _ in enumerate(sub_batches): sub_batches[i].cuda(fp16=self.fp16) # (2) translate finalized, gold_score, gold_words, allgold_words = self.translate_batch(batches, sub_batches=sub_batches) pred_length = [] # (3) convert indexes to words pred_batch = [] src_data = src_data[0] for b in range(batch_size): # probably when the src is empty so beam search stops immediately if len(finalized[b]) == 0: assert len(src_data[b]) == 0, "The target search result is empty, assuming that the source is empty." pred_batch.append( [self.build_target_tokens([], src_data[b], None) for n in range(self.opt.n_best)] ) else: pred_batch.append( [self.build_target_tokens(finalized[b][n]['tokens'], src_data[b], None) for n in range(self.opt.n_best)] ) pred_score = [] for b in range(batch_size): if len(finalized[b]) == 0: pred_score.append( [torch.FloatTensor([0]) for n in range(self.opt.n_best)] ) else: pred_score.append( [torch.FloatTensor([finalized[b][n]['score']]) for n in range(self.opt.n_best)] ) return pred_batch, pred_score, pred_length, gold_score, gold_words, allgold_words
class GlobalStreamTranslator(Translator): """ A fast implementation of the Beam Search based translator Based on Fairseq implementation """ def __init__(self, opt): super().__init__(opt) self.search = BeamSearch(self.tgt_dict) self.eos = onmt.constants.EOS self.pad = onmt.constants.PAD self.bos = self.bos_id self.vocab_size = self.tgt_dict.size() self.min_len = 1 self.normalize_scores = opt.normalize self.len_penalty = opt.alpha self.decoder_states = defaultdict(lambda: None) self.prev_scores = torch.Tensor(self.opt.beam_size).fill_(0) self.prev_lengths = torch.LongTensor(self.opt.beam_size).fill_(0) if hasattr(opt, 'no_repeat_ngram_size'): self.no_repeat_ngram_size = opt.no_repeat_ngram_size else: self.no_repeat_ngram_size = 0 if hasattr(opt, 'dynamic_max_len'): self.dynamic_max_len = opt.dynamic_max_len else: self.dynamic_max_len = False if hasattr(opt, 'dynamic_max_len_scale'): self.dynamic_max_len_scale = opt.dynamic_max_len_scale else: self.dynamic_max_len_scale = 1.2 if hasattr(opt, 'dynamic_min_len_scale'): self.dynamic_min_len_scale = opt.dynamic_min_len_scale else: self.dynamic_min_len_scale = 0.8 if opt.verbose: print('* Current bos id: %d' % self.bos_id, onmt.constants.BOS) print('* Using fast beam search implementation') self.max_memory_size = opt.max_memory_size for i in range(len(self.models)): self.models[i].set_memory_size(self.max_memory_size, self.max_memory_size) def reset_stream(self): self.decoder_states = defaultdict(lambda: None) def translateBatch(self, batch): with torch.no_grad(): return self._translateBatch(batch) def _translateBatch(self, batch): # Batch size is in different location depending on data. beam_size = self.opt.beam_size bsz = batch_size = batch.size max_len = self.opt.max_sent_length gold_scores = batch.get('source').data.new(batch_size).float().zero_() gold_words = 0 allgold_scores = [] if batch.has_target: # Use the first model to decode model_ = self.models[0] gold_words, gold_scores, allgold_scores = model_.decode(batch) # (3) Start decoding # initialize buffers src = batch.get('source') scores = src.new(bsz * beam_size, max_len + 1).float().fill_(0) self.prev_scores = self.prev_scores.type_as(scores) self.prev_lengths = self.prev_lengths.to(scores.device) scores_buf = scores.clone() tokens = src.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) beams = src.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) tokens_buf = tokens.clone() beams_buf = beams.clone() tokens[:, 0].fill_(self.bos) # first token is bos beams[:, 0].fill_(0) # first one is the same ... attn, attn_buf = None, None nonpad_idxs = None src_tokens = src.transpose(0, 1) # batch x time src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) blacklist = src_tokens.new_zeros(bsz, beam_size).eq( -1) # forward and backward-compatible False mask prefix_tokens = None # list of completed sentences finalized = [[] for i in range(bsz)] finished = [False for i in range(bsz)] num_remaining_sent = bsz # number of candidate hypos per step cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens) # helper function for allocating buffers on the fly buffers = {} def buffer(name, type_of=tokens): # noqa if name not in buffers: buffers[name] = type_of.new() return buffers[name] def is_finished(sent, step, unfinalized_scores=None): """ Check whether we've finished generation for a given sentence, by comparing the worst score among finalized hypotheses to the best possible score among unfinalized hypotheses. """ assert len(finalized[sent]) <= beam_size if len(finalized[sent]) == beam_size: return True return False def finalize_hypos(step, bbsz_idx, eos_scores): """ Finalize the given hypotheses at this step, while keeping the total number of finalized hypotheses per sentence <= beam_size. Note: the input must be in the desired finalization order, so that hypotheses that appear earlier in the input are preferred to those that appear later. Args: step: current time step bbsz_idx: A vector of indices in the range [0, bsz*beam_size), indicating which hypotheses to finalize eos_scores: A vector of the same size as bbsz_idx containing scores for each hypothesis """ assert bbsz_idx.numel() == eos_scores.numel() # clone relevant token and attention tensors tokens_clone = tokens.index_select(0, bbsz_idx) beams_clone = beams.index_select(0, bbsz_idx) prev_lengths = self.prev_lengths.index_select(0, bbsz_idx) tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS beams_clone = beams_clone[:, 0:step + 2] assert not tokens_clone.eq(self.eos).any() tokens_clone[:, step] = self.eos attn_clone = attn.index_select( 0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None # compute scores per token position pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] pos_scores[:, step] = eos_scores # convert from cumulative to per-position scores pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] raw_scores = eos_scores.clone() # normalize sentence-level scores if self.normalize_scores: eos_scores /= (step + 1 + prev_lengths)**self.len_penalty cum_unfin = [] prev = 0 for f in finished: if f: prev += 1 else: cum_unfin.append(prev) sents_seen = set() assert len(self.decoder_states) == 1 beam_buffers = self.decoder_states[0].get_beam_buffer(bbsz_idx) for i, (idx, score) in enumerate( zip(bbsz_idx.tolist(), eos_scores.tolist())): unfin_idx = idx // beam_size sent = unfin_idx + cum_unfin[unfin_idx] # looks like sent and unfin_idx are both 0 when batch_size is 1 ... # until everything is finished sents_seen.add((sent, unfin_idx)) def get_buffer(): buffer = dict() for l in beam_buffers: buffer[l] = dict() # take that state for key in beam_buffers[l]: buffer[l][key] = beam_buffers[l][ key][:, i, :].unsqueeze(1) return buffer def get_hypo(): if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i] else: hypo_attn = None return { 'tokens': tokens_clone[i], 'score': score, 'attention': hypo_attn, # src_len x tgt_len 'alignment': None, 'positional_scores': pos_scores[i], 'hidden_buffer': get_buffer(), 'raw_score': raw_scores[i] } if len(finalized[sent]) < beam_size: finalized[sent].append(get_hypo()) newly_finished = [] for sent, unfin_idx in sents_seen: # check termination conditions for this sentence if not finished[sent] and is_finished(sent, step, unfin_idx): finished[sent] = True newly_finished.append(unfin_idx) return newly_finished reorder_state = None batch_idxs = None # initialize the decoder state, including: # - expanding the context over the batch dimension len_src x (B*beam) x H # - expanding the mask over the batch dimension (B*beam) x len_src for i in range(self.n_models): # decoder_states[i] = self.models[i].create_decoder_state(batch, beam_size, type=2, streaming=False) self.decoder_states[i] = self.models[i].create_decoder_state( batch, beam_size, previous_decoding_state=self.decoder_states[i], streaming=True) if self.dynamic_max_len: src_len = src.size(0) max_len = min(math.ceil(int(src_len) * self.dynamic_max_len_scale), self.opt.max_sent_length) min_len = math.ceil(int(src_len) * self.dynamic_min_len_scale) else: min_len = self.min_len # Start decoding for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: if batch_idxs is not None: # update beam indices to take into account removed sentences corr = batch_idxs - torch.arange( batch_idxs.numel()).type_as(batch_idxs) reorder_state.view(-1, beam_size).add_( corr.unsqueeze(-1) * beam_size) for i, model in enumerate(self.models): self.decoder_states[i]._reorder_incremental_state( reorder_state) decode_input = tokens[:, :step + 1] # lprobs size: [batch x beam x vocab_size] lprobs, avg_attn_scores = self._decode(decode_input, self.decoder_states) avg_attn_scores = None lprobs[:, self.pad] = -math.inf # never select pad lprobs[:, self.bos] = -math.inf # never select bos ... # handle min and max length constraints if step >= max_len: lprobs[:, :self.eos] = -math.inf lprobs[:, self.eos + 1:] = -math.inf elif step < min_len: lprobs[:, self.eos] = -math.inf # handle prefix tokens (possibly with different lengths) # if prefix_tokens is not None and step < prefix_tokens.size(1): # prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) # prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) # prefix_mask = prefix_toks.ne(self.pad) # lprobs[prefix_mask] = -math.inf # lprobs[prefix_mask] = lprobs[prefix_mask].scatter_( # -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs # ) # # if prefix includes eos, then we should make sure tokens and # # scores are the same across all beams # eos_mask = prefix_toks.eq(self.eos) # if eos_mask.any(): # # validate that the first beam matches the prefix # first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1] # eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] # target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] # assert (first_beam == target_prefix).all() # # def replicate_first_beam(tensor, mask): # tensor = tensor.view(-1, beam_size, tensor.size(-1)) # tensor[mask] = tensor[mask][:, :1, :] # return tensor.view(-1, tensor.size(-1)) # # # copy tokens, scores and lprobs from the first beam to all beams # tokens = replicate_first_beam(tokens, eos_mask_batch_dim) # scores = replicate_first_beam(scores, eos_mask_batch_dim) # lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim) if self.no_repeat_ngram_size > 0: # for each beam and batch sentence, generate a list of previous ngrams gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): gen_tokens = tokens[bbsz_idx].tolist() for ngram in zip(*[ gen_tokens[i:] for i in range(self.no_repeat_ngram_size) ]): gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] # Record attention scores if avg_attn_scores is not None: if attn is None: attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) attn_buf = attn.clone() attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(lprobs) scores_buf = scores_buf.type_as(lprobs) eos_bbsz_idx = buffer('eos_bbsz_idx') eos_scores = buffer('eos_scores', type_of=scores) if self.no_repeat_ngram_size > 0: def calculate_banned_tokens(bbsz_idx): # before decoding the next token, prevent decoding of ngrams that have already appeared ngram_index = tuple( tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) return gen_ngrams[bbsz_idx].get(ngram_index, []) if step + 2 - self.no_repeat_ngram_size >= 0: # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet banned_tokens = [ calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size) ] else: banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf cand_scores, cand_indices, cand_beams = self.search.step( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], initial_score=self.prev_scores) # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] # when bsz = 1, cand_bbsz_idx is not different than cand_beams cand_bbsz_idx = cand_beams.add(bbsz_offsets) # finalize hypotheses that end in eos (except for blacklisted ones) eos_mask = cand_indices.eq(self.eos) eos_mask[:, :beam_size][blacklist] = 0 # only consider eos when it's among the top beam_size indices torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_bbsz_idx, ) # so: cand_bbsz_idx is a list of beam indices # eos_bbsz_idx in the case of batch_size 1: a list of beam_indices in which the eos is reached finalized_sents = set() if eos_bbsz_idx.numel() > 0: torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_scores, ) finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores) num_remaining_sent -= len(finalized_sents) assert num_remaining_sent >= 0 if num_remaining_sent == 0: break assert step < max_len # if batch size == 1 then this block will not be touched if len(finalized_sents) > 0: new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass batch_mask = cand_indices.new_ones(bsz) batch_mask[cand_indices.new(finalized_sents)] = 0 batch_idxs = batch_mask.nonzero().squeeze(-1) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] bbsz_offsets.resize_(new_bsz, 1) cand_bbsz_idx = cand_beams.add(bbsz_offsets) cand_scores = cand_scores[batch_idxs] cand_indices = cand_indices[batch_idxs] # if prefix_tokens is not None: # prefix_tokens = prefix_tokens[batch_idxs] src_lengths = src_lengths[batch_idxs] blacklist = blacklist[batch_idxs] scores = scores.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) scores_buf.resize_as_(scores) tokens = tokens.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) tokens_buf.resize_as_(tokens) if attn is not None: attn = attn.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, attn.size(1), -1) attn_buf.resize_as_(attn) bsz = new_bsz else: batch_idxs = None # Set active_mask so that values > cand_size indicate eos or # blacklisted hypos and values < cand_size indicate candidate # active hypos. After this, the min values per row are the top # candidate active hypos. active_mask = buffer('active_mask') eos_mask[:, :beam_size] |= blacklist torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[:eos_mask.size(1)], out=active_mask, ) # get the top beam_size active hypotheses, which are just the hypos # with the smallest values in active_mask active_hypos, new_blacklist = buffer('active_hypos'), buffer( 'new_blacklist') torch.topk(active_mask, k=beam_size, dim=1, largest=False, out=(new_blacklist, active_hypos)) # update blacklist to ignore any finalized hypos blacklist = new_blacklist.ge(cand_size)[:, :beam_size] assert (~blacklist).any(dim=1).all() active_bbsz_idx = buffer('active_bbsz_idx') torch.gather( cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx, ) active_scores = torch.gather( cand_scores, dim=1, index=active_hypos, out=scores[:, step].view(bsz, beam_size), ) active_bbsz_idx = active_bbsz_idx.view(-1) active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses torch.index_select( tokens[:, :step + 1], dim=0, index=active_bbsz_idx, out=tokens_buf[:, :step + 1], ) torch.index_select( beams[:, :step + 1], dim=0, index=active_bbsz_idx, out=beams_buf[:, step + 1], ) # add the cand_indices (words) into the token buffer of the last step torch.gather( cand_indices, dim=1, index=active_hypos, out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], ) torch.gather( cand_bbsz_idx, dim=1, index=active_hypos, out=beams_buf.view(bsz, beam_size, -1)[:, :, step + 1], ) # print(cand_indices.size(), cand_bbsz_idx.size()) if step > 0: torch.index_select( scores[:, :step], dim=0, index=active_bbsz_idx, out=scores_buf[:, :step], ) torch.gather( cand_scores, dim=1, index=active_hypos, out=scores_buf.view(bsz, beam_size, -1)[:, :, step], ) # copy attention for active hypotheses if attn is not None: torch.index_select( attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, out=attn_buf[:, :, :step + 2], ) # swap buffers tokens, tokens_buf = tokens_buf, tokens scores, scores_buf = scores_buf, scores beams, beams_buf = beams_buf, beams if attn is not None: attn, attn_buf = attn_buf, attn # reorder incremental state in decoder reorder_state = active_bbsz_idx # sort by score descending # Re-encoding step # for beam in range(self.opt.beam_size): # " batch size = 1 " # tensor = finalized[0][beam]['tokens'] # words = " ".join(self.tgt_dict.convertToLabels(tensor, onmt.constants.EOS, including_stop=False)) # beam_org = finalized[0][beam]['beam_origin'] # print(beam_org, words) for sent in range(len(finalized)): finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) for sent in range(len(finalized)): for beam in range(len(finalized[sent])): tensor = finalized[sent][beam]['tokens'] words = self.tgt_dict.convertToLabels(tensor, onmt.constants.EOS, including_stop=False) n_words = len(words) buffer_state = finalized[sent][beam]['hidden_buffer'] sentence = " ".join(words) # self.prev_scores[beam].fill_(finalized[sent][beam]['raw_score']) # self.prev_lengths[beam].fill_(n_words + 2) # assign the buffers to the decoder_states # at this point, we need to somehow make zero padding self.decoder_states[sent].set_beam_buffer(finalized[sent]) # self.decoder_states = defaultdict(lambda: None) # Should we do it before sorting, or after sorting # Step 1: revert the memory of the decoder to the starting point # Done. they are the buffer_state # Step 3: Re-select the buffer ( # print(tensor) return finalized, gold_scores, gold_words, allgold_scores def _decode(self, tokens, decoder_states): # require batch first for everything outs = dict() attns = dict() for i in range(self.n_models): # streaming = True in this case decoder_output = self.models[i].step(tokens, decoder_states[i], streaming=True) # take the last decoder state # decoder_hidden = decoder_hidden.squeeze(1) # attns[i] = coverage[:, -1, :].squeeze(1) # batch * beam x src_len # batch * beam x vocab_size # outs[i] = self.models[i].generator(decoder_hidden) outs[i] = decoder_output['log_prob'] attns[i] = decoder_output['coverage'] out = self._combine_outputs(outs) attn = self._combine_attention(attns) # attn = attn[:, -1, :] # I dont know what this line means attn = None # lol this is never used probably return out, attn def translate(self, src_data, tgt_data, type='mt'): # (1) convert words to indexes dataset = self.build_data(src_data, tgt_data, type=type) batch = dataset.next()[0] if self.cuda: batch.cuda(fp16=self.fp16) # ~ batch = self.to_variable(dataset.next()[0]) batch_size = batch.size # (2) translate finalized, gold_score, gold_words, allgold_words = self.translateBatch( batch) pred_length = [] # (3) convert indexes to words pred_batch = [] for b in range(batch_size): pred_batch.append([ self.build_target_tokens(finalized[b][n]['tokens'], src_data[b], None) for n in range(self.opt.n_best) ]) pred_score = [] for b in range(batch_size): pred_score.append([ torch.FloatTensor([finalized[b][n]['score']]) for n in range(self.opt.n_best) ]) return pred_batch, pred_score, pred_length, gold_score, gold_words, allgold_words