示例#1
0
文件: greedy.py 项目: ucam-smt/sgnmt
 def decode(self, src_sentence):
     """Decode a single source sentence in a greedy way: Always take
     the highest scoring word as next word and proceed to the next
     position. This makes it possible to decode without using the 
     predictors ``get_state()`` and ``set_state()`` methods as we
     do not have to keep track of predictor states.
     
     Args:
         src_sentence (list): List of source word ids without <S> or
                              </S> which make up the source sentence
     
     Returns:
         list. A list of a single best ``Hypothesis`` instance."""
     self.initialize_predictors(src_sentence)
     trgt_sentence = []
     score_breakdown = []
     trgt_word = None
     score = 0.0
     while trgt_word != utils.EOS_ID and len(trgt_sentence) <= self.max_len:
         posterior,breakdown = self.apply_predictors(1)
         trgt_word = utils.argmax(posterior)
         score += posterior[trgt_word]
         trgt_sentence.append(trgt_word)
         logging.debug("Partial hypothesis (%f): %s" % (
                 score, " ".join([str(i) for i in trgt_sentence]))) 
         score_breakdown.append(breakdown[trgt_word])
         self.consume(trgt_word)
     self.add_full_hypo(Hypothesis(trgt_sentence, score, score_breakdown))
     return self.full_hypos
示例#2
0
 def decode(self, src_sentence):
     """Decode a single source sentence in a greedy way: Always take
     the highest scoring word as next word and proceed to the next
     position. This makes it possible to decode without using the 
     predictors ``get_state()`` and ``set_state()`` methods as we
     do not have to keep track of predictor states.
     
     Args:
         src_sentence (list): List of source word ids without <S> or
                              </S> which make up the source sentence
     
     Returns:
         list. A list of a single best ``Hypothesis`` instance."""
     self.initialize_predictors(src_sentence)
     trgt_sentence = []
     score_breakdown = []
     trgt_word = None
     score = 0.0
     while trgt_word != utils.EOS_ID and len(trgt_sentence) <= self.max_len:
         posterior, breakdown = self.apply_predictors(1)
         trgt_word = utils.argmax(posterior)
         score += posterior[trgt_word]
         trgt_sentence.append(trgt_word)
         logging.debug("Partial hypothesis (%f): %s" %
                       (score, " ".join([str(i) for i in trgt_sentence])))
         score_breakdown.append(breakdown[trgt_word])
         self.consume(trgt_word)
     self.add_full_hypo(Hypothesis(trgt_sentence, score, score_breakdown))
     return self.full_hypos
示例#3
0
文件: decoder.py 项目: chagge/sgnmt
 def greedy_decode(self, hypo):
     """Helper function for greedy decoding from a certain point in
     the search tree."""
     best_word = hypo.trgt_sentence[-1]
     prev_hypo = hypo
     while best_word != utils.EOS_ID:
         self.consume(best_word)
         posterior,score_breakdown = self.apply_predictors()
         if len(posterior) < 1:
             return
         best_word = utils.argmax(posterior)
         best_word_score = posterior[best_word]
         new_hypo = prev_hypo.expand(best_word,
                                     None,
                                     best_word_score,
                                     score_breakdown[best_word])
         if new_hypo.score < self.best_score: # Admissible pruning
             return
         logging.debug("Expanded hypo: score=%f prefix=%s" % (
                                                 new_hypo.score,
                                                 new_hypo.trgt_sentence))
         if len(posterior) > 1:
             posterior.pop(best_word)
             children = sorted([RestartingChild(w,
                                                posterior[w],
                                                score_breakdown[w])
                 for w in posterior], key=lambda c: c.score, reverse=True)
             prev_hypo.predictor_states = copy.deepcopy(
                                             self.get_predictor_states())
             heappush(self.open_nodes,
                      (best_word_score-children[0].score,
                       RestartingNode(prev_hypo, children)))
         prev_hypo = new_hypo
     self.hypos.append(prev_hypo.generate_full_hypothesis())
     self.best_score = max(self.best_score, prev_hypo.score)
