Ejemplo n.º 1
0
 def consume(self, word):
     """Updates the current nodes by searching for all nodes which
     are reachable from the current nodes by a path consisting of 
     any number of epsilons and exactly one ``word`` label. If there
     is no such arc, we set the predictor in an invalid state. In 
     this case, all subsequent ``predict_next`` calls will return 
     the empty set.
     
     Args:
         word (int): Word on an outgoing arc from the current node
     """
     d_unconsumed = {}
     # Collect distances to nodes reachable by word
     for weight, node in self.cur_nodes:
         for arc in self.cur_fst.arcs(node):
             if arc.olabel == word:
                 next_node = arc.nextstate
                 next_score = weight + self.weight_factor * w2f(arc.weight)
                 if d_unconsumed.get(next_node, utils.NEG_INF) < next_score:
                     d_unconsumed[next_node] = next_score
     # Subtract the word score from the last predict_next
     consumed_score = self.score_max_func(d_unconsumed.itervalues()) \
          if (word != utils.GO_ID or self.skip_bos_weight) else 0.0
     # Add epsilon reachable states
     self.cur_nodes = self._follow_eps({
         node: score - consumed_score
         for node, score in d_unconsumed.iteritems()
     })
Ejemplo n.º 2
0
 def _follow_eps(self, roots):
     """BFS to find nodes reachable from root through eps arcs. This
     traversal strategy is efficient if the triangle inquality holds 
     for weights in the graphs, i.e. for all vertices v1,v2,v3: 
     (v1,v2),(v2,v3),(v1,v3) in E => d(v1,v2)+d(v2,v3) >= d(v1,v3).
     The method still returns the correct results if the triangle
     inequality does not hold, but edges may be traversed multiple
     times which makes it more inefficient.
     """
     open_nodes = dict(roots)
     d = {}
     visited = dict(roots)
     while open_nodes:
         next_open = {}
         for node, score in open_nodes.iteritems():
             has_noneps = False
             for arc in self.cur_fst.arcs(node):
                 if arc.olabel == EPS_ID:
                     next_node = arc.nextstate
                     next_score = score + self.weight_factor * w2f(
                         arc.weight)
                     if visited.get(next_node, utils.NEG_INF) < next_score:
                         visited[next_node] = next_score
                         next_open[next_node] = next_score
                 else:
                     has_noneps = True
             if has_noneps:
                 d[node] = score
         open_nodes = next_open
     return [(weight, node) for node, weight in d.iteritems()]
Ejemplo n.º 3
0
 def _follow_eps(self, roots):
     """BFS to find nodes reachable from root through eps arcs. This
     traversal strategy is efficient if the triangle inquality holds 
     for weights in the graphs, i.e. for all vertices v1,v2,v3: 
     (v1,v2),(v2,v3),(v1,v3) in E => d(v1,v2)+d(v2,v3) >= d(v1,v3).
     The method still returns the correct results if the triangle
     inequality does not hold, but edges may be traversed multiple
     times which makes it more inefficient.
     """
     open_nodes = dict(roots)
     d = {}
     visited = dict(roots)
     while open_nodes:
         next_open = {}
         for node,score in open_nodes.iteritems():
             has_noneps = False
             for arc in self.cur_fst.arcs(node):
                 if arc.olabel == EPS_ID:
                     next_node = arc.nextstate
                     next_score = score + self.weight_factor*w2f(arc.weight)
                     if visited.get(next_node, utils.NEG_INF) < next_score:
                         visited[next_node] = next_score
                         next_open[next_node] = next_score
                 else:
                     has_noneps = True
             if has_noneps:
                 d[node] = score
         open_nodes = next_open
     return [(weight, node) for node, weight in d.iteritems()]
