def get_target_node(source, target, rules): """Return node corresponding to parses for target, or None.""" tokens = source.split(" ") def node_fn(span_begin, span_end, rule, children): target_string = qcfg_rule.apply_target( rule, [node.target_string for node in children]) return RuleApplicationNode(rule, children, span_begin, span_end, target_string) def postprocess_fn(nodes): nodes = filter_nodes(nodes, target) return _aggregate(nodes) nodes = qcfg_parser.parse(tokens, rules, node_fn=node_fn, postprocess_cell_fn=postprocess_fn) # Filter for nodes where target_string matches target exactly. ret_nodes = [] for node in nodes: if node.target_string == target: ret_nodes.append(node) if not ret_nodes: return None if len(ret_nodes) > 1: raise ValueError return ret_nodes[0]
def run_inference(source, rules, score_fn): """Determine one-best parse using score_fn. Args: source: Input string. rules: Set of QCFGRules. score_fn: Function with inputs (rule, span_begin, span_end) and returns float score for a given anchored rule application. Note that `span_begin` and `span_end` refer to token indexes, where span_end is exclusive, and `rule` is a QCFGRule. Returns: (target string, score) for highest scoring derivation, or (None, None) if there is no derivation for given source. """ tokens = source.split(" ") node_fn = get_node_fn(score_fn) nodes = qcfg_parser.parse(tokens, rules, node_fn=node_fn, postprocess_cell_fn=postprocess_cell_fn) if not nodes: return None, None if len(nodes) > 1: raise ValueError("Multiple nodes returned for inference: %s" % nodes) return nodes[0].target_string, nodes[0].score
def test_parse_flat(self): tokens = ["dax", "twice"] rules = [ qcfg_rule.rule_from_string("dax twice ### DAX TWICE"), ] parses = qcfg_parser.parse(tokens, rules, _node_fn, _postprocess_cell_fn) self.assertEqual(parses, ["DAX TWICE"])
def test_parse(self): tokens = ["dax", "twice"] rules = [ qcfg_rule.rule_from_string("dax ### DAX"), qcfg_rule.rule_from_string("NT_1 twice ### NT_1 NT_1"), ] parses = qcfg_parser.parse(tokens, rules, _node_fn, _postprocess_cell_fn) self.assertEqual(parses, ["DAX DAX"])
def _get_num_all_derivations(source, rules, verbose): """Return total number of derivations for any target.""" def node_fn(unused_span_begin, unused_span_end, unused_rule, children): """Represent nodes as integer counts of possible derivations.""" return _aggregate_counts(children) def postprocess_fn(nodes): """Merge and sum all nodes.""" return [sum(nodes)] outputs = qcfg_parser.parse(source, rules, node_fn=node_fn, postprocess_cell_fn=postprocess_fn, verbose=verbose) if len(outputs) != 1: raise ValueError num_outputs = outputs[0] return num_outputs
def _get_num_target_derivations(source, target, rules, verbose): """Return number of derivations of target.""" goal_target_string = " ".join(target) def node_fn(unused_span_begin, unused_span_end, rule, children): """Represent nodes as (target string, int count of possible derivations).""" target_strings = [target_string for target_string, _ in children] new_target_string = qcfg_rule.apply_target(rule, target_strings) child_counts = [child_count for _, child_count in children] count = _aggregate_counts(child_counts) return (new_target_string, count) def postprocess_fn(nodes): """Discard nodes that cannot reach goal and aggregate counts.""" counts_dict = collections.defaultdict(int) for target_string, count in nodes: # Discard any targets that are not substrings of goal target. if target_string not in goal_target_string: continue counts_dict[target_string] += count return [(target_string, count) for target_string, count in counts_dict.items()] outputs = qcfg_parser.parse(source, rules, node_fn=node_fn, postprocess_cell_fn=postprocess_fn, verbose=verbose) for target_string, count in outputs: if target_string == goal_target_string: return count raise ValueError("No target derivation for example (%s, %s)" % (source, target))
def get_merged_node(source, rules): """Return node corresponding to all parses.""" tokens = source.split(" ") def node_fn(span_begin, span_end, rule, children): # Target string is ignored for this case. target_string = None return RuleApplicationNode(rule, children, span_begin, span_end, target_string) def postprocess_fn(nodes): if len(nodes) > 1: return [AggregationNode(nodes)] else: return nodes nodes = qcfg_parser.parse(tokens, rules, node_fn=node_fn, postprocess_cell_fn=postprocess_fn) if len(nodes) != 1: raise ValueError("example `%s` len(nodes) != 1: %s" % (source, nodes)) return nodes[0]