Ejemplo n.º 1
0
  def test_parse_3(self):
    # NT -> FOO NT
    rule_1 = cfg_rule.CFGRule(
        idx=0,
        lhs=NT,
        rhs=(
            cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL),
            cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
        ))

    input_symbols = [
        cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL),
        cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL),
        cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
    ]

    parses = cfg_parser.parse_symbols(
        input_symbols, [rule_1], {NT},
        {NT},
        _populate_fn,
        _postprocess_fn,
        verbose=True)
    self.assertLen(parses, 1)
    parse_node = parses[0]
    self.assertEqual(parse_node, [(0, 3, 0), (1, 3, 0)])
Ejemplo n.º 2
0
  def test_parse_1(self):
    # NT -> BAR
    rule_1 = cfg_rule.CFGRule(
        idx=0, lhs=NT, rhs=(cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),))

    # NT -> FOO NT
    rule_2 = cfg_rule.CFGRule(
        idx=1,
        lhs=NT,
        rhs=(
            cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL),
            cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
        ))

    input_ids = [FOO, FOO, BAR]

    parses = cfg_parser.parse(
        input_ids, [rule_1, rule_2], {NT},
        {NT},
        _populate_fn,
        _postprocess_fn,
        verbose=True)
    self.assertLen(parses, 1)
    parse_node = parses[0]
    self.assertEqual(parse_node, [(0, 3, 1), (1, 3, 1), (2, 3, 0)])
Ejemplo n.º 3
0
    def can_parse(self, tokens, verbose=False):
        """Return True if can be parsed given target CFG."""
        input_symbols = []
        terminal_ids = set()
        for token in tokens:
            if qcfg_rule.is_nt(token):
                idx = self.converter.nonterminals_to_ids[PLACEHOLDER_NT]
                input_symbols.append(
                    cfg_rule.CFGSymbol(idx, cfg_rule.NON_TERMINAL))
            else:
                if token not in self.converter.terminals_to_ids:
                    if verbose:
                        print(
                            "token `%s` not in `converter.terminals_to_ids`: %s"
                            % (token, self.converter.terminals_to_ids))
                    return False
                idx = self.converter.terminals_to_ids[token]
                terminal_ids.add(idx)
                input_symbols.append(cfg_rule.CFGSymbol(
                    idx, cfg_rule.TERMINAL))

        # Filter rules that contain terminals not in the input.
        def should_include(parser_rule):
            for symbol in parser_rule.rhs:
                if symbol.type == cfg_rule.TERMINAL and symbol.idx not in terminal_ids:
                    return False
            return True

        filtered_rules = [
            rule for rule in self.parser_rules if should_include(rule)
        ]
        if verbose:
            print("filtered_rules:")
            for rule in filtered_rules:
                print(rule)

        def populate_fn(unused_span_begin, unused_span_end, unused_parser_rule,
                        unused_children):
            return [True]

        nonterminals = set(self.converter.nonterminals_to_ids.values())
        parses = cfg_parser.parse_symbols(input_symbols,
                                          filtered_rules,
                                          nonterminals,
                                          nonterminals,
                                          populate_fn,
                                          postprocess_fn=None,
                                          max_single_nt_applications=2,
                                          verbose=verbose)
        if parses:
            return True
        else:
            return False
Ejemplo n.º 4
0
def can_parse(target_string,
              rules,
              max_single_nt_applications=2,
              verbose=False):
  """Returns True if there exists >=1 parse of target_string given rules."""
  tokens = target_string.split(" ")

  # Convert rules.
  converter = cfg_converter.CFGRuleConverter()
  parser_rules = []
  for rule_idx, rule in enumerate(rules):
    parser_rule = converter.convert_to_cfg_rule(
        lhs=rule.lhs,
        rhs=rule.rhs.split(" "),
        rule_idx=rule_idx,
        nonterminal_prefix=NON_TERMINAL_PREFIX,
        allowed_terminals=set(tokens))
    if parser_rule:
      parser_rules.append(parser_rule)

  start_idx = converter.nonterminals_to_ids[ROOT_SYMBOL]
  nonterminals = converter.nonterminals_to_ids.values()

  input_symbols = []
  for token in tokens:
    if token.startswith(NON_TERMINAL_PREFIX):
      idx = converter.nonterminals_to_ids[token[len(NON_TERMINAL_PREFIX):]]
      input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.NON_TERMINAL))
    else:
      if token not in converter.terminals_to_ids:
        return False
      idx = converter.terminals_to_ids[token]
      input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.TERMINAL))

  # Run parser.
  parses = cfg_parser.parse_symbols(
      input_symbols,
      parser_rules,
      nonterminals, {start_idx},
      _populate_fn,
      _postprocess_fn,
      verbose=verbose,
      max_single_nt_applications=max_single_nt_applications)

  if parses:
    return True
  else:
    return False