示例#4
0
文件: bow.py 项目: ucam-smt/sgnmt
 def create_initial_node(self):
     """Create the root node for the search tree. """
     init_hypo = PartialHypothesis()
     posterior,score_breakdown = self.apply_predictors()
     best_word = utils.argmax(posterior)
     init_hypo.predictor_states = self.get_predictor_states()
     init_node = BOWNode(init_hypo, posterior, score_breakdown, [])
     self._add_to_heap(init_node, best_word, 0.0) # Expected score irrelevant 
示例#5
0
文件: bow.py 项目: ml-lab/sgnmt
 def create_initial_node(self):
     """Create the root node for the search tree. """
     init_hypo = PartialHypothesis()
     posterior, score_breakdown = self.apply_predictors()
     best_word = utils.argmax(posterior)
     init_hypo.predictor_states = self.get_predictor_states()
     init_node = BOWNode(init_hypo, posterior, score_breakdown, [])
     self._add_to_heap(init_node, best_word,
                       0.0)  # Expected score irrelevant
示例#6
0
文件: bow.py 项目: ml-lab/sgnmt
    def greedy_decode(self, node, word, single_step):
        """Helper function for greedy decoding from a certain point in
        the search tree."""

        prev_hypo = node.hypo.expand(word, None, node.posterior[word],
                                     node.score_breakdown[word])
        prev_nodes = node.prev_nodes + [node]

        best_word = word
        while ((prev_hypo.score > self.best_score or not self.early_stopping)
               and best_word != utils.EOS_ID):
            self.consume(best_word)
            posterior, score_breakdown = self.apply_predictors()
            if len(posterior) < 1:
                return
            self._update_best_word_scores(posterior)
            best_word = utils.argmax(posterior)
            best_word_score = posterior[best_word]
            new_hypo = prev_hypo.expand(best_word, None, best_word_score,
                                        score_breakdown[best_word])
            logging.debug("Expanded hypo: len=%d score=%f prefix= %s" %
                          (len(new_hypo.trgt_sentence), new_hypo.score,
                           ' '.join([str(w) for w in new_hypo.trgt_sentence])))
            node = BOWNode(prev_hypo, posterior, score_breakdown,
                           list(prev_nodes))
            if not single_step:
                del node.active_arcs[best_word]
                if len(node.active_arcs) > 0:
                    prev_hypo.predictor_states = copy.deepcopy(
                        self.get_predictor_states())
            else:
                prev_hypo.predictor_states = self.get_predictor_states()
            prev_nodes.append(node)
            prev_hypo = new_hypo
            if single_step:
                break
        full_hypo_score = prev_hypo.score
        if best_word == utils.EOS_ID:  # Full hypo
            self.add_full_hypo(prev_hypo.generate_full_hypothesis())
            self.best_score = max(self.best_score, full_hypo_score)
        else:
            full_hypo_score = self._estimate_full_hypo_score(prev_hypo)
        # Update the heap
        sen = prev_hypo.trgt_sentence
        l = len(sen)
        for pos in xrange(l):
            worst_scores = {}
            for p in xrange(pos, l):
                worst_scores[sen[p]] = min(worst_scores.get(sen[p], 0.0),
                                           prev_nodes[p].posterior[sen[p]])
            node = prev_nodes[pos]
            for w in node.active_arcs:
                expected_score = self._estimate_full_hypo_score(
                    node.hypo.cheap_expand(w, node.posterior[w],
                                           node.score_breakdown[w]))
                self._add_to_heap(node, w, expected_score)
