Ejemplo n.º 1
0
    def apply_action(self, action):
        if self.tree is None:
            assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \
                                                        'at the beginning of decoding'

            self.tree = AbstractSyntaxTree(action.production)
            self.update_frontier_info()
        elif self.frontier_node:
            if isinstance(self.frontier_field.type, ASDLCompositeType):
                if isinstance(action, ApplyRuleAction):
                    field_value = AbstractSyntaxTree(action.production)
                    field_value.created_time = self.t
                    self.frontier_field.add_value(field_value)
                    self.update_frontier_info()
                elif isinstance(action, ReduceAction):
                    assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
                                                                                        'applied on field with multiple ' \
                                                                                        'cardinality'
                    self.frontier_field.set_finish()
                    self.update_frontier_info()
                else:
                    raise ValueError('Invalid action [%s] on field [%s]' %
                                     (action, self.frontier_field))
            else:  # fill in a primitive field
                if isinstance(action, GenTokenAction):
                    # only field of type string requires termination signal </primitive>
                    end_primitive = False
                    if self.frontier_field.type.name == 'string':
                        if action.is_stop_signal():
                            self.frontier_field.add_value(' '.join(
                                self._value_buffer))
                            self._value_buffer = []

                            end_primitive = True
                        else:
                            self._value_buffer.append(action.token)
                    else:
                        self.frontier_field.add_value(action.token)
                        end_primitive = True

                    if end_primitive and self.frontier_field.cardinality in (
                            'single', 'optional'):
                        self.frontier_field.set_finish()
                        self.update_frontier_info()

                elif isinstance(action, ReduceAction):
                    assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
                                                                                        'applied on field with multiple ' \
                                                                                        'cardinality'
                    self.frontier_field.set_finish()
                    self.update_frontier_info()
                else:
                    raise ValueError(
                        'Can only invoke GenToken or Reduce actions on primitive fields'
                    )

        self.t += 1
        self.actions.append(action)
Ejemplo n.º 2
0
    def build_change_graph(old_ast: AbstractSyntaxTree, new_ast: AbstractSyntaxTree):
        equality_links = []

        def _modify_id(node, prefix=''):
            node.id = f'{prefix}-{node.id}'
            if isinstance(node, AbstractSyntaxNode):
                for field in node.fields:
                    for field_val in field.as_value_list:
                        _modify_id(field_val, prefix)

        old_ast_root_copy = old_ast.root_node.copy()
        _modify_id(old_ast_root_copy, 'old')

        new_ast_root_copy = new_ast.root_node.copy()
        _modify_id(new_ast_root_copy, 'new')

        old_ast = AbstractSyntaxTree(old_ast_root_copy)
        new_ast = AbstractSyntaxTree(new_ast_root_copy)

        def _search_common_sub_tree(tgt_ast_node):
            node_query_result = old_ast.find_node(tgt_ast_node)
            if node_query_result:
                src_node_id, src_node = node_query_result
                tgt_ast_node.parent_field.replace(tgt_ast_node, src_node)

                # register this link
                equality_links.append((tgt_ast_node.id, src_node.id))
            else:
                for field in tgt_ast_node.fields:
                    if field.type.is_composite:
                        for field_val in field.as_value_list:
                            _search_common_sub_tree(field_val)

        _search_common_sub_tree(new_ast.root_node)

        visited = set()
        adjacency_list = []

        def _visit(node, parent_node):
            if parent_node:
                adjacency_list.append((parent_node.id, node.id))

            if node.id in visited:
                return

            if isinstance(node, AbstractSyntaxNode):
                for field in node.fields:
                    for field_val in field.as_value_list:
                        _visit(field_val, node)

            visited.add(node.id)

        _visit(old_ast.root_node, None)
        _visit(new_ast.root_node, None)
        pass
Ejemplo n.º 3
0
 def parse_orderby(self, orderby_clause: list, limit: int,
                   orderby_field: RealizedField):
     if limit is None:
         ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Asc')) if orderby_clause[0] == 'asc' \
             else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Desc'))
     else:
         ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('AscLimit')) if orderby_clause[0] == 'asc' \
             else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('DescLimit'))
     col_units_field = ast_node.fields[0]
     for val_unit in orderby_clause[1]:
         col_units_field.add_value(self.parse_col_unit(val_unit[1]))
     orderby_field.add_value(ast_node)
Ejemplo n.º 4
0
 def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
     groupby_ctr = ['OneNoHaving', 'TwoNoHaving', 'OneHaving', 'TwoHaving']
     groupby_num = min(2, len(groupby_clause))
     if having_clause:
         ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num + 1]))
         having_field = ast_node.fields[-1]
         having_field.add_value(self.parse_conds(having_clause))
     else:
         ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num - 1]))
     for i, col_unit in enumerate(groupby_clause):
         if i >= 2: break
         # ast_node.fields[i].add_value(int(col_unit[1]))
         ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
     groupby_field.add_value(ast_node)
Ejemplo n.º 5
0
def lisp_node_to_ast(grammar, lisp_tokens, start_idx):
    node_name = lisp_tokens[start_idx]
    i = start_idx
    if node_name in [
            '_eq', 'select', 'filter', '_parts', '_time', '_inspect',
            'between', '_and', '_or', 'renew', 'cancel'
    ]:
        # it's a predicate
        prod = grammar.get_prod_by_ctr_name('apply')
        pred_field = RealizedField(prod['predicate'], value=node_name)

        arg_ast_nodes = []
        while True:
            i += 1
            lisp_token = lisp_tokens[i]
            if lisp_token == "(":
                arg_ast_node, end_idx = lisp_expr_to_ast_helper(
                    grammar, lisp_tokens, i)
            elif lisp_token == ")":
                i += 1
                break
            else:
                prod1 = grammar.get_prod_by_ctr_name('Literal')
                arg_ast_node, end_idx = AbstractSyntaxTree(
                    prod1,
                    [RealizedField(prod1['literal'], value=lisp_tokens[i])]), i

            arg_ast_nodes.append(arg_ast_node)

            i = end_idx
            if i >= len(lisp_tokens): break
            if lisp_tokens[i] == ')':
                i += 1
                break

        arg_field = RealizedField(prod['arguments'], arg_ast_nodes)
        ast_node = AbstractSyntaxTree(prod, [pred_field, arg_field])
    elif node_name.endswith('id0') or node_name.endswith('id1') or node_name.endswith('id2') \
            or node_name in ['periodid0', 'periodid1']:
        # it's a literal
        prod = grammar.get_prod_by_ctr_name('Literal')

        ast_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['literal'], value=node_name)])

        i += 1
    else:
        raise NotImplementedError

    return ast_node, i