Ejemplo n.º 4
0
 def predict_next(self):
     """Uses the outgoing arcs from all current node to build up the
     scores for the next word. This method does not follow epsilon
     arcs: ``consume`` updates ``cur_nodes`` such that all reachable
     arcs with word ids are connected directly with a node in
     ``cur_nodes``. If there are multiple arcs with the same word,
     we use the log sum of the arc weights as score.
     
     Returns:
         dict. Set of words on outgoing arcs from the current node
         together with their scores, or an empty set if we currently
         have no active nodes or fst.
     """
     scores = {}
     for weight, node in self.cur_nodes:
         for arc in self.cur_fst.arcs(node):
             if arc.olabel != EPS_ID:
                 score = weight + self.weight_factor * w2f(arc.weight)
                 if arc.olabel in scores:
                     scores[arc.olabel] = self.score_max_func(
                         scores[arc.olabel], score)
                 else:
                     scores[arc.olabel] = score
     return self.finalize_posterior(scores, self.use_weights,
                                    self.normalize_scores)
Ejemplo n.º 5
0
 def consume(self, word):
     """Updates the current nodes by searching for all nodes which
     are reachable from the current nodes by a path consisting of 
     any number of epsilons and exactly one ``word`` label. If there
     is no such arc, we set the predictor in an invalid state. In 
     this case, all subsequent ``predict_next`` calls will return 
     the empty set.
     
     Args:
         word (int): Word on an outgoing arc from the current node
     """
     d_unconsumed = {}
     # Collect distances to nodes reachable by word
     for weight,node in self.cur_nodes:
         for arc in self.cur_fst.arcs(node):
             if arc.olabel == word:
                 next_node = arc.nextstate
                 next_score = weight + self.weight_factor*w2f(arc.weight)
                 if d_unconsumed.get(next_node, utils.NEG_INF) < next_score:
                     d_unconsumed[next_node] = next_score
     # Subtract the word score from the last predict_next 
     consumed_score = self.score_max_func(d_unconsumed.itervalues()) \
          if (word != utils.GO_ID or self.skip_bos_weight) else 0.0
     # Add epsilon reachable states
     self.cur_nodes = self._follow_eps({node: score - consumed_score
                 for node,score in d_unconsumed.iteritems()})
Ejemplo n.º 6
0
 def predict_next(self):
     """Uses the outgoing arcs from all current node to build up the
     scores for the next word. This method does not follow epsilon
     arcs: ``consume`` updates ``cur_nodes`` such that all reachable
     arcs with word ids are connected directly with a node in
     ``cur_nodes``. If there are multiple arcs with the same word,
     we use the log sum of the arc weights as score.
     
     Returns:
         dict. Set of words on outgoing arcs from the current node
         together with their scores, or an empty set if we currently
         have no active nodes or fst.
     """
     scores = {}
     for weight,node in self.cur_nodes:
         for arc in self.cur_fst.arcs(node): 
             if arc.olabel != EPS_ID:
                 score = weight + self.weight_factor*w2f(arc.weight) 
                 if arc.olabel in scores:
                     scores[arc.olabel] = self.score_max_func(
                                     scores[arc.olabel], score)
                 else:
                     scores[arc.olabel] = score 
     return self.finalize_posterior(scores,
             self.use_weights, self.normalize_scores)
Ejemplo n.º 7
0
 def consume(self, word):
     """Updates the current node by following the arc labelled with
     ``word``. If there is no such arc, we set ``cur_node`` to -1,
     indicating that the predictor is in an invalid state. In this
     case, all subsequent ``predict_next`` calls will return the
     empty set.
     
     Args:
         word (int): Word on an outgoing arc from the current node
     
     Returns:
         float. Weight on the traversed arc
     """
     if self.cur_node < 0:
         return
     from_state = self.cur_node
     self.cur_node = None
     unk_arc = None
     for arc in self.cur_fst.arcs(from_state):
         if arc.olabel == word:
             self.cur_node = arc.nextstate
             return self.weight_factor*w2f(arc.weight)
         elif arc.olabel == utils.UNK_ID:
             unk_arc = arc
     if unk_arc is not None:
         self.cur_node = unk_arc.nextstate