示例#7
0
 def greedy_decode(self, hypo):
     """Helper function for greedy decoding from a certain point in
     the search tree."""
     best_word = hypo.trgt_sentence[-1]
     prev_hypo = hypo
     remaining_exps = max(self.max_expansions - self.apply_predictors_count,
                          1)
     while (best_word != utils.EOS_ID 
            and len(prev_hypo.trgt_sentence) <= self.max_len):
         self.consume(best_word)
         posterior,score_breakdown = self.apply_predictors()
         if len(posterior) < 1:
             return
         best_word = utils.argmax(posterior)
         best_word_score = posterior[best_word]
         new_hypo = prev_hypo.expand(best_word,
                                     None,
                                     best_word_score,
                                     score_breakdown[best_word])
         if new_hypo.score < self.best_score: # Admissible pruning
             return
         logging.debug("Expanded hypo: score=%f prefix= %s" % (
                         new_hypo.score,
                         ' '.join([str(w) for w in new_hypo.trgt_sentence])))
         if len(posterior) > 1:
             if not self.always_single_step:
                 posterior.pop(best_word)
             children = sorted([RestartingChild(w,
                                                posterior[w],
                                                score_breakdown[w])
                 for w in posterior], key=lambda c: c.score, reverse=True)
             children = children[:remaining_exps]
             node_cost = self.get_node_cost(0.0, 
                                            best_word_score, 
                                            children[0].score)
             if node_cost <= self.max_heap_node_cost:
                 prev_hypo.predictor_states = copy.deepcopy(
                                             self.get_predictor_states())
                 heappush(self.open_nodes, (node_cost,
                                            RestartingNode(prev_hypo,
                                                           children)))
         prev_hypo = new_hypo
         if self.always_single_step:
             break
     if best_word == utils.EOS_ID:
         self.add_full_hypo(prev_hypo.generate_full_hypothesis())
         if prev_hypo.score > self.best_score: 
             logging.info("New_best (ID: %d): score=%f exp=%d hypo=%s" 
                 % (self.current_sen_id + 1,
                    prev_hypo.score, 
                    self.apply_predictors_count,
                    ' '.join([str(w) for w in prev_hypo.trgt_sentence])))
             self.best_score = prev_hypo.score
示例#8
0
 def greedy_decode(self, hypo):
     """Helper function for greedy decoding from a certain point in
     the search tree."""
     best_word = hypo.trgt_sentence[-1]
     prev_hypo = hypo
     remaining_exps = max(self.max_expansions - self.apply_predictors_count,
                          1)
     while (best_word != utils.EOS_ID 
            and len(prev_hypo.trgt_sentence) <= self.max_len):
         self.consume(best_word)
         posterior,score_breakdown = self.apply_predictors()
         if len(posterior) < 1:
             return
         best_word = utils.argmax(posterior)
         best_word_score = posterior[best_word]
         new_hypo = prev_hypo.expand(best_word,
                                     None,
                                     best_word_score,
                                     score_breakdown[best_word])
         if new_hypo.score < self.best_score: # Admissible pruning
             return
         logging.debug("Expanded hypo: score=%f prefix= %s" % (
                         new_hypo.score,
                         ' '.join([str(w) for w in new_hypo.trgt_sentence])))
         if len(posterior) > 1:
             if not self.always_single_step:
                 posterior.pop(best_word)
             children = sorted([RestartingChild(w,
                                                posterior[w],
                                                score_breakdown[w])
                 for w in posterior], key=lambda c: c.score, reverse=True)
             children = children[:remaining_exps]
             node_cost = self.get_node_cost(0.0, 
                                            best_word_score, 
                                            children[0].score)
             if node_cost <= self.max_heap_node_cost:
                 prev_hypo.predictor_states = copy.deepcopy(
                                             self.get_predictor_states())
                 heappush(self.open_nodes, (node_cost,
                                            RestartingNode(prev_hypo,
                                                           children)))
         prev_hypo = new_hypo
         if self.always_single_step:
             break
     if best_word == utils.EOS_ID:
         self.add_full_hypo(prev_hypo.generate_full_hypothesis())
         if prev_hypo.score > self.best_score: 
             logging.info("New_best (ID: %d): score=%f exp=%d hypo=%s" 
                 % (self.current_sen_id + 1,
                    prev_hypo.score, 
                    self.apply_predictors_count,
                    ' '.join([str(w) for w in prev_hypo.trgt_sentence])))
             self.best_score = prev_hypo.score