def prolog_node_to_ast(grammar, prolog_tokens, start_idx):
    node_name = prolog_tokens[start_idx]
    i = start_idx
    if node_name in [
            'job', 'language', 'loc', 'req_deg', 'application', 'area',
            'company', 'des_deg', 'des_exp', 'platform', 'recruiter',
            'req_exp', 'salary_greater_than', 'salary_less_than', 'title'
    ]:
        # it's a predicate
        prod = grammar.get_prod_by_ctr_name('Apply')
        pred_field = RealizedField(prod['predicate'], value=node_name)

        arg_ast_nodes = []
        i += 1
        assert prolog_tokens[i] == '('
        while True:
            i += 1
            arg_ast_node, end_idx = prolog_node_to_ast(grammar, prolog_tokens,
                                                       i)
            arg_ast_nodes.append(arg_ast_node)

            i = end_idx
            if i >= len(prolog_tokens): break
            if prolog_tokens[i] == ')':
                i += 1
                break

            assert prolog_tokens[i] == ','

        arg_field = RealizedField(prod['arguments'], arg_ast_nodes)
        ast_node = AbstractSyntaxTree(prod, [pred_field, arg_field])
    elif node_name in ['ANS', 'X', 'A', 'B', 'P', 'J']:
        # it's a variable
        prod = grammar.get_prod_by_ctr_name('Variable')
        ast_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['variable'], value=node_name)])

        i += 1
    elif node_name.endswith('id0') or node_name.endswith('id1') or node_name.endswith('id2') \
            or node_name in ['20', 'hour', 'num_salary', 'year', 'year0', 'year1', 'month']:
        # it's a literal
        prod = grammar.get_prod_by_ctr_name('Literal')
        ast_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['literal'], value=node_name)])

        i += 1
    else:
        raise NotImplementedError

    return ast_node, i
Ejemplo n.º 7
0
 def parse_select(self, select_clause: list, select_field: RealizedField):
     select_clause = select_clause[1] # list of (agg, val_unit), ignore distinct flag
     select_num = min(5, len(select_clause))
     select_ctr = ['SelectOne', 'SelectTwo', 'SelectThree', 'SelectFour', 'SelectFive']
     ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(select_ctr[select_num - 1]))
     for i, (agg, val_unit) in enumerate(select_clause):
         if i >= 5: break
         if agg != 0: # MAX/MIN/COUNT/SUM/AVG
             val_unit_ast = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary'))
             col_unit = [agg] + val_unit[1][1:]
             val_unit_ast.fields[0].add_value(self.parse_col_unit(col_unit))
         else:
             val_unit_ast = self.parse_val_unit(val_unit)
         ast_node.fields[i].add_value(val_unit_ast)
     select_field.add_value(ast_node)
Ejemplo n.º 8
0
def python_ast_to_asdl_ast(py_ast_node, grammar):
    # node should be composite
    py_node_name = type(py_ast_node).__name__
    # assert py_node_name.startswith('_ast.')

    production = grammar.get_prod_by_ctr_name(py_node_name)

    fields = []
    for field in production.fields:
        field_value = getattr(py_ast_node, field.name)
        asdl_field = RealizedField(field)
        if field.cardinality == 'single' or field.cardinality == 'optional':
            if field_value is not None:  # sometimes it could be 0
                if grammar.is_composite_type(field.type):
                    child_node = python_ast_to_asdl_ast(field_value, grammar)
                    asdl_field.add_value(child_node)
                else:
                    asdl_field.add_value(str(field_value))
        # field with multiple cardinality
        elif field_value is not None:
            if grammar.is_composite_type(field.type):
                for val in field_value:
                    child_node = python_ast_to_asdl_ast(val, grammar)
                    asdl_field.add_value(child_node)
            else:
                for val in field_value:
                    asdl_field.add_value(str(val))

        fields.append(asdl_field)

    asdl_node = AbstractSyntaxTree(production, realized_fields=fields)

    return asdl_node
Ejemplo n.º 9
0
def lisp_expr_to_ast_helper(grammar, lisp_tokens, start_idx=0):
    i = start_idx
    if lisp_tokens[i] == '(':
        i += 1

    parsed_nodes = []
    while True:
        if lisp_tokens[i] == '(':
            ast_node, end_idx = lisp_expr_to_ast_helper(
                grammar, lisp_tokens, i)
            parsed_nodes.append(ast_node)
            i = end_idx
        else:
            ast_node, end_idx = lisp_node_to_ast(grammar, lisp_tokens, i)
            parsed_nodes.append(ast_node)
            i = end_idx

        if i >= len(lisp_tokens): break
        if lisp_tokens[i] == ')':
            # i += 1
            break

        if lisp_tokens[i] == ' ':
            # and
            i += 1

    assert parsed_nodes
    if len(parsed_nodes) > 1:
        prod = grammar.get_prod_by_ctr_name('And')
        return_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['arguments'], parsed_nodes)])
    else:
        return_node = parsed_nodes[0]
    return return_node, i
Ejemplo n.º 10
0
 def parse_groupby(self, groupby_clause: list, having_clause: list,
                   groupby_field: RealizedField):
     col_ids = []
     for col_unit in groupby_clause:
         col_ids.append(col_unit[1])  # agg is None and isDistinct False
     if having_clause:
         ast_node = AbstractSyntaxTree(
             self.grammar.get_prod_by_ctr_name('Having'))
         col_units_field, having_fields = ast_node.fields
         having_fields.add_value(self.parse_conds(having_clause))
     else:
         ast_node = AbstractSyntaxTree(
             self.grammar.get_prod_by_ctr_name('NoHaving'))
         col_units_field = ast_node.fields[0]
     for col_unit in groupby_clause:
         col_units_field.add_value(self.parse_col_unit(col_unit))
     groupby_field.add_value(ast_node)