Ejemplo n.º 5
0
def sample(parser_rules,
           start_idx,
           rule_values=None,
           max_recursion=1,
           nonterminal_coef=1,
           verbose=False):
    """Sample data from CFG.

  Args:
    parser_rules: A list of CFGRule instances.
    start_idx: Index of non-terminal that is start symbol.
    rule_values: A optional list of rules with 1-1 mapping to parser rules. The
      rule values can be target grammar, QCFG rules, strings, etc. Only used
      when verbose is True for debugging purpose.
    max_recursion: The maximum number of recursion depth of applying CFG rules.
    nonterminal_coef: The scaling coefficient for rules with nonterminals.
    verbose: Print debug logging if True.

  Returns:
    A nested list of CFGRule instances.
  """
    nonterminals_to_rules = collections.defaultdict(list)
    for rule in parser_rules:
        nonterminals_to_rules[rule.lhs].append(rule)

    def expand_nonterminal(nonterminal, recursion=0):
        rules_to_sample = nonterminals_to_rules[nonterminal.idx]

        if recursion == max_recursion:
            # Filter out the rules that have NTs on RHS.
            rules_to_sample_no_nts = [
                rule for rule in nonterminals_to_rules[nonterminal.idx]
                if cfg_rule.get_arity(rule) == 0
            ]
            # If there are no rules for this NT that contain no NTs, then keep
            # recursing. In this case, we may exceed `max_recursion`.
            if rules_to_sample_no_nts:
                rules_to_sample = rules_to_sample_no_nts
        sampled_rule = sample_rule(rules_to_sample,
                                   nonterminal_coef=nonterminal_coef)
        if verbose and rule_values is not None:
            print("Recursion %d, Sampled rule: %s" %
                  (recursion, rule_values[sampled_rule.idx]))

        output = [sampled_rule]
        for symbol in sampled_rule.rhs:
            if symbol.type == cfg_rule.NON_TERMINAL:
                output.append(
                    expand_nonterminal(symbol, recursion=recursion + 1))
        return output

    start_symbol = cfg_rule.CFGSymbol(idx=start_idx,
                                      type=cfg_rule.NON_TERMINAL)
    output = expand_nonterminal(start_symbol, recursion=0)
    return output
Ejemplo n.º 6
0
    def test_sample(self, mock_random):
        mock_random.return_value = [0]
        # NT -> FOO NT
        rule_1 = cfg_rule.CFGRule(
            idx=0,
            lhs=NT,
            rhs=(
                cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL),
                cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
            ))
        # NT -> BAR
        rule_2 = cfg_rule.CFGRule(idx=1,
                                  lhs=NT,
                                  rhs=(cfg_rule.CFGSymbol(
                                      BAR, cfg_rule.TERMINAL), ))

        output = cfg_sampler.sample([rule_1, rule_2],
                                    NT,
                                    rule_values=[rule_1, rule_2],
                                    verbose=True)
        self.assertEqual(output, [rule_1, [rule_2]])
Ejemplo n.º 7
0
  def test_parse_6(self):
    # NT -> NT_2 BAR
    rule_1 = cfg_rule.CFGRule(
        idx=0,
        lhs=NT,
        rhs=(
            cfg_rule.CFGSymbol(NT_2, cfg_rule.NON_TERMINAL),
            cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),
        ))

    # NT_2 -> NT BAR
    rule_2 = cfg_rule.CFGRule(
        idx=1,
        lhs=NT_2,
        rhs=(
            cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
            cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),
        ))

    input_symbols = [
        cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
        cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),
    ]
    parses = cfg_parser.parse_symbols(
        input_symbols, [rule_1, rule_2], {NT, NT_2},
        {NT, NT_2},
        _populate_fn,
        _postprocess_fn,
        verbose=True)
    self.assertLen(parses, 1)
    parse_node = parses[0]
    self.assertEqual(parse_node, [(0, 2, 1)])