Ejemplo n.º 8
0
 def consume(self, word):
     """Updates the current node by following the arc labelled with
     ``word``. If there is no such arc, we set ``cur_node`` to -1,
     indicating that the predictor is in an invalid state. In this
     case, all subsequent ``predict_next`` calls will return the
     empty set.
     
     Args:
         word (int): Word on an outgoing arc from the current node
     
     Returns:
         float. Weight on the traversed arc
     """
     if self.cur_node < 0:
         return
     from_state = self.cur_node
     self.cur_node = None
     unk_arc = None
     for arc in self.cur_fst.arcs(from_state):
         if arc.olabel == word:
             self.cur_node = arc.nextstate
             return self.weight_factor * w2f(arc.weight)
         elif arc.olabel == utils.UNK_ID:
             unk_arc = arc
     if unk_arc is not None:
         self.cur_node = unk_arc.nextstate
Ejemplo n.º 9
0
 def estimate_future_cost(self, hypo):
     """The FST predictor comes with its own heuristic function. We
     use the shortest path in the fst as future cost estimator. """
     if not self.cur_node:
         return 0.0
     last_word = hypo.trgt_sentence[-1]
     for arc in self.cur_fst.arcs(self.cur_node):
         if arc.olabel == last_word:
             return w2f(self.distances[arc.nextstate])
     return 0.0
Ejemplo n.º 10
0
 def estimate_future_cost(self, hypo):
     """The FST predictor comes with its own heuristic function. We
     use the shortest path in the fst as future cost estimator. """
     if not self.cur_node:
         return 0.0
     last_word = hypo.trgt_sentence[-1]
     for arc in self.cur_fst.arcs(self.cur_node):
         if arc.olabel == last_word:
             return w2f(self.distances[arc.nextstate])
     return 0.0
Ejemplo n.º 11
0
 def estimate_future_cost(self, hypo):
     """The FST predictor comes with its own heuristic function. We
     use the shortest path in the fst as future cost estimator. """
     last_word = hypo.trgt_sentence[-1]
     dists = []
     for n in self.cur_nodes:
         for arc in self.cur_fst[n].arcs:
             if arc.olabel == last_word:
                 dists.append(w2f(self.distances[arc.nextstate]))
                 break
     return 0.0 if not dists else min(dists)
Ejemplo n.º 12
0
 def estimate_future_cost(self, hypo):
     """The FST predictor comes with its own heuristic function. We
     use the shortest path in the fst as future cost estimator. """
     last_word = hypo.trgt_sentence[-1]
     dists = []
     for n in self.cur_nodes:
         for arc in self.cur_fst[n].arcs:
             if arc.olabel == last_word:
                 dists.append(w2f(self.distances[arc.nextstate]))
                 break
     return 0.0 if not dists else min(dists)
Ejemplo n.º 13
0
 def predict_next(self):
     """Uses the outgoing arcs from the current node to build up the
     scores for the next word.
     
     Returns:
         dict. Set of words on outgoing arcs from the current node
         together with their scores, or an empty set if we currently
         have no active node or fst.
     """
     if self.cur_node < 0:
         return {}
     scores = {arc.olabel: self.weight_factor*w2f(arc.weight)
             for arc in self.cur_fst.arcs(self.cur_node)}
     if utils.EOS_ID in scores and self.add_bos_to_eos_score:
         scores[utils.EOS_ID] += self.bos_score
     return self.finalize_posterior(scores,
             self.use_weights, self.normalize_scores)