Ejemplo n.º 11
0
 def parse_from(self, from_clause: dict, from_field: RealizedField):
     """ Ignore from conditions, since it is not evaluated in evaluation script
     """
     table_units = from_clause['table_units']
     t = table_units[0][0]
     if t == 'table_unit':
         table_num = min(6, len(table_units))
         table_ctr = ['FromOneTable', 'FromTwoTable', 'FromThreeTable', 'FromFourTable', 'FromFiveTable', 'FromSixTable']
         ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(table_ctr[table_num - 1]))
         for i, (_, tab_id) in enumerate(table_units):
             if i >= 6: break
             ast_node.fields[i].add_value(int(tab_id))
     else:
         assert t == 'sql'
         v = table_units[0][1]
         ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL'))
         ast_node.fields[0].add_value(self.parse_sql(v))
     from_field.add_value(ast_node)
Ejemplo n.º 12
0
 def parse_select(self, select_clause: list, select_field: RealizedField):
     """
         ignore cases agg(col_id1 op col_id2) and agg(col_id1) op agg(col_id2)
     """
     select_clause = select_clause[1]  # list of (agg, val_unit)
     unit_op_list = ['Unary', 'Minus', 'Plus', 'Times', 'Divide']
     agg_op_list = ['None', 'Max', 'Min', 'Count', 'Sum', 'Avg']
     for agg, val_unit in select_clause:
         if agg != 0:  # agg col_id
             ast_node = AbstractSyntaxTree(
                 self.grammar.get_prod_by_ctr_name('Unary'))
             col_node = AbstractSyntaxTree(
                 self.grammar.get_prod_by_ctr_name(agg_op_list[agg]))
             col_node.fields[0].add_value(int(val_unit[1][1]))
             ast_node.fields[0].add_value(col_node)
         else:  # binary_op col_id1 col_id2
             ast_node = self.parse_val_unit(val_unit)
         select_field.add_value(ast_node)
Ejemplo n.º 13
0
def regex_ast_to_asdl_ast(grammar, reg_ast):
    if reg_ast.children:
        rule = _NODE_CLASS_TO_RULE[reg_ast.node_class]
        prod = grammar.get_prod_by_ctr_name(rule)
        # unary
        if rule in ["Not", "Star", "StartWith", "EndWith", "Contain"]:
            child_ast_node = regex_ast_to_asdl_ast(grammar,
                                                   reg_ast.children[0])
            ast_node = AbstractSyntaxTree(
                prod, [RealizedField(prod['arg'], child_ast_node)])
            return ast_node
        elif rule in ["Concat", "And", "Or"]:
            left_ast_node = regex_ast_to_asdl_ast(grammar, reg_ast.children[0])
            right_ast_node = regex_ast_to_asdl_ast(grammar,
                                                   reg_ast.children[1])
            ast_node = AbstractSyntaxTree(prod, [
                RealizedField(prod['left'], left_ast_node),
                RealizedField(prod['right'], right_ast_node)
            ])
            return ast_node
        elif rule in ["RepeatAtleast"]:
            # primitive node
            # RealizedField(prod['predicate'], value=node_name)
            child_ast_node = regex_ast_to_asdl_ast(grammar,
                                                   reg_ast.children[0])
            int_real_node = RealizedField(prod['k'], str(reg_ast.params[0]))
            ast_node = AbstractSyntaxTree(
                prod,
                [RealizedField(prod['arg'], child_ast_node), int_real_node])
            return ast_node
        else:
            raise ValueError("wrong node class", reg_ast.node_class)
    else:
        if reg_ast.node_class in [
                "<num>", "<let>", "<vow>", "<low>", "<cap>", "<any>"
        ]:
            rule = "CharClass"
        elif reg_ast.node_class in ["<m0>", "<m1>", "<m2>", "<m3>"]:
            rule = "Const"
        else:
            raise ValueError("wrong node class", reg_ast.node_class)
        prod = grammar.get_prod_by_ctr_name(rule)
        return AbstractSyntaxTree(
            prod, [RealizedField(prod['arg'], reg_ast.node_class)])
Ejemplo n.º 14
0
 def parse_from(self, from_clause: dict, from_field: RealizedField):
     """
         Ignore from conditions, since it is not evaluated in evaluation script
     """
     table_units = from_clause['table_units']
     t = table_units[0][0]
     if t == 'table_unit':
         ast_node = AbstractSyntaxTree(
             self.grammar.get_prod_by_ctr_name('FromTable'))
         tables_field = ast_node.fields[0]
         for _, v in table_units:
             tables_field.add_value(int(v))
     else:
         assert t == 'sql'
         v = table_units[0][1]
         ast_node = AbstractSyntaxTree(
             self.grammar.get_prod_by_ctr_name('FromSQL'))
         ast_node.fields[0].add_value(self.parse_sql(v))
     from_field.add_value(ast_node)
Ejemplo n.º 15
0
def copy_tree_field(tree: AbstractSyntaxTree, field: RealizedField, bool_w_dummy_reduce=False):
    if bool_w_dummy_reduce:
        new_tree = tree.copy_and_reindex_w_dummy_reduce()
    else:
        new_tree = tree.copy_and_reindex_wo_dummy_reduce()

    root_to_field_trace = []
    cur_field = field
    while cur_field:
        cur_parent_node = cur_field.parent_node
        cur_field_idx = find_by_id(cur_parent_node.fields, cur_field)
        assert cur_field_idx != -1
        root_to_field_trace.append(('field', cur_field_idx))

        cur_parent_node_parent_field = cur_parent_node.parent_field
        if cur_parent_node_parent_field:
            cur_parent_node_idx = find_by_id(cur_parent_node_parent_field.as_value_list, cur_parent_node)
            assert cur_parent_node_idx != -1
            root_to_field_trace.append(('node', cur_parent_node_idx))

        cur_field = cur_parent_node_parent_field

    pointer = new_tree.root_node
    while root_to_field_trace:
        trace = root_to_field_trace.pop()
        if trace[0] == 'field':
            assert isinstance(pointer, AbstractSyntaxNode)
            field_idx = trace[1]
            pointer = pointer.fields[field_idx]
        else:
            assert trace[0] == 'node'
            assert isinstance(pointer, RealizedField)
            node_idx = trace[1]
            pointer = pointer.as_value_list[node_idx]

    assert isinstance(pointer, RealizedField)
    new_field = pointer

    # assert new_tree == tree # not necessary since DummyReduce may have been inserted
    # assert new_field == field

    return new_tree, new_field
