def decode(self, src_sentence, trgt_sentence): self.trgt_sentence = trgt_sentence + [utils.EOS_ID] self.initialize_predictor(src_sentence) hypo = PartialHypothesis(self.get_predictor_states(), self.calculate_stats) while hypo.get_last_word() != utils.EOS_ID: self._expand_hypo(hypo) hypo.score = self.get_adjusted_score(hypo) self.add_full_hypo(hypo.generate_full_hypothesis()) return self.get_full_hypos_sorted()
def decode(self, src_sentence, seed=0): self.initialize_predictor(src_sentence) hypos = [ PartialHypothesis(copy.deepcopy(self.get_predictor_states())) for i in range(self.nbest) ] t = 0 while hypos and t < self.max_len: next_hypos = [] for sen_seed, hypo in enumerate(hypos): if hypo.get_last_word() == utils.EOS_ID: hypo.score = self.get_adjusted_score(hypo) self.add_full_hypo(hypo.generate_full_hypothesis()) else: self._expand_hypo(hypo, seed=seed + sen_seed) next_hypos.append(hypo) hypos = next_hypos t += 1 for hypo in hypos: hypo.score = self.get_adjusted_score(hypo) self.add_full_hypo(hypo.generate_full_hypothesis()) return self.get_full_hypos_sorted()
def decode(self, src_sentence): """Decodes a single source sentence using A* search. """ self.initialize_predictor(src_sentence) self.lower_bound = self.get_empty_hypo( ) if self.use_lower_bound else None self.cur_capacity = self.capacity open_set = MinMaxHeap( reserve=self.capacity) if self.capacity > 0 else [] self.push( open_set, 0.0, PartialHypothesis(self.get_predictor_states(), self.calculate_stats)) while open_set: c, hypo = self.pop(open_set) if hypo.get_last_word() == utils.EOS_ID: # Found best hypothesis hypo.score = self.get_adjusted_score(hypo) self.add_full_hypo(hypo.generate_full_hypothesis()) if len(self.full_hypos ) == self.nbest: # if we have enough hypos return self.get_full_hypos_sorted() self.cur_capacity -= 1 continue if len(hypo) == self.max_len: #discard and continue continue for next_hypo in self._expand_hypo(hypo, self.capacity): score = self.get_adjusted_score(next_hypo) self.push(open_set, score, next_hypo) if not self.full_hypos: self.add_full_hypo(self.lower_bound.generate_full_hypothesis()) return self.get_full_hypos_sorted()
def decode(self, src_sentence): self.initialize_predictor(src_sentence) hypothesis = PartialHypothesis(self.get_predictor_states()) while hypothesis.get_last_word() != utils.EOS_ID and len( hypothesis) < self.max_len: ids, posterior, original_posterior = self.apply_predictor( hypothesis if self.gumbel else None, 1) trgt_word = ids[0] if self.gumbel: hypothesis.base_score += original_posterior[0] hypothesis.score_breakdown.append(original_posterior[0]) else: hypothesis.score += posterior[0] hypothesis.score_breakdown.append(posterior[0]) hypothesis.trgt_sentence.append(trgt_word) self.consume(trgt_word) self.add_full_hypo(hypothesis.generate_full_hypothesis()) return self.full_hypos
def initialize_order_ds(self): self.queues = [MinMaxHeap() for k in range(self.max_len + 1)] self.queue_order = PointerQueue([0.0], reserve=self.max_len) self.time_sync = defaultdict(lambda: self.beam if self.beam > 0 else utils.INF) # Initialize BOS hypothesis self.queues[0].insert( (0.0, PartialHypothesis(self.get_predictor_states()))) self.time_sync[0] = 1
def _get_initial_hypos(self): """Get the list of initial ``PartialHypothesis``. """ bos_hypo = PartialHypothesis(self.get_predictor_states()) hypos = self._expand_hypo(bos_hypo, self.beam_size) inds = list(np.cumsum(self.group_sizes)) return [hypos[a:b] for a, b in zip([0] + inds[:-1], inds)]
def _get_initial_hypos(self): """Get the list of initial ``PartialHypothesis``. """ return [PartialHypothesis(self.get_predictor_states())]
def _get_initial_hypos(self): """Get the list of initial ``PartialHypothesis``. """ return [[PartialHypothesis(copy.deepcopy(self.get_predictor_states()), self.calculate_stats)] for i in range(self.num_groups)]