Ejemplo n.º 8
0
  def test_parse_4(self):
    # NT -> BAR
    rule_1 = cfg_rule.CFGRule(
        idx=0, lhs=NT, rhs=(cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),))

    # NT -> NT FOO NT
    rule_2 = cfg_rule.CFGRule(
        idx=1,
        lhs=NT,
        rhs=(
            cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
            cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL),
            cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
        ))

    # NT -> NT FOO BAR
    rule_3 = cfg_rule.CFGRule(
        idx=2,
        lhs=NT,
        rhs=(
            cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
            cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL),
            cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),
        ))

    input_symbols = [
        cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL),
        cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL),
        cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),
    ]

    parses = cfg_parser.parse_symbols(
        input_symbols, [rule_1, rule_2, rule_3], {NT},
        {NT},
        _populate_fn,
        _postprocess_fn,
        verbose=True)
    self.assertLen(parses, 2)
    self.assertEqual(parses, [[(0, 3, 2)], [(0, 3, 1), (2, 3, 0)]])
Ejemplo n.º 9
0
    def convert_to_cfg_rule(self,
                            lhs,
                            rhs,
                            rule_idx,
                            nonterminal_prefix,
                            allowed_terminals=None):
        """Convert symbol strings to CFGRule.

    Args:
      lhs: String symbol for LHS.
      rhs: List of string symbols for RHS.
      rule_idx: Integer index for rule.
      nonterminal_prefix: String prefix for nonterminal symbols in `rhs`.
      allowed_terminals: If set, returns None if rhs contains terminals not in
        this set.

    Returns:
      A CFGRule.
    """
        rhs_symbols = []
        lhs_idx = self._get_nonterminal_id(lhs)
        for token in rhs:
            if token.startswith(nonterminal_prefix):
                symbol_idx = self._get_nonterminal_id(
                    token[len(nonterminal_prefix):])
                rhs_symbols.append(
                    cfg_rule.CFGSymbol(idx=symbol_idx,
                                       type=cfg_rule.NON_TERMINAL))
            else:
                if allowed_terminals and token not in allowed_terminals:
                    return None
                symbol_idx = self._get_terminal_id(token)
                rhs_symbols.append(
                    cfg_rule.CFGSymbol(idx=symbol_idx, type=cfg_rule.TERMINAL))
        rule = cfg_rule.CFGRule(idx=rule_idx, lhs=lhs_idx, rhs=rhs_symbols)
        return rule
Ejemplo n.º 10
0
def parse(input_ids,
          rules,
          nonterminals,
          start_idx_set,
          populate_fn,
          postprocess_fn,
          max_single_nt_applications=1,
          verbose=False):
    """Run bottom up parser where all inputs are terminals."""
    input_symbols = tuple(
        [cfg_rule.CFGSymbol(idx, cfg_rule.TERMINAL) for idx in input_ids])
    return parse_symbols(input_symbols,
                         rules,
                         nonterminals,
                         start_idx_set,
                         populate_fn,
                         postprocess_fn,
                         max_single_nt_applications=max_single_nt_applications,
                         verbose=verbose)