Ejemplo n.º 16
0
 def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
     orderby_num = min(2, len(orderby_clause[1]))
     num_str = 'One' if orderby_num == 1 else 'Two'
     order_str = 'Asc' if orderby_clause[0] == 'asc' else 'Desc'
     limit_str = 'Limit' if limit else '' # e.g. OneAsc, TwoDescLimit
     ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(num_str + order_str + limit_str))
     for i, val_unit in enumerate(orderby_clause[1]):
         if i >= 2: break
         col_unit = val_unit[1]
         ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
         # ast_node.fields[i].add_value(self.parse_val_unit(val_unit))
     orderby_field.add_value(ast_node)
Ejemplo n.º 17
0
    def __init__(self,
                 init_tree_w_dummy_reduce: AbstractSyntaxTree,
                 bool_copy_subtree=False,
                 tree=None,
                 memory=None,
                 memory_type='all_init_joint',
                 init_code_tokens=None,
                 length_norm=False):
        self.init_tree_w_dummy_reduce = init_tree_w_dummy_reduce
        self.bool_copy_subtree = bool_copy_subtree
        assert memory_type in ('all_init_joint', 'all_init_distinct',
                               'deleted_distinct')
        self.memory_type = memory_type
        self.init_code_tokens = init_code_tokens
        self.length_norm = length_norm

        if tree is not None:
            self.tree = tree
        else:
            self.tree = init_tree_w_dummy_reduce.copy()

        if bool_copy_subtree and memory is None:
            if self.memory_type == 'all_init_joint':
                self.memory = stack_subtrees(
                    self.init_tree_w_dummy_reduce.root_node)
            elif self.memory_type == 'all_init_distinct':
                self.memory = []
                for node in stack_subtrees(
                        self.init_tree_w_dummy_reduce.root_node):
                    if node not in self.memory:
                        self.memory.append(node)
            else:
                self.memory = []
        else:
            self.memory = memory
        # self.set_tree_all_finish() # redundant?

        self.edits = []
        self.score_per_edit = []
        self.score = 0.

        self.repr2field = {}
        self.open_del_node_and_ids = []  # nodes available to delete
        self.open_add_fields = []  # fields open to add nodes
        self.restricted_frontier_fields = [
        ]  # fields (esp. with single cardinality) grammatically need to fill
        self.update_frontier_info()

        # record the current time step
        self.last_edit_field_node = None  # trace the last edit
        self.t = 0
        self.stop_t = None
Ejemplo n.º 18
0
 def parse_sql_unit(self, sql: dict):
     """ Parse a single sql unit, determine the existence of different clauses
     """
     ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('SQL'))
     from_field, select_field, where_field, groupby_field, orderby_field = ast_node.fields
     self.parse_from(sql['from'], from_field)
     self.parse_select(sql['select'], select_field)
     if sql['where']:
         self.parse_where(sql['where'], where_field)
     if sql['groupBy']:  # if having clause is not empty, groupBy must exist
         self.parse_groupby(sql['groupBy'], sql['having'], groupby_field)
     if sql['orderBy']:  # if limit is not None, orderBY is not empty
         self.parse_orderby(sql['orderBy'], sql['limit'], orderby_field)
     return ast_node
Ejemplo n.º 19
0
    def get_ast_from_json_obj(self, json_obj: Dict):
        """read an AST from serialized JSON string"""
        # FIXME: cyclic import
        from asdl.asdl_ast import AbstractSyntaxNode, RealizedField, SyntaxToken, AbstractSyntaxTree

        def get_subtree(entry, parent_field, next_available_id):
            if entry is None:
                return None, next_available_id

            constructor_name = entry['Constructor']

            # terminal case
            if constructor_name == 'SyntaxToken':
                if entry['Value'] is None:
                    return None, next_available_id  # return None for optional field whose value is null

                token = SyntaxToken(parent_field.type,
                                    entry['Value'],
                                    position=entry['Position'],
                                    id=next_available_id)
                next_available_id += 1

                return token, next_available_id

            field_entries = entry['Fields']
            node_id = next_available_id
            next_available_id += 1
            prod = self.get_prod_by_ctr_name(constructor_name)
            realized_fields = []
            for field in prod.constructor.fields:
                field_value = field_entries[field.name]

                if isinstance(field_value, list):
                    assert 'SyntaxList' in field.type.name

                    sub_ast_id = next_available_id
                    next_available_id += 1

                    sub_ast_prod = self.get_prod_by_ctr_name(field.type.name)
                    sub_ast_constr_field = sub_ast_prod.constructor.fields[0]
                    sub_ast_field_values = []
                    for field_child_entry in field_value:
                        child_sub_ast, next_available_id = get_subtree(
                            field_child_entry,
                            sub_ast_constr_field,
                            next_available_id=next_available_id)
                        sub_ast_field_values.append(child_sub_ast)

                    sub_ast = AbstractSyntaxNode(sub_ast_prod, [
                        RealizedField(sub_ast_constr_field,
                                      sub_ast_field_values)
                    ],
                                                 id=sub_ast_id)

                    # FIXME: have a global mark_finished method!
                    for sub_ast_field in sub_ast.fields:
                        if sub_ast_field.cardinality in ('multiple',
                                                         'optional'):
                            sub_ast_field._not_single_cardinality_finished = True

                    realized_field = RealizedField(field, sub_ast)
                else:
                    # if the child is an AST or terminal SyntaxNode
                    sub_ast, next_available_id = get_subtree(
                        field_value, field, next_available_id)
                    realized_field = RealizedField(field, sub_ast)

                realized_fields.append(realized_field)

            ast_node = AbstractSyntaxNode(prod, realized_fields, id=node_id)
            for field in ast_node.fields:
                if field.cardinality in ('multiple', 'optional'):
                    field._not_single_cardinality_finished = True

            return ast_node, next_available_id

        ast_root, _ = get_subtree(json_obj,
                                  parent_field=None,
                                  next_available_id=0)
        ast = AbstractSyntaxTree(ast_root)

        return ast