示例#9
0
 def _greedy_decode(self):
     """Performs greedy decoding from the start node. Used to obtain
     the initial hypothesis.
     """
     hypo = PartialHypothesis()
     hypos = []
     posteriors = []
     score_breakdowns = []
     scores = []
     bag = dict(self.full_bag)
     while bag:
         posterior,score_breakdown = self.apply_predictors()
         hypo.predictor_states = copy.deepcopy(self.get_predictor_states())
         hypos.append(hypo)
         posteriors.append(posterior)
         score_breakdowns.append(score_breakdown)
         best_word = utils.argmax({w: posterior[w] for w in bag})
         bag[best_word] -= 1
         if bag[best_word] < 1:
             del bag[best_word]
         self.consume(best_word)
         hypo = hypo.expand(best_word,
                            None,
                            posterior[best_word],
                            score_breakdown[best_word])
         scores.append(posterior[best_word])
     posterior,score_breakdown = self.apply_predictors()
     hypo.predictor_states = self.get_predictor_states()
     hypos.append(hypo)
     posteriors.append(posterior)
     score_breakdowns.append(score_breakdown)
     hypo = hypo.expand(utils.EOS_ID,
                        None,
                        posterior[utils.EOS_ID],
                        score_breakdown[utils.EOS_ID])
     logging.debug("Greedy hypo (%f): %s" % (
                       hypo.score,
                       ' '.join([str(w) for w in hypo.trgt_sentence])))
     scores.append(posterior[utils.EOS_ID])
     self.best_score = hypo.score
     self.add_full_hypo(hypo.generate_full_hypothesis())
     self._process_new_hypos(FlipCandidate(hypo.trgt_sentence,
                                            scores,
                                            self._create_dummy_bigrams(),
                                            hypo.score),
                              len(hypo.trgt_sentence),
                              hypos,
                              posteriors,
                              score_breakdowns)
示例#10
0
 def estimate_future_cost_without_cache(self, hypo):
     """Disabled cache... """
     old_states = self.decoder.get_predictor_states()
     self.decoder.set_predictor_states(copy.deepcopy(old_states))
     # Greedy decoding
     trgt_word = hypo.trgt_sentence[-1]
     score = 0.0
     while trgt_word != utils.EOS_ID:
         self.decoder.consume(trgt_word)
         posterior,_ = self.decoder.apply_predictors()
         trgt_word = utils.argmax(posterior)
         score += posterior[trgt_word]
     # Reset predictor states
     self.decoder.set_predictor_states(old_states)
     return -score
示例#11
0
 def estimate_future_cost_without_cache(self, hypo):
     """Disabled cache... """
     old_states = self.decoder.get_predictor_states()
     self.decoder.set_predictor_states(copy.deepcopy(old_states))
     # Greedy decoding
     trgt_word = hypo.trgt_sentence[-1]
     score = 0.0
     while trgt_word != utils.EOS_ID:
         self.decoder.consume(trgt_word)
         posterior, _ = self.decoder.apply_predictors()
         trgt_word = utils.argmax(posterior)
         score += posterior[trgt_word]
     # Reset predictor states
     self.decoder.set_predictor_states(old_states)
     return -score
示例#12
0
 def _greedy_decode(self):
     """Performs greedy decoding from the start node. Used to obtain
     initial bigram statistics.
     """
     hypo = PartialHypothesis()
     hypos = []
     posteriors = []
     score_breakdowns = []
     bag = dict(self.full_bag)
     while bag:
         posterior,score_breakdown = self.apply_predictors()
         hypo.predictor_states = copy.deepcopy(self.get_predictor_states())
         bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos}
         bag_breakdown = {w: score_breakdown[w] 
                                     for w in self.full_bag_with_eos}
         posteriors.append(bag_posterior)
         score_breakdowns.append(bag_breakdown)
         hypos.append(hypo)
         best_word = utils.argmax({w: bag_posterior[w] for w in bag})
         bag[best_word] -= 1
         if bag[best_word] < 1:
             del bag[best_word]
         self.consume(best_word)
         hypo = hypo.expand(best_word,
                            None,
                            bag_posterior[best_word],
                            score_breakdown[best_word])
     posterior,score_breakdown = self.apply_predictors()
     hypo.predictor_states = copy.deepcopy(self.get_predictor_states())
     bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos}
     bag_breakdown = {w: score_breakdown[w] for w in self.full_bag_with_eos}
     posteriors.append(bag_posterior)
     score_breakdowns.append(bag_breakdown)
     hypos.append(hypo)
     
     hypo = hypo.cheap_expand(utils.EOS_ID,
                              bag_posterior[utils.EOS_ID],
                              score_breakdown[utils.EOS_ID])
     logging.debug("Greedy hypo (%f): %s" % (
                       hypo.score,
                       ' '.join([str(w) for w in hypo.trgt_sentence])))
     self._process_new_hypos(hypos, posteriors, score_breakdowns, hypo)
