def consume(self, word): """Updates the current nodes by searching for all nodes which are reachable from the current nodes by a path consisting of any number of epsilons and exactly one ``word`` label. If there is no such arc, we set the predictor in an invalid state. In this case, all subsequent ``predict_next`` calls will return the empty set. Args: word (int): Word on an outgoing arc from the current node """ d_unconsumed = {} # Collect distances to nodes reachable by word for weight, node in self.cur_nodes: for arc in self.cur_fst.arcs(node): if arc.olabel == word: next_node = arc.nextstate next_score = weight + self.weight_factor * w2f(arc.weight) if d_unconsumed.get(next_node, utils.NEG_INF) < next_score: d_unconsumed[next_node] = next_score # Subtract the word score from the last predict_next consumed_score = self.score_max_func(d_unconsumed.itervalues()) \ if (word != utils.GO_ID or self.skip_bos_weight) else 0.0 # Add epsilon reachable states self.cur_nodes = self._follow_eps({ node: score - consumed_score for node, score in d_unconsumed.iteritems() })
def _follow_eps(self, roots): """BFS to find nodes reachable from root through eps arcs. This traversal strategy is efficient if the triangle inquality holds for weights in the graphs, i.e. for all vertices v1,v2,v3: (v1,v2),(v2,v3),(v1,v3) in E => d(v1,v2)+d(v2,v3) >= d(v1,v3). The method still returns the correct results if the triangle inequality does not hold, but edges may be traversed multiple times which makes it more inefficient. """ open_nodes = dict(roots) d = {} visited = dict(roots) while open_nodes: next_open = {} for node, score in open_nodes.iteritems(): has_noneps = False for arc in self.cur_fst.arcs(node): if arc.olabel == EPS_ID: next_node = arc.nextstate next_score = score + self.weight_factor * w2f( arc.weight) if visited.get(next_node, utils.NEG_INF) < next_score: visited[next_node] = next_score next_open[next_node] = next_score else: has_noneps = True if has_noneps: d[node] = score open_nodes = next_open return [(weight, node) for node, weight in d.iteritems()]
def _follow_eps(self, roots): """BFS to find nodes reachable from root through eps arcs. This traversal strategy is efficient if the triangle inquality holds for weights in the graphs, i.e. for all vertices v1,v2,v3: (v1,v2),(v2,v3),(v1,v3) in E => d(v1,v2)+d(v2,v3) >= d(v1,v3). The method still returns the correct results if the triangle inequality does not hold, but edges may be traversed multiple times which makes it more inefficient. """ open_nodes = dict(roots) d = {} visited = dict(roots) while open_nodes: next_open = {} for node,score in open_nodes.iteritems(): has_noneps = False for arc in self.cur_fst.arcs(node): if arc.olabel == EPS_ID: next_node = arc.nextstate next_score = score + self.weight_factor*w2f(arc.weight) if visited.get(next_node, utils.NEG_INF) < next_score: visited[next_node] = next_score next_open[next_node] = next_score else: has_noneps = True if has_noneps: d[node] = score open_nodes = next_open return [(weight, node) for node, weight in d.iteritems()]
def predict_next(self): """Uses the outgoing arcs from all current node to build up the scores for the next word. This method does not follow epsilon arcs: ``consume`` updates ``cur_nodes`` such that all reachable arcs with word ids are connected directly with a node in ``cur_nodes``. If there are multiple arcs with the same word, we use the log sum of the arc weights as score. Returns: dict. Set of words on outgoing arcs from the current node together with their scores, or an empty set if we currently have no active nodes or fst. """ scores = {} for weight, node in self.cur_nodes: for arc in self.cur_fst.arcs(node): if arc.olabel != EPS_ID: score = weight + self.weight_factor * w2f(arc.weight) if arc.olabel in scores: scores[arc.olabel] = self.score_max_func( scores[arc.olabel], score) else: scores[arc.olabel] = score return self.finalize_posterior(scores, self.use_weights, self.normalize_scores)
def consume(self, word): """Updates the current nodes by searching for all nodes which are reachable from the current nodes by a path consisting of any number of epsilons and exactly one ``word`` label. If there is no such arc, we set the predictor in an invalid state. In this case, all subsequent ``predict_next`` calls will return the empty set. Args: word (int): Word on an outgoing arc from the current node """ d_unconsumed = {} # Collect distances to nodes reachable by word for weight,node in self.cur_nodes: for arc in self.cur_fst.arcs(node): if arc.olabel == word: next_node = arc.nextstate next_score = weight + self.weight_factor*w2f(arc.weight) if d_unconsumed.get(next_node, utils.NEG_INF) < next_score: d_unconsumed[next_node] = next_score # Subtract the word score from the last predict_next consumed_score = self.score_max_func(d_unconsumed.itervalues()) \ if (word != utils.GO_ID or self.skip_bos_weight) else 0.0 # Add epsilon reachable states self.cur_nodes = self._follow_eps({node: score - consumed_score for node,score in d_unconsumed.iteritems()})
def predict_next(self): """Uses the outgoing arcs from all current node to build up the scores for the next word. This method does not follow epsilon arcs: ``consume`` updates ``cur_nodes`` such that all reachable arcs with word ids are connected directly with a node in ``cur_nodes``. If there are multiple arcs with the same word, we use the log sum of the arc weights as score. Returns: dict. Set of words on outgoing arcs from the current node together with their scores, or an empty set if we currently have no active nodes or fst. """ scores = {} for weight,node in self.cur_nodes: for arc in self.cur_fst.arcs(node): if arc.olabel != EPS_ID: score = weight + self.weight_factor*w2f(arc.weight) if arc.olabel in scores: scores[arc.olabel] = self.score_max_func( scores[arc.olabel], score) else: scores[arc.olabel] = score return self.finalize_posterior(scores, self.use_weights, self.normalize_scores)
def consume(self, word): """Updates the current node by following the arc labelled with ``word``. If there is no such arc, we set ``cur_node`` to -1, indicating that the predictor is in an invalid state. In this case, all subsequent ``predict_next`` calls will return the empty set. Args: word (int): Word on an outgoing arc from the current node Returns: float. Weight on the traversed arc """ if self.cur_node < 0: return from_state = self.cur_node self.cur_node = None unk_arc = None for arc in self.cur_fst.arcs(from_state): if arc.olabel == word: self.cur_node = arc.nextstate return self.weight_factor*w2f(arc.weight) elif arc.olabel == utils.UNK_ID: unk_arc = arc if unk_arc is not None: self.cur_node = unk_arc.nextstate
def consume(self, word): """Updates the current node by following the arc labelled with ``word``. If there is no such arc, we set ``cur_node`` to -1, indicating that the predictor is in an invalid state. In this case, all subsequent ``predict_next`` calls will return the empty set. Args: word (int): Word on an outgoing arc from the current node Returns: float. Weight on the traversed arc """ if self.cur_node < 0: return from_state = self.cur_node self.cur_node = None unk_arc = None for arc in self.cur_fst.arcs(from_state): if arc.olabel == word: self.cur_node = arc.nextstate return self.weight_factor * w2f(arc.weight) elif arc.olabel == utils.UNK_ID: unk_arc = arc if unk_arc is not None: self.cur_node = unk_arc.nextstate
def estimate_future_cost(self, hypo): """The FST predictor comes with its own heuristic function. We use the shortest path in the fst as future cost estimator. """ if not self.cur_node: return 0.0 last_word = hypo.trgt_sentence[-1] for arc in self.cur_fst.arcs(self.cur_node): if arc.olabel == last_word: return w2f(self.distances[arc.nextstate]) return 0.0
def estimate_future_cost(self, hypo): """The FST predictor comes with its own heuristic function. We use the shortest path in the fst as future cost estimator. """ if not self.cur_node: return 0.0 last_word = hypo.trgt_sentence[-1] for arc in self.cur_fst.arcs(self.cur_node): if arc.olabel == last_word: return w2f(self.distances[arc.nextstate]) return 0.0
def estimate_future_cost(self, hypo): """The FST predictor comes with its own heuristic function. We use the shortest path in the fst as future cost estimator. """ last_word = hypo.trgt_sentence[-1] dists = [] for n in self.cur_nodes: for arc in self.cur_fst[n].arcs: if arc.olabel == last_word: dists.append(w2f(self.distances[arc.nextstate])) break return 0.0 if not dists else min(dists)
def estimate_future_cost(self, hypo): """The FST predictor comes with its own heuristic function. We use the shortest path in the fst as future cost estimator. """ last_word = hypo.trgt_sentence[-1] dists = [] for n in self.cur_nodes: for arc in self.cur_fst[n].arcs: if arc.olabel == last_word: dists.append(w2f(self.distances[arc.nextstate])) break return 0.0 if not dists else min(dists)
def predict_next(self): """Uses the outgoing arcs from the current node to build up the scores for the next word. Returns: dict. Set of words on outgoing arcs from the current node together with their scores, or an empty set if we currently have no active node or fst. """ if self.cur_node < 0: return {} scores = {arc.olabel: self.weight_factor*w2f(arc.weight) for arc in self.cur_fst.arcs(self.cur_node)} if utils.EOS_ID in scores and self.add_bos_to_eos_score: scores[utils.EOS_ID] += self.bos_score return self.finalize_posterior(scores, self.use_weights, self.normalize_scores)
def add_to_label_fst_map_recursive(self, label_fst_map, visited_nodes, root_node, acc_weight, history, func): """Adds arcs to ``label_fst_map`` if they are labeled with an NT symbol and reachable from ``root_node`` via ``history``. Note: visited_nodes is maintained for each history separately """ if root_node in visited_nodes: # This introduces some error as we take the score of the first best # path with a certain history, not the globally best path. For now, # this error should not be significant return visited_nodes[root_node] = True for arc in self.cur_fst.arcs(root_node): arc_acc_weight = acc_weight + self.weight_factor*w2f(arc.weight) if arc.olabel == EPS_ID: # Follow epsilon edges self.add_to_label_fst_map_recursive(label_fst_map, visited_nodes, arc.nextstate, arc_acc_weight, history, func) elif not history: if self.is_nt_label(arc.olabel): # Add to label_fst_map replace_label = len(label_fst_map) + 2000000000 label_fst_map[replace_label] = self.get_sub_fst( arc.olabel) arc.ilabel = replace_label arc.olabel = replace_label else: # This is a regular arc and we have no history left func(arc.nextstate, arc.olabel, arc_acc_weight) # apply func elif arc.olabel == history[0]: # history is not empty self.add_to_label_fst_map_recursive(label_fst_map, {}, arc.nextstate, arc_acc_weight, history[1:], func) elif arc.olabel > history[0]: # FST is arc sorted, we can stop here break
def predict_next(self): """Uses the outgoing arcs from the current node to build up the scores for the next word. Returns: dict. Set of words on outgoing arcs from the current node together with their scores, or an empty set if we currently have no active node or fst. """ if not self.cur_node: return {} scores = { arc.olabel: self.weight_factor * w2f(arc.weight) for arc in self.cur_fst.arcs(self.cur_node) } if utils.EOS_ID in scores and self.add_bos_to_eos_score: scores[utils.EOS_ID] += self.bos_score return self.finalize_posterior(scores, self.use_weights, self.normalize_scores)
def add_to_label_fst_map_recursive(self, label_fst_map, visited_nodes, root_node, acc_weight, history, func): """Adds arcs to ``label_fst_map`` if they are labeled with an NT symbol and reachable from ``root_node`` via ``history``. Note: visited_nodes is maintained for each history separately """ if root_node in visited_nodes: # This introduces some error as we take the score of the first best # path with a certain history, not the globally best path. For now, # this error should not be significant return visited_nodes[root_node] = True for arc in self.cur_fst.arcs(root_node): arc_acc_weight = acc_weight + self.weight_factor * w2f(arc.weight) if arc.olabel == EPS_ID: # Follow epsilon edges self.add_to_label_fst_map_recursive(label_fst_map, visited_nodes, arc.nextstate, arc_acc_weight, history, func) elif not history: if self.is_nt_label(arc.olabel): # Add to label_fst_map replace_label = len(label_fst_map) + 2000000000 label_fst_map[replace_label] = self.get_sub_fst(arc.olabel) arc.ilabel = replace_label arc.olabel = replace_label else: # This is a regular arc and we have no history left func(arc.nextstate, arc.olabel, arc_acc_weight) # apply func elif arc.olabel == history[0]: # history is not empty self.add_to_label_fst_map_recursive(label_fst_map, {}, arc.nextstate, arc_acc_weight, history[1:], func) elif arc.olabel > history[0]: # FST is arc sorted, we can stop here break