Ejemplo n.º 20
0
def streg_ast_to_asdl_ast(grammar, reg_ast):
    if reg_ast.children:
        rule = _NODE_CLASS_TO_RULE[reg_ast.node_class]
        prod = grammar.get_prod_by_ctr_name(rule)
        # unary
        if rule in [
                "Not", "Star", "StartWith", "EndWith", "Contain", "NotCC",
                "Optional"
        ]:
            child_ast_node = streg_ast_to_asdl_ast(grammar,
                                                   reg_ast.children[0])
            ast_node = AbstractSyntaxTree(
                prod, [RealizedField(prod['arg'], child_ast_node)])
            return ast_node
        elif rule in ["Concat", "And", "Or"]:
            left_ast_node = streg_ast_to_asdl_ast(grammar, reg_ast.children[0])
            right_ast_node = streg_ast_to_asdl_ast(grammar,
                                                   reg_ast.children[1])
            ast_node = AbstractSyntaxTree(prod, [
                RealizedField(prod['left'], left_ast_node),
                RealizedField(prod['right'], right_ast_node)
            ])
            return ast_node
        elif rule in ["RepeatAtleast", "Repeat"]:
            # primitive node
            # RealizedField(prod['predicate'], value=node_name)
            child_ast_node = streg_ast_to_asdl_ast(grammar,
                                                   reg_ast.children[0])
            int_real_node = RealizedField(prod['k'], str(reg_ast.params[0]))
            ast_node = AbstractSyntaxTree(
                prod,
                [RealizedField(prod['arg'], child_ast_node), int_real_node])
            return ast_node
        elif rule in ["RepeatRange"]:
            child_ast_node = streg_ast_to_asdl_ast(grammar,
                                                   reg_ast.children[0])
            int_real_node1 = RealizedField(prod['k1'], str(reg_ast.params[0]))
            int_real_node2 = RealizedField(prod['k2'], str(reg_ast.params[1]))
            ast_node = AbstractSyntaxTree(prod, [
                RealizedField(prod['arg'], child_ast_node), int_real_node1,
                int_real_node2
            ])
            return ast_node
        elif rule in ["String"]:
            return AbstractSyntaxTree(
                prod,
                [RealizedField(prod['arg'], reg_ast.children[0].node_class)])
        else:
            raise ValueError("wrong node class", reg_ast.node_class)
    else:
        if reg_ast.node_class in [
                "<num>", "<let>", "<spec>", "<low>", "<cap>", "<any>"
        ]:
            rule = "CharClass"
        elif reg_ast.node_class.startswith(
                "const") and reg_ast.node_class[5:].isdigit():
            rule = "ConstSym"
        elif reg_ast.node_class.startswith(
                "<") and reg_ast.node_class.endswith(">"):
            rule = "Token"
        else:
            raise ValueError("wrong node class", reg_ast.node_class)
        prod = grammar.get_prod_by_ctr_name(rule)
        return AbstractSyntaxTree(
            prod, [RealizedField(prod['arg'], reg_ast.node_class)])
Ejemplo n.º 21
0
def pdf_to_ast(grammar, lf_node):
    if lf_node.name.startswith('obj'):
        # obj = Objective(id name, expr* hdr)
        prod = grammar.get_prod_by_ctr_name('Objective')

        id_field = RealizedField(prod['name'], value=lf_node.name)

        hdr_ast_nodes = []
        for hdr_node in lf_node.children:
            hdr_ast_node = pdf_to_ast(grammar, hdr_node)
            hdr_ast_nodes.append(hdr_ast_node)
        hdr_field = RealizedField(prod['hdr'], hdr_ast_nodes)

        ast_node = AbstractSyntaxTree(prod, [id_field, hdr_field])

    elif lf_node.name in [
            'Type', 'SubType', 'Size', 'Length', 'Kids', 'Parent', 'Count',
            'Limits', 'Range', 'Filter', 'Domain', 'FuncType', 'Pages',
            'MediaBox', 'Resources'
    ]:
        # expr -> Apply(pred predicate, expr* arguments)
        prod = grammar.get_prod_by_ctr_name('Apply')

        pred_field = RealizedField(prod['predicate'], value=lf_node.name)

        arg_ast_nodes = []
        for arg_node in lf_node.children:
            arg_ast_node = pdf_to_ast(grammar, arg_node)
            arg_ast_nodes.append(arg_ast_node)
        arg_field = RealizedField(prod['arguments'], arg_ast_nodes)

        ast_node = AbstractSyntaxTree(prod, [pred_field, arg_field])

    elif lf_node.name.startswith('S'):
        # expr = Variable(var_type type, var variable)
        prod = grammar.get_prod_by_ctr_name('Variable')

        var_type_field = RealizedField(prod['type'], value='string')
        var_field = RealizedField(prod['variable'], value=lf_node.name[1:])

        ast_node = AbstractSyntaxTree(prod, [var_type_field, var_field])
    elif lf_node.name.startswith('I'):
        prod = grammar.get_prod_by_ctr_name('Variable')

        var_type_field = RealizedField(prod['type'], value='int')
        var_field = RealizedField(prod['variable'], value=lf_node.name[1:])

        ast_node = AbstractSyntaxTree(prod, [var_type_field, var_field])
    elif lf_node.name.startswith('H'):
        prod = grammar.get_prod_by_ctr_name('Variable')

        var_type_field = RealizedField(prod['type'], value='header')
        var_field = RealizedField(prod['variable'], value=lf_node.name[1:])

        ast_node = AbstractSyntaxTree(prod, [var_type_field, var_field])

    elif lf_node.name.startswith('R'):
        # expr = Reference(id ref)
        prod = grammar.get_prod_by_ctr_name('Reference')

        ref_var = 'obj' + lf_node.name[1:]
        ref_field = RealizedField(prod['ref'], value=ref_var)

        ast_node = AbstractSyntaxTree(prod, [ref_field])

    else:
        raise NotImplementedError

    return ast_node