示例#13
0
 def _greedy_decode(self):
     """Performs greedy decoding from the start node. Used to obtain
     initial bigram statistics.
     """
     hypo = PartialHypothesis()
     hypos = []
     posteriors = []
     score_breakdowns = []
     bag = dict(self.full_bag)
     while bag:
         posterior,score_breakdown = self.apply_predictors()
         hypo.predictor_states = copy.deepcopy(self.get_predictor_states())
         bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos}
         bag_breakdown = {w: score_breakdown[w] 
                                     for w in self.full_bag_with_eos}
         posteriors.append(bag_posterior)
         score_breakdowns.append(bag_breakdown)
         hypos.append(hypo)
         best_word = utils.argmax({w: bag_posterior[w] for w in bag})
         bag[best_word] -= 1
         if bag[best_word] < 1:
             del bag[best_word]
         self.consume(best_word)
         hypo = hypo.expand(best_word,
                            None,
                            bag_posterior[best_word],
                            score_breakdown[best_word])
     posterior,score_breakdown = self.apply_predictors()
     hypo.predictor_states = copy.deepcopy(self.get_predictor_states())
     bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos}
     bag_breakdown = {w: score_breakdown[w] for w in self.full_bag_with_eos}
     posteriors.append(bag_posterior)
     score_breakdowns.append(bag_breakdown)
     hypos.append(hypo)
     
     hypo = hypo.cheap_expand(utils.EOS_ID,
                              bag_posterior[utils.EOS_ID],
                              score_breakdown[utils.EOS_ID])
     logging.debug("Greedy hypo (%f): %s" % (
                       hypo.score,
                       ' '.join([str(w) for w in hypo.trgt_sentence])))
     self._process_new_hypos(hypos, posteriors, score_breakdowns, hypo)
示例#14
0
文件: length.py 项目: ucam-smt/sgnmt
 def initialize(self, src_sentence):
     """Runs greedy decoding on the slave predictor to populate
     self.scores and self.unk_scores, resets the history.
     """
     self.slave_predictor.initialize(src_sentence)
     self.scores = []
     self.unk_scores = []
     trg_word = -1
     max_len = self.max_len_factor * len(src_sentence)
     l = 0
     while trg_word != utils.EOS_ID and l <= max_len:
         posterior = self.slave_predictor.predict_next()
         trg_word = utils.argmax(posterior)
         self.scores.append(posterior)
         self.unk_scores.append(self.slave_predictor.get_unk_probability(
             posterior))
         self.slave_predictor.consume(utils.UNK_ID)
         l += 1
     logging.debug("ngramize uses %d time steps." % l)
     self.history = []
     self.cur_unk_score = utils.NEG_INF
示例#15
0
 def initialize(self, src_sentence):
     """Runs greedy decoding on the slave predictor to populate
     self.scores and self.unk_scores, resets the history.
     """
     self.slave_predictor.initialize(src_sentence)
     self.scores = []
     self.unk_scores = []
     trg_word = -1
     max_len = self.max_len_factor * len(src_sentence)
     l = 0
     while trg_word != utils.EOS_ID and l <= max_len:
         posterior = self.slave_predictor.predict_next()
         trg_word = utils.argmax(posterior)
         self.scores.append(posterior)
         self.unk_scores.append(
             self.slave_predictor.get_unk_probability(posterior))
         self.slave_predictor.consume(utils.UNK_ID)
         l += 1
     logging.debug("ngramize uses %d time steps." % l)
     self.history = []
     self.cur_unk_score = utils.NEG_INF
示例#16
0
 def estimate_future_cost_with_cache(self, hypo):
     """Enabled cache... """
     cached_cost = self.cache.get(hypo.trgt_sentence)
     if not cached_cost is None:
         return cached_cost
     old_states = self.decoder.get_predictor_states()
     self.decoder.set_predictor_states(copy.deepcopy(old_states))
     # Greedy decoding
     trgt_word = hypo.trgt_sentence[-1]
     scores = []
     words = []
     while trgt_word != utils.EOS_ID:
         self.decoder.consume(trgt_word)
         posterior, _ = self.decoder.apply_predictors()
         trgt_word = utils.argmax(posterior)
         scores.append(posterior[trgt_word])
         words.append(trgt_word)
     # Update cache using scores and words
     for i in xrange(1, len(scores)):
         self.cache.add(hypo.trgt_sentence + words[:i], -sum(scores[i:]))
     # Reset predictor states
     self.decoder.set_predictor_states(old_states)
     return -sum(scores)