Ejemplo n.º 11
0
def parse_symbols(input_symbols,
                  rules,
                  nonterminals,
                  start_idx_set,
                  populate_fn,
                  postprocess_fn,
                  max_single_nt_applications=1,
                  verbose=False):
    """Run bottom up parser.

  Let T be an arbitrary type for chart entries, specified by the return type
  of populate_fn. Examples for T are simple types that simply indicate presenece
  of a parse for a given span, or more complex structures that represent
  parse forests.

  Args:
    input_symbols: List of CFGSymbols in rules.
    rules: A list of CFGRule instances.
    nonterminals: Collection of CFGSymbol objects for possible non-terminals.
    start_idx_set: A set of index of non-terminal that is start symbol.
    populate_fn: A function that takes: (span_begin (Interger), span_end
      (Integer), parser_rule (CFGRule), substitutions (List of T)) and returns a
      list of objects of type T, which can be any type. These objects are added
      to the chart. Depending on what information is desired about completed
      parses, T can be anything from a simple count to a complex parse forest
      object.
    postprocess_fn: A function that takes and returns a list of T. This function
      post-processes each cell after it has been populated. This function is
      useful for pruning the chart, or merging equivalent entries. Ignored if
      None.
    max_single_nt_applications: The maximum number of times a rule where the
      RHS is a single nonterminal symbol can be applied consecutively.
    verbose: Print debug logging if True.

  Returns:
    A list of T.
  """
    input_len = len(input_symbols)

    # Initialize the empty chart.
    chart = Chart(populate_fn, postprocess_fn)

    # Initialize Trie of rules.
    trie_root = TrieNode()
    max_num_nts = 0
    for rule in rules:
        add_rule_to_trie(trie_root, rule)
        max_num_nts = max(max_num_nts, cfg_rule.get_num_nts(rule.rhs))

    # Populate the chart.
    for span_end in range(1, input_len + 1):
        for span_begin in range(span_end - 1, -1, -1):

            # Map of span_begin to List of SearchState.
            search_map = collections.defaultdict(list)
            search_map[span_begin].append(SearchState([], trie_root))

            # Iterate across every input token in the span range to find rule matches.
            for idx in range(span_begin, span_end):

                # End early if there are no remaining candidate matches.
                if not search_map[idx]:
                    continue

                terminal_symbol = input_symbols[idx]

                # Iterate through partial matches.
                while search_map[idx]:
                    search_state = search_map[idx].pop()

                    # Consider matching terminal.
                    new_trie_node = search_state.trie_node.maybe_get_child(
                        terminal_symbol)
                    if new_trie_node:
                        # Found a match for the terminal in the Trie.
                        # Add a partial match to search_map with idx incremented by 1 token.
                        new_search_state = SearchState(
                            search_state.anchored_nonterminals, new_trie_node)
                        search_map[idx + 1].append(new_search_state)

                    # Consider matching non-terminal.
                    nonterminal_tuples = chart.get_from_start(idx)
                    if len(search_state.anchored_nonterminals) < max_num_nts:
                        # Iterate through lower chart entries with a completed sub-tree
                        # that starts at the current index.
                        for nt_end, nonterminal in nonterminal_tuples:
                            nonterminal_symbol = cfg_rule.CFGSymbol(
                                nonterminal, cfg_rule.NON_TERMINAL)
                            new_trie_node = search_state.trie_node.maybe_get_child(
                                nonterminal_symbol)
                            if new_trie_node:
                                # Found a match for the non-terminal in the Trie.
                                # Add a partial match to search_map with idx set to the end
                                # of the sub-tree span.
                                new_anchored_nonterminals = search_state.anchored_nonterminals[:]
                                new_anchored_nonterminals.append(
                                    (idx, nt_end, nonterminal))
                                search_map[nt_end].append(
                                    SearchState(new_anchored_nonterminals,
                                                new_trie_node))

            # Loop through search_map for completed matches at span_end.
            for search_state in search_map[span_end]:
                # Get the ParserRule(s) associated with the particular Trie path.
                matched_rules = search_state.trie_node.values
                if not matched_rules:
                    continue

                for rule in matched_rules:
                    # Given the ParserRule and anchored nonterminal positions, generate
                    # new chart entries and add chart.
                    children_list = []
                    for anchored_nt in search_state.anchored_nonterminals:
                        children = chart.get_from_key(*anchored_nt)
                        children_list.append(children)
                    for children in itertools.product(*children_list):
                        chart.add(span_begin, span_end, rule, children)

            for nt in nonterminals:
                chart.postprocess(span_begin, span_end, nt)

            # Optionally apply rule where RHS is a single NT.
            for _ in range(max_single_nt_applications):
                for nt in nonterminals:
                    # Copy cell since we are mutating it during iteration below.
                    cell = chart.get_from_key(span_begin, span_end, nt).copy()
                    nt_symbol = cfg_rule.CFGSymbol(nt, cfg_rule.NON_TERMINAL)
                    child = trie_root.maybe_get_child(nt_symbol)
                    if child:
                        single_nt_rules = child.values
                        for rule in single_nt_rules:
                            for node in cell:
                                chart.add(span_begin, span_end, rule, [node])
                    chart.postprocess(span_begin, span_end, nt)

            if verbose:
                for nt in nonterminals:
                    cell = chart.get_from_key(span_begin, span_end, nt)
                    if cell:
                        print("Populated (%s,%s): %s - %s" %
                              (span_begin, span_end, nt, cell))

    # Return completed parses.
    parses = []
    for start_idx in start_idx_set:
        parses.extend(chart.get_from_key(0, input_len, start_idx))
    return parses