Ejemplo n.º 22
0
def logical_form_to_ast(grammar, lf_node):
    if lf_node.name == 'lambda':
        # expr -> Lambda(var variable, var_type type, expr body)
        prod = grammar.get_prod_by_ctr_name('Lambda')

        var_node = lf_node.children[0]
        var_field = RealizedField(prod['variable'], var_node.name)

        var_type_node = lf_node.children[1]
        var_type_field = RealizedField(prod['type'], var_type_node.name)

        body_node = lf_node.children[2]
        body_ast_node = logical_form_to_ast(grammar, body_node)  # of type expr
        body_field = RealizedField(prod['body'], body_ast_node)

        ast_node = AbstractSyntaxTree(prod,
                                      [var_field, var_type_field, body_field])
    elif lf_node.name == 'argmax' or lf_node.name == 'argmin' or lf_node.name == 'sum':
        # expr -> Argmax|Sum(var variable, expr domain, expr body)

        prod = grammar.get_prod_by_ctr_name(lf_node.name.title())

        var_node = lf_node.children[0]
        var_field = RealizedField(prod['variable'], var_node.name)

        domain_node = lf_node.children[1]
        domain_ast_node = logical_form_to_ast(grammar, domain_node)
        domain_field = RealizedField(prod['domain'], domain_ast_node)

        body_node = lf_node.children[2]
        body_ast_node = logical_form_to_ast(grammar, body_node)
        body_field = RealizedField(prod['body'], body_ast_node)

        ast_node = AbstractSyntaxTree(prod,
                                      [var_field, domain_field, body_field])
    elif lf_node.name == 'and' or lf_node.name == 'or':
        # expr -> And(expr* arguments) | Or(expr* arguments)
        prod = grammar.get_prod_by_ctr_name(lf_node.name.title())

        arg_ast_nodes = []
        for arg_node in lf_node.children:
            arg_ast_node = logical_form_to_ast(grammar, arg_node)
            arg_ast_nodes.append(arg_ast_node)

        ast_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['arguments'], arg_ast_nodes)])
    elif lf_node.name == 'not':
        # expr -> Not(expr argument)
        prod = grammar.get_prod_by_ctr_name('Not')

        arg_ast_node = logical_form_to_ast(grammar, lf_node.children[0])

        ast_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['argument'], arg_ast_node)])
    elif lf_node.name == '>' or lf_node.name == '=' or lf_node.name == '<':
        # expr -> Compare(cmp_op op, expr left, expr right)
        prod = grammar.get_prod_by_ctr_name('Compare')
        op_name = 'GreaterThan' if lf_node.name == '>' else 'Equal' if lf_node.name == '=' else 'LessThan'
        op_field = RealizedField(
            prod['op'],
            AbstractSyntaxTree(grammar.get_prod_by_ctr_name(op_name)))

        left_node = lf_node.children[0]
        left_ast_node = logical_form_to_ast(grammar, left_node)
        left_field = RealizedField(prod['left'], left_ast_node)

        right_node = lf_node.children[1]
        right_ast_node = logical_form_to_ast(grammar, right_node)
        right_field = RealizedField(prod['right'], right_ast_node)

        ast_node = AbstractSyntaxTree(prod,
                                      [op_field, left_field, right_field])
    elif lf_node.name in [
            'jet', 'flight', 'from_airport', 'airport', 'airline',
            'airline_name', 'class_type', 'aircraft_code', 'aircraft_code:t',
            'from', 'to', 'day', 'month', 'year', 'arrival_time', 'limousine',
            'departure_time', 'meal', 'meal:t', 'meal_code', 'during_day',
            'tomorrow', 'daily', 'time_elapsed', 'time_zone_code',
            'booking_class:t', 'booking_class', 'economy', 'ground_fare',
            'class_of_service', 'capacity', 'weekday', 'today', 'turboprop',
            'aircraft', 'air_taxi_operation', 'month_return', 'day_return',
            'day_number_return', 'minimum_connection_time',
            'during_day_arrival', 'connecting', 'minutes_distant', 'named',
            'miles_distant', 'approx_arrival_time', 'approx_return_time',
            'approx_departure_time', 'has_stops', 'day_after_tomorrow',
            'manufacturer', 'discounted', 'overnight', 'nonstop', 'has_meal',
            'round_trip', 'oneway', 'loc:t', 'ground_transport', 'to_city',
            'flight_number', 'equals:t', 'abbrev', 'equals', 'rapid_transit',
            'stop_arrival_time', 'arrival_month', 'cost', 'fare', 'services',
            'fare_basis_code', 'rental_car', 'city', 'stop', 'day_number',
            'days_from_today', 'after_day', 'before_day', 'airline:e', 'stops',
            'month_arrival', 'day_number_arrival', 'day_arrival', 'taxi',
            'next_days', 'restriction_code', 'tomorrow_arrival', 'tonight',
            'population:i', 'state:t', 'next_to:t', 'elevation:i', 'size:i',
            'capital:t', 'len:i', 'city:t', 'named:t', 'river:t', 'place:t',
            'capital:c', 'major:t', 'town:t', 'mountain:t', 'lake:t', 'area:i',
            'density:i', 'high_point:t', 'elevation:t', 'population:t', 'in:t'
    ]:
        # expr -> Apply(pred predicate, expr* arguments)
        prod = grammar.get_prod_by_ctr_name('Apply')

        pred_field = RealizedField(prod['predicate'], value=lf_node.name)

        arg_ast_nodes = []
        for arg_node in lf_node.children:
            arg_ast_node = logical_form_to_ast(grammar, arg_node)
            arg_ast_nodes.append(arg_ast_node)
        arg_field = RealizedField(prod['arguments'], arg_ast_nodes)

        ast_node = AbstractSyntaxTree(prod, [pred_field, arg_field])
    elif lf_node.name.startswith('$'):
        prod = grammar.get_prod_by_ctr_name('Variable')
        ast_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['variable'], value=lf_node.name)])
    elif ':ap' in lf_node.name or ':fb' in lf_node.name or ':mf' in lf_node.name or \
                    ':me' in lf_node.name or ':cl' in lf_node.name or ':pd' in lf_node.name or \
                    ':dc' in lf_node.name or ':al' in lf_node.name or \
                    lf_node.name in ['yr0', 'do0', 'fb1', 'rc0', 'ci0', 'fn0', 'ap0', 'al1', 'al2', 'ap1', 'ci1',
                                     'ci2', 'ci3', 'st0', 'ti0', 'ti1', 'da0', 'da1', 'da2', 'da3', 'da4', 'al0',
                                     'fb0', 'dn0', 'dn1', 'mn0', 'ac0', 'fn1', 'st1', 'st2',
                                     'c0', 'm0', 's0', 'r0', 'n0', 'co0', 'usa:co', 'death_valley:lo', 's1',
                                     'colorado:n']:
        prod = grammar.get_prod_by_ctr_name('Entity')
        ast_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['entity'], value=lf_node.name)])
    elif lf_node.name.endswith(':i') or lf_node.name.endswith(':hr'):
        prod = grammar.get_prod_by_ctr_name('Number')
        ast_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['number'], value=lf_node.name)])
    elif lf_node.name == 'the':
        # expr -> The(var variable, expr body)
        prod = grammar.get_prod_by_ctr_name('The')

        var_node = lf_node.children[0]
        var_field = RealizedField(prod['variable'], var_node.name)

        body_node = lf_node.children[1]
        body_ast_node = logical_form_to_ast(grammar, body_node)
        body_field = RealizedField(prod['body'], body_ast_node)

        ast_node = AbstractSyntaxTree(prod, [var_field, body_field])
    elif lf_node.name == 'exists' or lf_node.name == 'max' or lf_node.name == 'min' or lf_node.name == 'count':
        # expr -> Exists(var variable, expr body)
        prod = grammar.get_prod_by_ctr_name(lf_node.name.title())

        var_node = lf_node.children[0]
        var_field = RealizedField(prod['variable'], var_node.name)

        body_node = lf_node.children[1]
        body_ast_node = logical_form_to_ast(grammar, body_node)
        body_field = RealizedField(prod['body'], body_ast_node)

        ast_node = AbstractSyntaxTree(prod, [var_field, body_field])
    else:
        raise NotImplementedError

    return ast_node