Ejemplo n.º 14
0
 def add_to_label_fst_map_recursive(self, 
                                    label_fst_map, 
                                    visited_nodes, 
                                    root_node, 
                                    acc_weight, 
                                    history, 
                                    func):
     """Adds arcs to ``label_fst_map`` if they are labeled with an
     NT symbol and reachable from ``root_node`` via ``history``.
       
     Note: visited_nodes is maintained for each history separately
     """
     if root_node in visited_nodes:
         # This introduces some error as we take the score of the first best
         # path with a certain history, not the globally best path. For now,
         # this error should not be significant
         return
     visited_nodes[root_node] = True
     for arc in self.cur_fst.arcs(root_node):
         arc_acc_weight = acc_weight + self.weight_factor*w2f(arc.weight)
         if arc.olabel == EPS_ID: # Follow epsilon edges
             self.add_to_label_fst_map_recursive(label_fst_map,
                                                 visited_nodes,
                                                 arc.nextstate,
                                                 arc_acc_weight, 
                                                 history,
                                                 func)
         elif not history:
             if self.is_nt_label(arc.olabel): # Add to label_fst_map
                 replace_label = len(label_fst_map) + 2000000000
                 label_fst_map[replace_label] = self.get_sub_fst(
                                                                 arc.olabel)
                 arc.ilabel = replace_label
                 arc.olabel = replace_label
             else: # This is a regular arc and we have no history left
                 func(arc.nextstate, arc.olabel, arc_acc_weight) # apply func
         elif arc.olabel == history[0]: # history is not empty
             self.add_to_label_fst_map_recursive(label_fst_map,
                                                 {},
                                                 arc.nextstate,
                                                 arc_acc_weight,
                                                 history[1:],
                                                 func)
         elif arc.olabel > history[0]: # FST is arc sorted, we can stop here
             break
Ejemplo n.º 15
0
 def predict_next(self):
     """Uses the outgoing arcs from the current node to build up the
     scores for the next word.
     
     Returns:
         dict. Set of words on outgoing arcs from the current node
         together with their scores, or an empty set if we currently
         have no active node or fst.
     """
     if not self.cur_node:
         return {}
     scores = {
         arc.olabel: self.weight_factor * w2f(arc.weight)
         for arc in self.cur_fst.arcs(self.cur_node)
     }
     if utils.EOS_ID in scores and self.add_bos_to_eos_score:
         scores[utils.EOS_ID] += self.bos_score
     return self.finalize_posterior(scores, self.use_weights,
                                    self.normalize_scores)
Ejemplo n.º 16
0
 def add_to_label_fst_map_recursive(self, label_fst_map, visited_nodes,
                                    root_node, acc_weight, history, func):
     """Adds arcs to ``label_fst_map`` if they are labeled with an
     NT symbol and reachable from ``root_node`` via ``history``.
       
     Note: visited_nodes is maintained for each history separately
     """
     if root_node in visited_nodes:
         # This introduces some error as we take the score of the first best
         # path with a certain history, not the globally best path. For now,
         # this error should not be significant
         return
     visited_nodes[root_node] = True
     for arc in self.cur_fst.arcs(root_node):
         arc_acc_weight = acc_weight + self.weight_factor * w2f(arc.weight)
         if arc.olabel == EPS_ID:  # Follow epsilon edges
             self.add_to_label_fst_map_recursive(label_fst_map,
                                                 visited_nodes,
                                                 arc.nextstate,
                                                 arc_acc_weight, history,
                                                 func)
         elif not history:
             if self.is_nt_label(arc.olabel):  # Add to label_fst_map
                 replace_label = len(label_fst_map) + 2000000000
                 label_fst_map[replace_label] = self.get_sub_fst(arc.olabel)
                 arc.ilabel = replace_label
                 arc.olabel = replace_label
             else:  # This is a regular arc and we have no history left
                 func(arc.nextstate, arc.olabel,
                      arc_acc_weight)  # apply func
         elif arc.olabel == history[0]:  # history is not empty
             self.add_to_label_fst_map_recursive(label_fst_map, {},
                                                 arc.nextstate,
                                                 arc_acc_weight,
                                                 history[1:], func)
         elif arc.olabel > history[0]:  # FST is arc sorted, we can stop here
             break