Ejemplo n.º 1
0
def _swap_nt_order(rhs, old_to_new):
    new_rhs = []
    for symbol in rhs:
        if qcfg_rule.is_nt_fast(symbol):
            old_idx = qcfg_rule.get_nt_index(symbol)
            new_rhs.append(str(qcfg_rule.NT(old_to_new[old_idx])))
        else:
            new_rhs.append(symbol)
    return tuple(new_rhs)
Ejemplo n.º 2
0
def _get_free_nt_replacement(nts):
  """Get the next free NT index."""
  # For example
  # nts = {NT_1, NT_3} will return NT_2.
  # nts = {NT_1, NT_2} will return NT_3.
  nt_indices = set(qcfg_rule.get_nt_index(nt) for nt in nts)
  # The new NT indices are {NT_1, ..., NT_{n+1}}
  missing_idx = set(range(1, len(nts) + 2)) - nt_indices
  return str(qcfg_rule.NT(missing_idx.pop()))
Ejemplo n.º 3
0
def _convert_to_qcfg(nested_rule):
  """Convert nested JointRule to QCFG source and target."""
  sources = []
  targets = []
  rule = nested_rule[0]
  idx_to_source = {}
  idx_to_target = {}
  for nt_idx, child_rule in enumerate(nested_rule[1:]):
    source, target = _convert_to_qcfg(child_rule)
    idx_to_source[nt_idx + 1] = source
    idx_to_target[nt_idx + 1] = target
  for symbol in rule.source:
    if qcfg_rule.is_nt_fast(symbol):
      index = qcfg_rule.get_nt_index(symbol)
      sources.extend(idx_to_source[index])
    else:
      sources.append(symbol)
  for symbol in rule.target:
    if qcfg_rule.is_nt_fast(symbol):
      index = qcfg_rule.get_nt_index(symbol)
      targets.extend(idx_to_target[index])
    else:
      targets.append(symbol)
  return sources, targets
Ejemplo n.º 4
0
def canonicalize_nts(source, target, arity):
    """Follows convention of source indexes being in order."""
    source_nts = []
    for token in source:
        if qcfg_rule.is_nt(token) and token not in source_nts:
            source_nts.append(token)
    if len(set(source_nts)) != arity:
        raise ValueError("Bad arity 2 source: %s" % (source, ))
    old_to_new = {
        qcfg_rule.get_nt_index(nt): idx + 1
        for idx, nt in enumerate(source_nts)
    }
    source = _swap_nt_order(source, old_to_new)
    target = _swap_nt_order(target, old_to_new)
    return source, target
Ejemplo n.º 5
0
    def convert(self, induced_rule, verbose=False):
        """Convert QCFGRule to JointRule."""
        tokens = induced_rule.target
        input_symbols = []
        terminal_ids = set()
        qcfg_idxs = []
        rhs = []
        num_nts = 0
        for token in tokens:
            if qcfg_rule.is_nt(token):
                qcfg_idx = qcfg_rule.get_nt_index(token)
                qcfg_idxs.append(qcfg_idx)
                # NT placeholders are 1-indexed.
                qcfg_nt = NT_PLACEHOLDER % (num_nts + 1)
                num_nts += 1
                rhs.append(JOINT_NT)
                idx = self.converter.nonterminals_to_ids[qcfg_nt]
                input_symbols.append(
                    cfg_rule.CFGSymbol(idx, cfg_rule.NON_TERMINAL))
            else:
                if token not in self.converter.terminals_to_ids:
                    raise ValueError(
                        "token `%s` not in `converter.terminals_to_ids`: %s" %
                        (token, self.converter.terminals_to_ids))
                rhs.append(token)
                idx = self.converter.terminals_to_ids[token]
                terminal_ids.add(idx)
                input_symbols.append(cfg_rule.CFGSymbol(
                    idx, cfg_rule.TERMINAL))

        # Filter rules that contain terminals not in the input.
        def should_include(parser_rule):
            for symbol in parser_rule.rhs:
                if symbol.type == cfg_rule.TERMINAL and symbol.idx not in terminal_ids:
                    return False
            return True

        filtered_rules = [
            rule for rule in self.parser_rules if should_include(rule)
        ]
        if verbose:
            print("filtered_rules:")
            for rule in filtered_rules:
                print(rule)

        def populate_fn(unused_span_begin, unused_span_end, parser_rule,
                        children):
            return [ParseNode(parser_rule, children)]

        nonterminals = set(self.converter.nonterminals_to_ids.values())
        parses = cfg_parser.parse_symbols(
            input_symbols,
            filtered_rules,
            nonterminals,
            nonterminals,
            populate_fn,
            postprocess_fn=None,
            max_single_nt_applications=self.max_single_nt_applications,
            verbose=verbose)
        if not parses:
            print("Could not parse: %s" % (tokens, ))
            return None

        # Extract cfg_nts from parses.
        cfg_nts_set = set()
        for parse_node in parses:
            cfg_nts = _get_cfg_nts(self.converter.nonterminals_to_ids,
                                   self.rhs_nt_rules, parse_node, num_nts)
            cfg_nts = _rearrange_nts(cfg_nts, qcfg_idxs)
            if cfg_nts:
                cfg_nts_set.add(cfg_nts)

        return JointRule(induced_rule, frozenset(cfg_nts_set))