Ejemplo n.º 23
0
class Hypothesis(object):
    def __init__(self):
        self.tree = None
        self.actions = []
        self.score = 0.
        self.frontier_node = None
        self.frontier_field = None
        self._value_buffer = []

        # record the current time step
        self.t = 0

    def apply_action(self, action):
        if self.tree is None:
            assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \
                                                        'at the beginning of decoding'

            self.tree = AbstractSyntaxTree(action.production)
            self.update_frontier_info()
        elif self.frontier_node:
            if isinstance(self.frontier_field.type, ASDLCompositeType):
                if isinstance(action, ApplyRuleAction):
                    field_value = AbstractSyntaxTree(action.production)
                    field_value.created_time = self.t
                    self.frontier_field.add_value(field_value)
                    self.update_frontier_info()
                elif isinstance(action, ReduceAction):
                    assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
                                                                                        'applied on field with multiple ' \
                                                                                        'cardinality'
                    self.frontier_field.set_finish()
                    self.update_frontier_info()
                else:
                    raise ValueError('Invalid action [%s] on field [%s]' %
                                     (action, self.frontier_field))
            else:  # fill in a primitive field
                if isinstance(action, GenTokenAction):
                    # only field of type string requires termination signal </primitive>
                    end_primitive = False
                    if self.frontier_field.type.name == 'string':
                        if action.is_stop_signal():
                            self.frontier_field.add_value(' '.join(
                                self._value_buffer))
                            self._value_buffer = []

                            end_primitive = True
                        else:
                            self._value_buffer.append(action.token)
                    else:
                        self.frontier_field.add_value(action.token)
                        end_primitive = True

                    if end_primitive and self.frontier_field.cardinality in (
                            'single', 'optional'):
                        self.frontier_field.set_finish()
                        self.update_frontier_info()

                elif isinstance(action, ReduceAction):
                    assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
                                                                                        'applied on field with multiple ' \
                                                                                        'cardinality'
                    self.frontier_field.set_finish()
                    self.update_frontier_info()
                else:
                    raise ValueError(
                        'Can only invoke GenToken or Reduce actions on primitive fields'
                    )

        self.t += 1
        self.actions.append(action)

    def update_frontier_info(self):
        def _find_frontier_node_and_field(tree_node):
            if tree_node:
                for field in tree_node.fields:
                    # if it's an intermediate node, check its children
                    if isinstance(field.type,
                                  ASDLCompositeType) and field.value:
                        if field.cardinality in ('single', 'optional'):
                            iter_values = [field.value]
                        else:
                            iter_values = field.value

                        for child_node in iter_values:
                            result = _find_frontier_node_and_field(child_node)
                            if result: return result

                    # now all its possible children are checked
                    if not field.finished:
                        return tree_node, field

                return None
            else:
                return None

        frontier_info = _find_frontier_node_and_field(self.tree)
        if frontier_info:
            self.frontier_node, self.frontier_field = frontier_info
        else:
            self.frontier_node, self.frontier_field = None, None

    def clone_and_apply_action(self, action):
        new_hyp = self.copy()
        new_hyp.apply_action(action)

        return new_hyp

    def copy(self):
        new_hyp = Hypothesis()
        if self.tree:
            new_hyp.tree = self.tree.copy()

        new_hyp.actions = list(self.actions)
        new_hyp.score = self.score
        new_hyp._value_buffer = list(self._value_buffer)
        new_hyp.t = self.t

        new_hyp.update_frontier_info()

        return new_hyp

    @property
    def completed(self):
        return self.tree and self.frontier_field is None