Ejemplo n.º 12
0
def parse(tokens,
          rules,
          node_fn,
          postprocess_cell_fn,
          max_single_nt_applications=1,
          verbose=False):
    """Run bottom up parser.

  Args:
    tokens: List of strings for input (terminals or nonterminals).
    rules: List of QCFGRule instances.
    node_fn: Function with input arguments (span_begin, span_end, rule,
      children) and returns a "node".
    postprocess_cell_fn: Function from a list of "nodes" to "nodes".
    max_single_nt_applications: The maximum number of times a rule where the RHS
      is a single nonterminal symbol can be applied consecutively.
    verbose: Print debug output if True.

  Returns:
    A List of "node" objects for completed parses.
  """
    if verbose:
        print("tokens: %s" % (tokens, ))
        print("rules:")
        for rule in rules:
            print(str(rule))

    # Our QCFG grammars always use a single NT symbol.
    nt_idx = 0

    # Convert to ParserRule format.
    converter = cfg_converter.CFGRuleConverter()
    idx_to_rule = {}
    parser_rules = []
    rule_idx = 0

    allowed_terminals = set(tokens)
    for rule in rules:
        if not qcfg_rule.is_allowed(rule.source, allowed_terminals):
            continue
        rhs = _convert_nt(rule.source)
        parser_rule = converter.convert_to_cfg_rule(
            lhs=NT_IDX,
            rhs=rhs,
            rule_idx=rule_idx,
            nonterminal_prefix=NON_TERMINAL_PREFIX)
        parser_rules.append(parser_rule)
        idx_to_rule[rule_idx] = rule
        rule_idx += 1

    for token in tokens:
        if not qcfg_rule.is_nt(
                token) and token not in converter.terminals_to_ids:
            if verbose:
                print("Input token does not appear in rules: %s" % token)
            return []

    input_symbols = []
    for token in tokens:
        if qcfg_rule.is_nt(token):
            input_symbols.append(
                cfg_rule.CFGSymbol(nt_idx, cfg_rule.NON_TERMINAL))
        else:
            idx = converter.terminals_to_ids[token]
            input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.TERMINAL))

    # Wrap node_fn to pass original Rule instead of CFGRule.
    def populate_fn(span_begin, span_end, parser_rule, children):
        rule = idx_to_rule[parser_rule.idx]
        node = node_fn(span_begin, span_end, rule, children)
        return [node]

    nonterminals = {nt_idx}
    start_idx = nt_idx

    if verbose:
        print("parser_rules: %s" % parser_rules)

    parses = cfg_parser.parse_symbols(
        input_symbols,
        parser_rules,
        nonterminals, {start_idx},
        populate_fn,
        postprocess_cell_fn,
        max_single_nt_applications=max_single_nt_applications,
        verbose=verbose)

    return parses
Ejemplo n.º 13
0
def parse(tokens, rules, node_fn, postprocess_fn, verbose=False):
    """Run bottom up parser on QCFG target using target CFG.

  Args:
    tokens: List of strings for input.
    rules: List of TargetCfgRule instances.
    node_fn: Function with input arguments (span_begin, span_end, rule,
      children) and returns a list of "node".
    postprocess_fn: Function from a list of "nodes" to "nodes".
    verbose: Print debug output if True.

  Returns:
    A List of "node" objects for completed parses.
  """
    if verbose:
        print("tokens: %s" % (tokens, ))
        print("rules:")
        for rule in rules:
            print(str(rule))
    terminals = [
        token for token in tokens
        if not token.startswith(qcfg_rule.NON_TERMINAL_PREFIX)
    ]

    # Convert rules.
    converter = cfg_converter.CFGRuleConverter()
    parser_rules = []
    idx_to_rule = {}
    rule_idx = 0
    for rule in rules:
        parser_rule = converter.convert_to_cfg_rule(
            lhs=rule.lhs,
            rhs=rule.rhs.split(" "),
            rule_idx=rule_idx,
            nonterminal_prefix=target_grammar.NON_TERMINAL_PREFIX,
            allowed_terminals=set(terminals))
        if parser_rule:
            parser_rules.append(parser_rule)
            idx_to_rule[rule_idx] = rule
            rule_idx += 1

    # Add rules for every target nonterminal and QCFG nonterminal
    target_nts = set(converter.nonterminals_to_ids.keys())
    qcfg_nts = set(qcfg_rule.get_nts(tokens))
    for target_nt in target_nts:
        for qcfg_nt in qcfg_nts:
            rule = target_grammar.TargetCfgRule(target_nt,
                                                _convert_qcfg_nt(qcfg_nt))
            parser_rule = converter.convert_to_cfg_rule(
                lhs=rule.lhs,
                rhs=rule.rhs.split(" "),
                rule_idx=rule_idx,
                nonterminal_prefix=target_grammar.NON_TERMINAL_PREFIX)
            parser_rules.append(parser_rule)
            idx_to_rule[rule_idx] = rule
            rule_idx += 1

    input_symbols = []
    for token in tokens:
        if qcfg_rule.is_nt(token):
            if token not in converter.nonterminals_to_ids:
                return []
            idx = converter.nonterminals_to_ids[token]
            input_symbols.append(cfg_rule.CFGSymbol(idx,
                                                    cfg_rule.NON_TERMINAL))
        else:
            if token not in converter.terminals_to_ids:
                return []
            idx = converter.terminals_to_ids[token]
            input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.TERMINAL))

    # Wrap node_fn to pass original Rule instead of CFGRule.
    def populate_fn(span_begin, span_end, parser_rule, children):
        rule = idx_to_rule[parser_rule.idx]
        nodes = node_fn(span_begin, span_end, rule, children)
        return nodes

    nonterminals = set(converter.nonterminals_to_ids.values())
    if verbose:
        print("parser_rules: %s" % parser_rules)

    parses = cfg_parser.parse_symbols(input_symbols,
                                      parser_rules,
                                      nonterminals,
                                      nonterminals,
                                      populate_fn,
                                      postprocess_fn,
                                      max_single_nt_applications=0,
                                      verbose=verbose)
    return parses
