def get_target_node(source, target, rules):
    """Return node corresponding to parses for target, or None."""
    tokens = source.split(" ")

    def node_fn(span_begin, span_end, rule, children):
        target_string = qcfg_rule.apply_target(
            rule, [node.target_string for node in children])
        return RuleApplicationNode(rule, children, span_begin, span_end,
                                   target_string)

    def postprocess_fn(nodes):
        nodes = filter_nodes(nodes, target)
        return _aggregate(nodes)

    nodes = qcfg_parser.parse(tokens,
                              rules,
                              node_fn=node_fn,
                              postprocess_cell_fn=postprocess_fn)

    # Filter for nodes where target_string matches target exactly.
    ret_nodes = []
    for node in nodes:
        if node.target_string == target:
            ret_nodes.append(node)

    if not ret_nodes:
        return None

    if len(ret_nodes) > 1:
        raise ValueError

    return ret_nodes[0]
def run_inference(source, rules, score_fn):
    """Determine one-best parse using score_fn.

  Args:
    source: Input string.
    rules: Set of QCFGRules.
    score_fn: Function with inputs (rule, span_begin, span_end) and returns
      float score for a given anchored rule application. Note that `span_begin`
      and `span_end` refer to token indexes, where span_end is exclusive, and
      `rule` is a QCFGRule.

  Returns:
    (target string, score) for highest scoring derivation, or (None, None)
    if there is no derivation for given source.
  """
    tokens = source.split(" ")
    node_fn = get_node_fn(score_fn)
    nodes = qcfg_parser.parse(tokens,
                              rules,
                              node_fn=node_fn,
                              postprocess_cell_fn=postprocess_cell_fn)

    if not nodes:
        return None, None

    if len(nodes) > 1:
        raise ValueError("Multiple nodes returned for inference: %s" % nodes)

    return nodes[0].target_string, nodes[0].score
Example #3
0
 def test_parse_flat(self):
   tokens = ["dax", "twice"]
   rules = [
       qcfg_rule.rule_from_string("dax twice ### DAX TWICE"),
   ]
   parses = qcfg_parser.parse(tokens, rules, _node_fn, _postprocess_cell_fn)
   self.assertEqual(parses, ["DAX TWICE"])
Example #4
0
 def test_parse(self):
   tokens = ["dax", "twice"]
   rules = [
       qcfg_rule.rule_from_string("dax ### DAX"),
       qcfg_rule.rule_from_string("NT_1 twice ### NT_1 NT_1"),
   ]
   parses = qcfg_parser.parse(tokens, rules, _node_fn, _postprocess_cell_fn)
   self.assertEqual(parses, ["DAX DAX"])
Example #5
0
def _get_num_all_derivations(source, rules, verbose):
    """Return total number of derivations for any target."""
    def node_fn(unused_span_begin, unused_span_end, unused_rule, children):
        """Represent nodes as integer counts of possible derivations."""
        return _aggregate_counts(children)

    def postprocess_fn(nodes):
        """Merge and sum all nodes."""
        return [sum(nodes)]

    outputs = qcfg_parser.parse(source,
                                rules,
                                node_fn=node_fn,
                                postprocess_cell_fn=postprocess_fn,
                                verbose=verbose)

    if len(outputs) != 1:
        raise ValueError
    num_outputs = outputs[0]
    return num_outputs
Example #6
0
def _get_num_target_derivations(source, target, rules, verbose):
    """Return number of derivations of target."""
    goal_target_string = " ".join(target)

    def node_fn(unused_span_begin, unused_span_end, rule, children):
        """Represent nodes as (target string, int count of possible derivations)."""
        target_strings = [target_string for target_string, _ in children]
        new_target_string = qcfg_rule.apply_target(rule, target_strings)

        child_counts = [child_count for _, child_count in children]

        count = _aggregate_counts(child_counts)

        return (new_target_string, count)

    def postprocess_fn(nodes):
        """Discard nodes that cannot reach goal and aggregate counts."""
        counts_dict = collections.defaultdict(int)
        for target_string, count in nodes:
            # Discard any targets that are not substrings of goal target.
            if target_string not in goal_target_string:
                continue
            counts_dict[target_string] += count
        return [(target_string, count)
                for target_string, count in counts_dict.items()]

    outputs = qcfg_parser.parse(source,
                                rules,
                                node_fn=node_fn,
                                postprocess_cell_fn=postprocess_fn,
                                verbose=verbose)

    for target_string, count in outputs:
        if target_string == goal_target_string:
            return count

    raise ValueError("No target derivation for example (%s, %s)" %
                     (source, target))
def get_merged_node(source, rules):
    """Return node corresponding to all parses."""
    tokens = source.split(" ")

    def node_fn(span_begin, span_end, rule, children):
        # Target string is ignored for this case.
        target_string = None
        return RuleApplicationNode(rule, children, span_begin, span_end,
                                   target_string)

    def postprocess_fn(nodes):
        if len(nodes) > 1:
            return [AggregationNode(nodes)]
        else:
            return nodes

    nodes = qcfg_parser.parse(tokens,
                              rules,
                              node_fn=node_fn,
                              postprocess_cell_fn=postprocess_fn)
    if len(nodes) != 1:
        raise ValueError("example `%s` len(nodes) != 1: %s" % (source, nodes))

    return nodes[0]