示例#17
0
 def estimate_future_cost_with_cache(self, hypo):
     """Enabled cache... """
     cached_cost = self.cache.get(hypo.trgt_sentence)
     if not cached_cost is None:
         return cached_cost
     old_states = self.decoder.get_predictor_states()
     self.decoder.set_predictor_states(copy.deepcopy(old_states))
     # Greedy decoding
     trgt_word = hypo.trgt_sentence[-1]
     scores = []
     words = []
     while trgt_word != utils.EOS_ID:
         self.decoder.consume(trgt_word)
         posterior,_ = self.decoder.apply_predictors()
         trgt_word = utils.argmax(posterior)
         scores.append(posterior[trgt_word])
         words.append(trgt_word)
     # Update cache using scores and words
     for i in xrange(1,len(scores)):
         self.cache.add(hypo.trgt_sentence + words[:i], -sum(scores[i:]))
     # Reset predictor states
     self.decoder.set_predictor_states(old_states)
     return -sum(scores)
示例#18
0
 def find_word_greedy(self, posterior):
     while not self.are_best_terminal(posterior):
         best_rule_id = utils.argmax(posterior)
         self.consume(best_rule_id)
         posterior = self.predict_next(predicting_next_word=True)
     return posterior
示例#19
0
文件: bow.py 项目: ucam-smt/sgnmt
 def greedy_decode(self, node, word, single_step):
     """Helper function for greedy decoding from a certain point in
     the search tree."""
     
     prev_hypo = node.hypo.expand(word,
                                  None,
                                  node.posterior[word],
                                  node.score_breakdown[word])
     prev_nodes = node.prev_nodes + [node]
     
     best_word = word
     while ((prev_hypo.score > self.best_score or not self.early_stopping)
            and best_word != utils.EOS_ID):
         self.consume(best_word)
         posterior,score_breakdown = self.apply_predictors()
         if len(posterior) < 1:
             return
         self._update_best_word_scores(posterior)
         best_word = utils.argmax(posterior)
         best_word_score = posterior[best_word]
         new_hypo = prev_hypo.expand(best_word,
                                     None,
                                     best_word_score,
                                     score_breakdown[best_word])
         logging.debug("Expanded hypo: len=%d score=%f prefix= %s" % (
                         len(new_hypo.trgt_sentence),
                         new_hypo.score,
                         ' '.join([str(w) for w in new_hypo.trgt_sentence])))
         node = BOWNode(prev_hypo, 
                        posterior, 
                        score_breakdown, 
                        list(prev_nodes))
         if not single_step:
             del node.active_arcs[best_word]
             if len(node.active_arcs) > 0:
                 prev_hypo.predictor_states = copy.deepcopy(
                                             self.get_predictor_states())
         else:
             prev_hypo.predictor_states = self.get_predictor_states()
         prev_nodes.append(node)
         prev_hypo = new_hypo
         if single_step:
             break
     full_hypo_score = prev_hypo.score
     if best_word == utils.EOS_ID: # Full hypo
         self.add_full_hypo(prev_hypo.generate_full_hypothesis())
         self.best_score = max(self.best_score, full_hypo_score)
     else:
         full_hypo_score = self._estimate_full_hypo_score(prev_hypo)
     # Update the heap
     sen = prev_hypo.trgt_sentence
     l = len(sen)
     for pos in xrange(l):
         worst_scores = {}
         for p in xrange(pos, l):
             worst_scores[sen[p]] = min(worst_scores.get(sen[p], 0.0),
                                        prev_nodes[p].posterior[sen[p]])
         node = prev_nodes[pos]
         for w in node.active_arcs:
             expected_score = self._estimate_full_hypo_score(
                             node.hypo.cheap_expand(w, 
                                                    node.posterior[w], 
                                                    node.score_breakdown[w]))
             self._add_to_heap(node, w, expected_score) 
示例#20
0
文件: parse.py 项目: ucam-smt/sgnmt
 def find_word_greedy(self, posterior):
     while not self.are_best_terminal(posterior):
         best_rule_id = utils.argmax(posterior)
         self.consume(best_rule_id)
         posterior = self.predict_next(predicting_next_word=True)
     return posterior