def test_parse_3(self): # NT -> FOO NT rule_1 = cfg_rule.CFGRule( idx=0, lhs=NT, rhs=( cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL), cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), )) input_symbols = [ cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL), cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL), cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), ] parses = cfg_parser.parse_symbols( input_symbols, [rule_1], {NT}, {NT}, _populate_fn, _postprocess_fn, verbose=True) self.assertLen(parses, 1) parse_node = parses[0] self.assertEqual(parse_node, [(0, 3, 0), (1, 3, 0)])
def test_parse_1(self): # NT -> BAR rule_1 = cfg_rule.CFGRule( idx=0, lhs=NT, rhs=(cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),)) # NT -> FOO NT rule_2 = cfg_rule.CFGRule( idx=1, lhs=NT, rhs=( cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL), cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), )) input_ids = [FOO, FOO, BAR] parses = cfg_parser.parse( input_ids, [rule_1, rule_2], {NT}, {NT}, _populate_fn, _postprocess_fn, verbose=True) self.assertLen(parses, 1) parse_node = parses[0] self.assertEqual(parse_node, [(0, 3, 1), (1, 3, 1), (2, 3, 0)])
def can_parse(self, tokens, verbose=False): """Return True if can be parsed given target CFG.""" input_symbols = [] terminal_ids = set() for token in tokens: if qcfg_rule.is_nt(token): idx = self.converter.nonterminals_to_ids[PLACEHOLDER_NT] input_symbols.append( cfg_rule.CFGSymbol(idx, cfg_rule.NON_TERMINAL)) else: if token not in self.converter.terminals_to_ids: if verbose: print( "token `%s` not in `converter.terminals_to_ids`: %s" % (token, self.converter.terminals_to_ids)) return False idx = self.converter.terminals_to_ids[token] terminal_ids.add(idx) input_symbols.append(cfg_rule.CFGSymbol( idx, cfg_rule.TERMINAL)) # Filter rules that contain terminals not in the input. def should_include(parser_rule): for symbol in parser_rule.rhs: if symbol.type == cfg_rule.TERMINAL and symbol.idx not in terminal_ids: return False return True filtered_rules = [ rule for rule in self.parser_rules if should_include(rule) ] if verbose: print("filtered_rules:") for rule in filtered_rules: print(rule) def populate_fn(unused_span_begin, unused_span_end, unused_parser_rule, unused_children): return [True] nonterminals = set(self.converter.nonterminals_to_ids.values()) parses = cfg_parser.parse_symbols(input_symbols, filtered_rules, nonterminals, nonterminals, populate_fn, postprocess_fn=None, max_single_nt_applications=2, verbose=verbose) if parses: return True else: return False
def can_parse(target_string, rules, max_single_nt_applications=2, verbose=False): """Returns True if there exists >=1 parse of target_string given rules.""" tokens = target_string.split(" ") # Convert rules. converter = cfg_converter.CFGRuleConverter() parser_rules = [] for rule_idx, rule in enumerate(rules): parser_rule = converter.convert_to_cfg_rule( lhs=rule.lhs, rhs=rule.rhs.split(" "), rule_idx=rule_idx, nonterminal_prefix=NON_TERMINAL_PREFIX, allowed_terminals=set(tokens)) if parser_rule: parser_rules.append(parser_rule) start_idx = converter.nonterminals_to_ids[ROOT_SYMBOL] nonterminals = converter.nonterminals_to_ids.values() input_symbols = [] for token in tokens: if token.startswith(NON_TERMINAL_PREFIX): idx = converter.nonterminals_to_ids[token[len(NON_TERMINAL_PREFIX):]] input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.NON_TERMINAL)) else: if token not in converter.terminals_to_ids: return False idx = converter.terminals_to_ids[token] input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.TERMINAL)) # Run parser. parses = cfg_parser.parse_symbols( input_symbols, parser_rules, nonterminals, {start_idx}, _populate_fn, _postprocess_fn, verbose=verbose, max_single_nt_applications=max_single_nt_applications) if parses: return True else: return False
def sample(parser_rules, start_idx, rule_values=None, max_recursion=1, nonterminal_coef=1, verbose=False): """Sample data from CFG. Args: parser_rules: A list of CFGRule instances. start_idx: Index of non-terminal that is start symbol. rule_values: A optional list of rules with 1-1 mapping to parser rules. The rule values can be target grammar, QCFG rules, strings, etc. Only used when verbose is True for debugging purpose. max_recursion: The maximum number of recursion depth of applying CFG rules. nonterminal_coef: The scaling coefficient for rules with nonterminals. verbose: Print debug logging if True. Returns: A nested list of CFGRule instances. """ nonterminals_to_rules = collections.defaultdict(list) for rule in parser_rules: nonterminals_to_rules[rule.lhs].append(rule) def expand_nonterminal(nonterminal, recursion=0): rules_to_sample = nonterminals_to_rules[nonterminal.idx] if recursion == max_recursion: # Filter out the rules that have NTs on RHS. rules_to_sample_no_nts = [ rule for rule in nonterminals_to_rules[nonterminal.idx] if cfg_rule.get_arity(rule) == 0 ] # If there are no rules for this NT that contain no NTs, then keep # recursing. In this case, we may exceed `max_recursion`. if rules_to_sample_no_nts: rules_to_sample = rules_to_sample_no_nts sampled_rule = sample_rule(rules_to_sample, nonterminal_coef=nonterminal_coef) if verbose and rule_values is not None: print("Recursion %d, Sampled rule: %s" % (recursion, rule_values[sampled_rule.idx])) output = [sampled_rule] for symbol in sampled_rule.rhs: if symbol.type == cfg_rule.NON_TERMINAL: output.append( expand_nonterminal(symbol, recursion=recursion + 1)) return output start_symbol = cfg_rule.CFGSymbol(idx=start_idx, type=cfg_rule.NON_TERMINAL) output = expand_nonterminal(start_symbol, recursion=0) return output
def test_sample(self, mock_random): mock_random.return_value = [0] # NT -> FOO NT rule_1 = cfg_rule.CFGRule( idx=0, lhs=NT, rhs=( cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL), cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), )) # NT -> BAR rule_2 = cfg_rule.CFGRule(idx=1, lhs=NT, rhs=(cfg_rule.CFGSymbol( BAR, cfg_rule.TERMINAL), )) output = cfg_sampler.sample([rule_1, rule_2], NT, rule_values=[rule_1, rule_2], verbose=True) self.assertEqual(output, [rule_1, [rule_2]])
def test_parse_6(self): # NT -> NT_2 BAR rule_1 = cfg_rule.CFGRule( idx=0, lhs=NT, rhs=( cfg_rule.CFGSymbol(NT_2, cfg_rule.NON_TERMINAL), cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL), )) # NT_2 -> NT BAR rule_2 = cfg_rule.CFGRule( idx=1, lhs=NT_2, rhs=( cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL), )) input_symbols = [ cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL), ] parses = cfg_parser.parse_symbols( input_symbols, [rule_1, rule_2], {NT, NT_2}, {NT, NT_2}, _populate_fn, _postprocess_fn, verbose=True) self.assertLen(parses, 1) parse_node = parses[0] self.assertEqual(parse_node, [(0, 2, 1)])
def test_parse_4(self): # NT -> BAR rule_1 = cfg_rule.CFGRule( idx=0, lhs=NT, rhs=(cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL),)) # NT -> NT FOO NT rule_2 = cfg_rule.CFGRule( idx=1, lhs=NT, rhs=( cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL), cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), )) # NT -> NT FOO BAR rule_3 = cfg_rule.CFGRule( idx=2, lhs=NT, rhs=( cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL), cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL), )) input_symbols = [ cfg_rule.CFGSymbol(NT, cfg_rule.NON_TERMINAL), cfg_rule.CFGSymbol(FOO, cfg_rule.TERMINAL), cfg_rule.CFGSymbol(BAR, cfg_rule.TERMINAL), ] parses = cfg_parser.parse_symbols( input_symbols, [rule_1, rule_2, rule_3], {NT}, {NT}, _populate_fn, _postprocess_fn, verbose=True) self.assertLen(parses, 2) self.assertEqual(parses, [[(0, 3, 2)], [(0, 3, 1), (2, 3, 0)]])
def convert_to_cfg_rule(self, lhs, rhs, rule_idx, nonterminal_prefix, allowed_terminals=None): """Convert symbol strings to CFGRule. Args: lhs: String symbol for LHS. rhs: List of string symbols for RHS. rule_idx: Integer index for rule. nonterminal_prefix: String prefix for nonterminal symbols in `rhs`. allowed_terminals: If set, returns None if rhs contains terminals not in this set. Returns: A CFGRule. """ rhs_symbols = [] lhs_idx = self._get_nonterminal_id(lhs) for token in rhs: if token.startswith(nonterminal_prefix): symbol_idx = self._get_nonterminal_id( token[len(nonterminal_prefix):]) rhs_symbols.append( cfg_rule.CFGSymbol(idx=symbol_idx, type=cfg_rule.NON_TERMINAL)) else: if allowed_terminals and token not in allowed_terminals: return None symbol_idx = self._get_terminal_id(token) rhs_symbols.append( cfg_rule.CFGSymbol(idx=symbol_idx, type=cfg_rule.TERMINAL)) rule = cfg_rule.CFGRule(idx=rule_idx, lhs=lhs_idx, rhs=rhs_symbols) return rule
def parse(input_ids, rules, nonterminals, start_idx_set, populate_fn, postprocess_fn, max_single_nt_applications=1, verbose=False): """Run bottom up parser where all inputs are terminals.""" input_symbols = tuple( [cfg_rule.CFGSymbol(idx, cfg_rule.TERMINAL) for idx in input_ids]) return parse_symbols(input_symbols, rules, nonterminals, start_idx_set, populate_fn, postprocess_fn, max_single_nt_applications=max_single_nt_applications, verbose=verbose)
def parse_symbols(input_symbols, rules, nonterminals, start_idx_set, populate_fn, postprocess_fn, max_single_nt_applications=1, verbose=False): """Run bottom up parser. Let T be an arbitrary type for chart entries, specified by the return type of populate_fn. Examples for T are simple types that simply indicate presenece of a parse for a given span, or more complex structures that represent parse forests. Args: input_symbols: List of CFGSymbols in rules. rules: A list of CFGRule instances. nonterminals: Collection of CFGSymbol objects for possible non-terminals. start_idx_set: A set of index of non-terminal that is start symbol. populate_fn: A function that takes: (span_begin (Interger), span_end (Integer), parser_rule (CFGRule), substitutions (List of T)) and returns a list of objects of type T, which can be any type. These objects are added to the chart. Depending on what information is desired about completed parses, T can be anything from a simple count to a complex parse forest object. postprocess_fn: A function that takes and returns a list of T. This function post-processes each cell after it has been populated. This function is useful for pruning the chart, or merging equivalent entries. Ignored if None. max_single_nt_applications: The maximum number of times a rule where the RHS is a single nonterminal symbol can be applied consecutively. verbose: Print debug logging if True. Returns: A list of T. """ input_len = len(input_symbols) # Initialize the empty chart. chart = Chart(populate_fn, postprocess_fn) # Initialize Trie of rules. trie_root = TrieNode() max_num_nts = 0 for rule in rules: add_rule_to_trie(trie_root, rule) max_num_nts = max(max_num_nts, cfg_rule.get_num_nts(rule.rhs)) # Populate the chart. for span_end in range(1, input_len + 1): for span_begin in range(span_end - 1, -1, -1): # Map of span_begin to List of SearchState. search_map = collections.defaultdict(list) search_map[span_begin].append(SearchState([], trie_root)) # Iterate across every input token in the span range to find rule matches. for idx in range(span_begin, span_end): # End early if there are no remaining candidate matches. if not search_map[idx]: continue terminal_symbol = input_symbols[idx] # Iterate through partial matches. while search_map[idx]: search_state = search_map[idx].pop() # Consider matching terminal. new_trie_node = search_state.trie_node.maybe_get_child( terminal_symbol) if new_trie_node: # Found a match for the terminal in the Trie. # Add a partial match to search_map with idx incremented by 1 token. new_search_state = SearchState( search_state.anchored_nonterminals, new_trie_node) search_map[idx + 1].append(new_search_state) # Consider matching non-terminal. nonterminal_tuples = chart.get_from_start(idx) if len(search_state.anchored_nonterminals) < max_num_nts: # Iterate through lower chart entries with a completed sub-tree # that starts at the current index. for nt_end, nonterminal in nonterminal_tuples: nonterminal_symbol = cfg_rule.CFGSymbol( nonterminal, cfg_rule.NON_TERMINAL) new_trie_node = search_state.trie_node.maybe_get_child( nonterminal_symbol) if new_trie_node: # Found a match for the non-terminal in the Trie. # Add a partial match to search_map with idx set to the end # of the sub-tree span. new_anchored_nonterminals = search_state.anchored_nonterminals[:] new_anchored_nonterminals.append( (idx, nt_end, nonterminal)) search_map[nt_end].append( SearchState(new_anchored_nonterminals, new_trie_node)) # Loop through search_map for completed matches at span_end. for search_state in search_map[span_end]: # Get the ParserRule(s) associated with the particular Trie path. matched_rules = search_state.trie_node.values if not matched_rules: continue for rule in matched_rules: # Given the ParserRule and anchored nonterminal positions, generate # new chart entries and add chart. children_list = [] for anchored_nt in search_state.anchored_nonterminals: children = chart.get_from_key(*anchored_nt) children_list.append(children) for children in itertools.product(*children_list): chart.add(span_begin, span_end, rule, children) for nt in nonterminals: chart.postprocess(span_begin, span_end, nt) # Optionally apply rule where RHS is a single NT. for _ in range(max_single_nt_applications): for nt in nonterminals: # Copy cell since we are mutating it during iteration below. cell = chart.get_from_key(span_begin, span_end, nt).copy() nt_symbol = cfg_rule.CFGSymbol(nt, cfg_rule.NON_TERMINAL) child = trie_root.maybe_get_child(nt_symbol) if child: single_nt_rules = child.values for rule in single_nt_rules: for node in cell: chart.add(span_begin, span_end, rule, [node]) chart.postprocess(span_begin, span_end, nt) if verbose: for nt in nonterminals: cell = chart.get_from_key(span_begin, span_end, nt) if cell: print("Populated (%s,%s): %s - %s" % (span_begin, span_end, nt, cell)) # Return completed parses. parses = [] for start_idx in start_idx_set: parses.extend(chart.get_from_key(0, input_len, start_idx)) return parses
def parse(tokens, rules, node_fn, postprocess_cell_fn, max_single_nt_applications=1, verbose=False): """Run bottom up parser. Args: tokens: List of strings for input (terminals or nonterminals). rules: List of QCFGRule instances. node_fn: Function with input arguments (span_begin, span_end, rule, children) and returns a "node". postprocess_cell_fn: Function from a list of "nodes" to "nodes". max_single_nt_applications: The maximum number of times a rule where the RHS is a single nonterminal symbol can be applied consecutively. verbose: Print debug output if True. Returns: A List of "node" objects for completed parses. """ if verbose: print("tokens: %s" % (tokens, )) print("rules:") for rule in rules: print(str(rule)) # Our QCFG grammars always use a single NT symbol. nt_idx = 0 # Convert to ParserRule format. converter = cfg_converter.CFGRuleConverter() idx_to_rule = {} parser_rules = [] rule_idx = 0 allowed_terminals = set(tokens) for rule in rules: if not qcfg_rule.is_allowed(rule.source, allowed_terminals): continue rhs = _convert_nt(rule.source) parser_rule = converter.convert_to_cfg_rule( lhs=NT_IDX, rhs=rhs, rule_idx=rule_idx, nonterminal_prefix=NON_TERMINAL_PREFIX) parser_rules.append(parser_rule) idx_to_rule[rule_idx] = rule rule_idx += 1 for token in tokens: if not qcfg_rule.is_nt( token) and token not in converter.terminals_to_ids: if verbose: print("Input token does not appear in rules: %s" % token) return [] input_symbols = [] for token in tokens: if qcfg_rule.is_nt(token): input_symbols.append( cfg_rule.CFGSymbol(nt_idx, cfg_rule.NON_TERMINAL)) else: idx = converter.terminals_to_ids[token] input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.TERMINAL)) # Wrap node_fn to pass original Rule instead of CFGRule. def populate_fn(span_begin, span_end, parser_rule, children): rule = idx_to_rule[parser_rule.idx] node = node_fn(span_begin, span_end, rule, children) return [node] nonterminals = {nt_idx} start_idx = nt_idx if verbose: print("parser_rules: %s" % parser_rules) parses = cfg_parser.parse_symbols( input_symbols, parser_rules, nonterminals, {start_idx}, populate_fn, postprocess_cell_fn, max_single_nt_applications=max_single_nt_applications, verbose=verbose) return parses
def parse(tokens, rules, node_fn, postprocess_fn, verbose=False): """Run bottom up parser on QCFG target using target CFG. Args: tokens: List of strings for input. rules: List of TargetCfgRule instances. node_fn: Function with input arguments (span_begin, span_end, rule, children) and returns a list of "node". postprocess_fn: Function from a list of "nodes" to "nodes". verbose: Print debug output if True. Returns: A List of "node" objects for completed parses. """ if verbose: print("tokens: %s" % (tokens, )) print("rules:") for rule in rules: print(str(rule)) terminals = [ token for token in tokens if not token.startswith(qcfg_rule.NON_TERMINAL_PREFIX) ] # Convert rules. converter = cfg_converter.CFGRuleConverter() parser_rules = [] idx_to_rule = {} rule_idx = 0 for rule in rules: parser_rule = converter.convert_to_cfg_rule( lhs=rule.lhs, rhs=rule.rhs.split(" "), rule_idx=rule_idx, nonterminal_prefix=target_grammar.NON_TERMINAL_PREFIX, allowed_terminals=set(terminals)) if parser_rule: parser_rules.append(parser_rule) idx_to_rule[rule_idx] = rule rule_idx += 1 # Add rules for every target nonterminal and QCFG nonterminal target_nts = set(converter.nonterminals_to_ids.keys()) qcfg_nts = set(qcfg_rule.get_nts(tokens)) for target_nt in target_nts: for qcfg_nt in qcfg_nts: rule = target_grammar.TargetCfgRule(target_nt, _convert_qcfg_nt(qcfg_nt)) parser_rule = converter.convert_to_cfg_rule( lhs=rule.lhs, rhs=rule.rhs.split(" "), rule_idx=rule_idx, nonterminal_prefix=target_grammar.NON_TERMINAL_PREFIX) parser_rules.append(parser_rule) idx_to_rule[rule_idx] = rule rule_idx += 1 input_symbols = [] for token in tokens: if qcfg_rule.is_nt(token): if token not in converter.nonterminals_to_ids: return [] idx = converter.nonterminals_to_ids[token] input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.NON_TERMINAL)) else: if token not in converter.terminals_to_ids: return [] idx = converter.terminals_to_ids[token] input_symbols.append(cfg_rule.CFGSymbol(idx, cfg_rule.TERMINAL)) # Wrap node_fn to pass original Rule instead of CFGRule. def populate_fn(span_begin, span_end, parser_rule, children): rule = idx_to_rule[parser_rule.idx] nodes = node_fn(span_begin, span_end, rule, children) return nodes nonterminals = set(converter.nonterminals_to_ids.values()) if verbose: print("parser_rules: %s" % parser_rules) parses = cfg_parser.parse_symbols(input_symbols, parser_rules, nonterminals, nonterminals, populate_fn, postprocess_fn, max_single_nt_applications=0, verbose=verbose) return parses
def convert(self, induced_rule, verbose=False): """Convert QCFGRule to JointRule.""" tokens = induced_rule.target input_symbols = [] terminal_ids = set() qcfg_idxs = [] rhs = [] num_nts = 0 for token in tokens: if qcfg_rule.is_nt(token): qcfg_idx = qcfg_rule.get_nt_index(token) qcfg_idxs.append(qcfg_idx) # NT placeholders are 1-indexed. qcfg_nt = NT_PLACEHOLDER % (num_nts + 1) num_nts += 1 rhs.append(JOINT_NT) idx = self.converter.nonterminals_to_ids[qcfg_nt] input_symbols.append( cfg_rule.CFGSymbol(idx, cfg_rule.NON_TERMINAL)) else: if token not in self.converter.terminals_to_ids: raise ValueError( "token `%s` not in `converter.terminals_to_ids`: %s" % (token, self.converter.terminals_to_ids)) rhs.append(token) idx = self.converter.terminals_to_ids[token] terminal_ids.add(idx) input_symbols.append(cfg_rule.CFGSymbol( idx, cfg_rule.TERMINAL)) # Filter rules that contain terminals not in the input. def should_include(parser_rule): for symbol in parser_rule.rhs: if symbol.type == cfg_rule.TERMINAL and symbol.idx not in terminal_ids: return False return True filtered_rules = [ rule for rule in self.parser_rules if should_include(rule) ] if verbose: print("filtered_rules:") for rule in filtered_rules: print(rule) def populate_fn(unused_span_begin, unused_span_end, parser_rule, children): return [ParseNode(parser_rule, children)] nonterminals = set(self.converter.nonterminals_to_ids.values()) parses = cfg_parser.parse_symbols( input_symbols, filtered_rules, nonterminals, nonterminals, populate_fn, postprocess_fn=None, max_single_nt_applications=self.max_single_nt_applications, verbose=verbose) if not parses: print("Could not parse: %s" % (tokens, )) return None # Extract cfg_nts from parses. cfg_nts_set = set() for parse_node in parses: cfg_nts = _get_cfg_nts(self.converter.nonterminals_to_ids, self.rhs_nt_rules, parse_node, num_nts) cfg_nts = _rearrange_nts(cfg_nts, qcfg_idxs) if cfg_nts: cfg_nts_set.add(cfg_nts) return JointRule(induced_rule, frozenset(cfg_nts_set))