Ejemplo n.º 24
0
def pdf_to_ast(grammar, x, tr):
    if len(tr) >= 300:
        raise NotImplementedError
    if isinstance(x, PdfDict):
        prod = grammar.get_prod_by_ctr_name('PdfDict')

        ast_nodes = []
        for y in x:
            prod_ = grammar.get_prod_by_ctr_name('Apply')

            pred_field = RealizedField(prod_['name'], value=str(y))

            if y in ['/Parent', '/P', '/Dest', '/Prev']:
                op_field = RealizedField(prod_['op'])

            else:
                tr.append(y)
                args_ast_node = pdf_to_ast(grammar, x[y], tr)
                del tr[len(tr) - 1]
                op_field = RealizedField(prod_['op'], value=args_ast_node)

            ast_node_ = AbstractSyntaxTree(prod_, [pred_field, op_field])
            ast_nodes.append(ast_node_)

        if x.stream:
            prod_ = grammar.get_prod_by_ctr_name('PdfString')

            var_field = RealizedField(prod_['value'], value=str(x.stream))

            ast_node = AbstractSyntaxTree(prod_, [var_field])

            ast_nodes.append(ast_node)

        arg_field = RealizedField(prod['args'], ast_nodes)

        ast_node = AbstractSyntaxTree(prod, [arg_field]) 

    elif isinstance(x, PdfObject):
        prod = grammar.get_prod_by_ctr_name('PdfObject')

        var_field = RealizedField(prod['value'], value=str(x))

        ast_node = AbstractSyntaxTree(prod, [var_field])

    elif isinstance(x, PdfArray):
        dict_nodes = []
        list_nodes = []
        for y in x:
            if isinstance(y, PdfDict):
                args_ast_node = pdf_to_ast(grammar, y, tr)
                dict_nodes.append(args_ast_node)

            elif isinstance(y, PdfArray):
                raise NotImplementedError

            elif isinstance(y, BasePdfName):
                list_nodes.append(str(y))

            elif isinstance(y, PdfObject):
                list_nodes.append(str(y))


        if dict_nodes:
            prod = grammar.get_prod_by_ctr_name('PdfArray')

            arg_field = RealizedField(prod['args'], dict_nodes)

            ast_node = AbstractSyntaxTree(prod, [arg_field]) 

        else:
            prod = grammar.get_prod_by_ctr_name('PdfList')

            var_field = RealizedField(prod['value'], value=tuple(list_nodes))

            ast_node = AbstractSyntaxTree(prod, [var_field])

    elif isinstance(x, PdfString):
        prod = grammar.get_prod_by_ctr_name('PdfString')

        var_field = RealizedField(prod['value'], value=str(x))

        ast_node = AbstractSyntaxTree(prod, [var_field])

    elif isinstance(x, BasePdfName):
        prod = grammar.get_prod_by_ctr_name('BasePdfName')

        var_field = RealizedField(prod['value'], value=str(x))

        ast_node = AbstractSyntaxTree(prod, [var_field])

    else:
        print(type(x))
        raise NotImplementedError

    return ast_node
def prolog_expr_to_ast_helper(grammar, prolog_tokens, start_idx=0):
    i = start_idx
    if prolog_tokens[i] == '(':
        i += 1

    parsed_nodes = []
    while True:
        if prolog_tokens[i] == '\\+':
            # expr -> Not(expr argument)
            prod = grammar.get_prod_by_ctr_name('Not')
            i += 1
            if prolog_tokens[i] == '(':
                arg_ast_node, end_idx = prolog_expr_to_ast_helper(
                    grammar, prolog_tokens, i)
            else:
                arg_ast_node, end_idx = prolog_node_to_ast(
                    grammar, prolog_tokens, i)
            i = end_idx

            assert arg_ast_node.production.type.name == 'expr'
            ast_node = AbstractSyntaxTree(
                prod, [RealizedField(prod['argument'], arg_ast_node)])

            parsed_nodes.append(ast_node)
        elif prolog_tokens[i] == '(':
            ast_node, end_idx = prolog_expr_to_ast_helper(
                grammar, prolog_tokens, i)
            parsed_nodes.append(ast_node)
            i = end_idx
        else:
            ast_node, end_idx = prolog_node_to_ast(grammar, prolog_tokens, i)
            parsed_nodes.append(ast_node)
            i = end_idx

        if i >= len(prolog_tokens): break
        if prolog_tokens[i] == ')':
            i += 1
            break

        if prolog_tokens[i] == ',':
            # and
            i += 1
        elif prolog_tokens[i] == ';':
            # Or
            prod = grammar.get_prod_by_ctr_name('Or')

            assert parsed_nodes
            if len(parsed_nodes) == 1:
                left_ast_node = parsed_nodes[0]
            else:
                left_expr_prod = grammar.get_prod_by_ctr_name('And')
                left_ast_node = AbstractSyntaxTree(
                    left_expr_prod,
                    [RealizedField(left_expr_prod['arguments'], parsed_nodes)])
                parsed_nodes = []

            # get the right ast node
            i += 1
            right_ast_node, end_idx = prolog_expr_to_ast_helper(
                grammar, prolog_tokens, i)

            ast_node = AbstractSyntaxTree(prod, [
                RealizedField(prod['left'], left_ast_node),
                RealizedField(prod['right'], right_ast_node)
            ])

            i = end_idx
            parsed_nodes = [ast_node]

            if i >= len(prolog_tokens): break
            if prolog_tokens[i] == ')':
                i += 1
                break

    assert parsed_nodes
    if len(parsed_nodes) > 1:
        prod = grammar.get_prod_by_ctr_name('And')
        return_node = AbstractSyntaxTree(
            prod, [RealizedField(prod['arguments'], parsed_nodes)])
    else:
        return_node = parsed_nodes[0]

    return return_node, i