Ejemplo n.º 14
0
    def convert(self, induced_rule, verbose=False):
        """Convert QCFGRule to JointRule."""
        tokens = induced_rule.target
        input_symbols = []
        terminal_ids = set()
        qcfg_idxs = []
        rhs = []
        num_nts = 0
        for token in tokens:
            if qcfg_rule.is_nt(token):
                qcfg_idx = qcfg_rule.get_nt_index(token)
                qcfg_idxs.append(qcfg_idx)
                # NT placeholders are 1-indexed.
                qcfg_nt = NT_PLACEHOLDER % (num_nts + 1)
                num_nts += 1
                rhs.append(JOINT_NT)
                idx = self.converter.nonterminals_to_ids[qcfg_nt]
                input_symbols.append(
                    cfg_rule.CFGSymbol(idx, cfg_rule.NON_TERMINAL))
            else:
                if token not in self.converter.terminals_to_ids:
                    raise ValueError(
                        "token `%s` not in `converter.terminals_to_ids`: %s" %
                        (token, self.converter.terminals_to_ids))
                rhs.append(token)
                idx = self.converter.terminals_to_ids[token]
                terminal_ids.add(idx)
                input_symbols.append(cfg_rule.CFGSymbol(
                    idx, cfg_rule.TERMINAL))

        # Filter rules that contain terminals not in the input.
        def should_include(parser_rule):
            for symbol in parser_rule.rhs:
                if symbol.type == cfg_rule.TERMINAL and symbol.idx not in terminal_ids:
                    return False
            return True

        filtered_rules = [
            rule for rule in self.parser_rules if should_include(rule)
        ]
        if verbose:
            print("filtered_rules:")
            for rule in filtered_rules:
                print(rule)

        def populate_fn(unused_span_begin, unused_span_end, parser_rule,
                        children):
            return [ParseNode(parser_rule, children)]

        nonterminals = set(self.converter.nonterminals_to_ids.values())
        parses = cfg_parser.parse_symbols(
            input_symbols,
            filtered_rules,
            nonterminals,
            nonterminals,
            populate_fn,
            postprocess_fn=None,
            max_single_nt_applications=self.max_single_nt_applications,
            verbose=verbose)
        if not parses:
            print("Could not parse: %s" % (tokens, ))
            return None

        # Extract cfg_nts from parses.
        cfg_nts_set = set()
        for parse_node in parses:
            cfg_nts = _get_cfg_nts(self.converter.nonterminals_to_ids,
                                   self.rhs_nt_rules, parse_node, num_nts)
            cfg_nts = _rearrange_nts(cfg_nts, qcfg_idxs)
            if cfg_nts:
                cfg_nts_set.add(cfg_nts)

        return JointRule(induced_rule, frozenset(cfg_nts_set))