def _greedy_decode(self): """Performs greedy decoding from the start node. Used to obtain the initial hypothesis. """ hypo = PartialHypothesis() hypos = [] posteriors = [] score_breakdowns = [] scores = [] bag = dict(self.full_bag) while bag: posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = copy.deepcopy(self.get_predictor_states()) hypos.append(hypo) posteriors.append(posterior) score_breakdowns.append(score_breakdown) best_word = utils.argmax({w: posterior[w] for w in bag}) bag[best_word] -= 1 if bag[best_word] < 1: del bag[best_word] self.consume(best_word) hypo = hypo.expand(best_word, None, posterior[best_word], score_breakdown[best_word]) scores.append(posterior[best_word]) posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = self.get_predictor_states() hypos.append(hypo) posteriors.append(posterior) score_breakdowns.append(score_breakdown) hypo = hypo.expand(utils.EOS_ID, None, posterior[utils.EOS_ID], score_breakdown[utils.EOS_ID]) logging.debug("Greedy hypo (%f): %s" % ( hypo.score, ' '.join([str(w) for w in hypo.trgt_sentence]))) scores.append(posterior[utils.EOS_ID]) self.best_score = hypo.score self.add_full_hypo(hypo.generate_full_hypothesis()) self._process_new_hypos(FlipCandidate(hypo.trgt_sentence, scores, self._create_dummy_bigrams(), hypo.score), len(hypo.trgt_sentence), hypos, posteriors, score_breakdowns)
def _greedy_decode(self): """Performs greedy decoding from the start node. Used to obtain initial bigram statistics. """ hypo = PartialHypothesis() hypos = [] posteriors = [] score_breakdowns = [] bag = dict(self.full_bag) while bag: posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = copy.deepcopy(self.get_predictor_states()) bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos} bag_breakdown = {w: score_breakdown[w] for w in self.full_bag_with_eos} posteriors.append(bag_posterior) score_breakdowns.append(bag_breakdown) hypos.append(hypo) best_word = utils.argmax({w: bag_posterior[w] for w in bag}) bag[best_word] -= 1 if bag[best_word] < 1: del bag[best_word] self.consume(best_word) hypo = hypo.expand(best_word, None, bag_posterior[best_word], score_breakdown[best_word]) posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = copy.deepcopy(self.get_predictor_states()) bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos} bag_breakdown = {w: score_breakdown[w] for w in self.full_bag_with_eos} posteriors.append(bag_posterior) score_breakdowns.append(bag_breakdown) hypos.append(hypo) hypo = hypo.cheap_expand(utils.EOS_ID, bag_posterior[utils.EOS_ID], score_breakdown[utils.EOS_ID]) logging.debug("Greedy hypo (%f): %s" % ( hypo.score, ' '.join([str(w) for w in hypo.trgt_sentence]))) self._process_new_hypos(hypos, posteriors, score_breakdowns, hypo)