def apply_action(self, action): if self.tree is None: assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \ 'at the beginning of decoding' self.tree = AbstractSyntaxTree(action.production) self.update_frontier_info() elif self.frontier_node: if isinstance(self.frontier_field.type, ASDLCompositeType): if isinstance(action, ApplyRuleAction): field_value = AbstractSyntaxTree(action.production) field_value.created_time = self.t self.frontier_field.add_value(field_value) self.update_frontier_info() elif isinstance(action, ReduceAction): assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 'applied on field with multiple ' \ 'cardinality' self.frontier_field.set_finish() self.update_frontier_info() else: raise ValueError('Invalid action [%s] on field [%s]' % (action, self.frontier_field)) else: # fill in a primitive field if isinstance(action, GenTokenAction): # only field of type string requires termination signal </primitive> end_primitive = False if self.frontier_field.type.name == 'string': if action.is_stop_signal(): self.frontier_field.add_value(' '.join( self._value_buffer)) self._value_buffer = [] end_primitive = True else: self._value_buffer.append(action.token) else: self.frontier_field.add_value(action.token) end_primitive = True if end_primitive and self.frontier_field.cardinality in ( 'single', 'optional'): self.frontier_field.set_finish() self.update_frontier_info() elif isinstance(action, ReduceAction): assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 'applied on field with multiple ' \ 'cardinality' self.frontier_field.set_finish() self.update_frontier_info() else: raise ValueError( 'Can only invoke GenToken or Reduce actions on primitive fields' ) self.t += 1 self.actions.append(action)
def build_change_graph(old_ast: AbstractSyntaxTree, new_ast: AbstractSyntaxTree): equality_links = [] def _modify_id(node, prefix=''): node.id = f'{prefix}-{node.id}' if isinstance(node, AbstractSyntaxNode): for field in node.fields: for field_val in field.as_value_list: _modify_id(field_val, prefix) old_ast_root_copy = old_ast.root_node.copy() _modify_id(old_ast_root_copy, 'old') new_ast_root_copy = new_ast.root_node.copy() _modify_id(new_ast_root_copy, 'new') old_ast = AbstractSyntaxTree(old_ast_root_copy) new_ast = AbstractSyntaxTree(new_ast_root_copy) def _search_common_sub_tree(tgt_ast_node): node_query_result = old_ast.find_node(tgt_ast_node) if node_query_result: src_node_id, src_node = node_query_result tgt_ast_node.parent_field.replace(tgt_ast_node, src_node) # register this link equality_links.append((tgt_ast_node.id, src_node.id)) else: for field in tgt_ast_node.fields: if field.type.is_composite: for field_val in field.as_value_list: _search_common_sub_tree(field_val) _search_common_sub_tree(new_ast.root_node) visited = set() adjacency_list = [] def _visit(node, parent_node): if parent_node: adjacency_list.append((parent_node.id, node.id)) if node.id in visited: return if isinstance(node, AbstractSyntaxNode): for field in node.fields: for field_val in field.as_value_list: _visit(field_val, node) visited.add(node.id) _visit(old_ast.root_node, None) _visit(new_ast.root_node, None) pass
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField): if limit is None: ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Asc')) if orderby_clause[0] == 'asc' \ else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Desc')) else: ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('AscLimit')) if orderby_clause[0] == 'asc' \ else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('DescLimit')) col_units_field = ast_node.fields[0] for val_unit in orderby_clause[1]: col_units_field.add_value(self.parse_col_unit(val_unit[1])) orderby_field.add_value(ast_node)
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField): groupby_ctr = ['OneNoHaving', 'TwoNoHaving', 'OneHaving', 'TwoHaving'] groupby_num = min(2, len(groupby_clause)) if having_clause: ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num + 1])) having_field = ast_node.fields[-1] having_field.add_value(self.parse_conds(having_clause)) else: ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num - 1])) for i, col_unit in enumerate(groupby_clause): if i >= 2: break # ast_node.fields[i].add_value(int(col_unit[1])) ast_node.fields[i].add_value(self.parse_col_unit(col_unit)) groupby_field.add_value(ast_node)
def lisp_node_to_ast(grammar, lisp_tokens, start_idx): node_name = lisp_tokens[start_idx] i = start_idx if node_name in [ '_eq', 'select', 'filter', '_parts', '_time', '_inspect', 'between', '_and', '_or', 'renew', 'cancel' ]: # it's a predicate prod = grammar.get_prod_by_ctr_name('apply') pred_field = RealizedField(prod['predicate'], value=node_name) arg_ast_nodes = [] while True: i += 1 lisp_token = lisp_tokens[i] if lisp_token == "(": arg_ast_node, end_idx = lisp_expr_to_ast_helper( grammar, lisp_tokens, i) elif lisp_token == ")": i += 1 break else: prod1 = grammar.get_prod_by_ctr_name('Literal') arg_ast_node, end_idx = AbstractSyntaxTree( prod1, [RealizedField(prod1['literal'], value=lisp_tokens[i])]), i arg_ast_nodes.append(arg_ast_node) i = end_idx if i >= len(lisp_tokens): break if lisp_tokens[i] == ')': i += 1 break arg_field = RealizedField(prod['arguments'], arg_ast_nodes) ast_node = AbstractSyntaxTree(prod, [pred_field, arg_field]) elif node_name.endswith('id0') or node_name.endswith('id1') or node_name.endswith('id2') \ or node_name in ['periodid0', 'periodid1']: # it's a literal prod = grammar.get_prod_by_ctr_name('Literal') ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['literal'], value=node_name)]) i += 1 else: raise NotImplementedError return ast_node, i
def prolog_node_to_ast(grammar, prolog_tokens, start_idx): node_name = prolog_tokens[start_idx] i = start_idx if node_name in [ 'job', 'language', 'loc', 'req_deg', 'application', 'area', 'company', 'des_deg', 'des_exp', 'platform', 'recruiter', 'req_exp', 'salary_greater_than', 'salary_less_than', 'title' ]: # it's a predicate prod = grammar.get_prod_by_ctr_name('Apply') pred_field = RealizedField(prod['predicate'], value=node_name) arg_ast_nodes = [] i += 1 assert prolog_tokens[i] == '(' while True: i += 1 arg_ast_node, end_idx = prolog_node_to_ast(grammar, prolog_tokens, i) arg_ast_nodes.append(arg_ast_node) i = end_idx if i >= len(prolog_tokens): break if prolog_tokens[i] == ')': i += 1 break assert prolog_tokens[i] == ',' arg_field = RealizedField(prod['arguments'], arg_ast_nodes) ast_node = AbstractSyntaxTree(prod, [pred_field, arg_field]) elif node_name in ['ANS', 'X', 'A', 'B', 'P', 'J']: # it's a variable prod = grammar.get_prod_by_ctr_name('Variable') ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['variable'], value=node_name)]) i += 1 elif node_name.endswith('id0') or node_name.endswith('id1') or node_name.endswith('id2') \ or node_name in ['20', 'hour', 'num_salary', 'year', 'year0', 'year1', 'month']: # it's a literal prod = grammar.get_prod_by_ctr_name('Literal') ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['literal'], value=node_name)]) i += 1 else: raise NotImplementedError return ast_node, i
def parse_select(self, select_clause: list, select_field: RealizedField): select_clause = select_clause[1] # list of (agg, val_unit), ignore distinct flag select_num = min(5, len(select_clause)) select_ctr = ['SelectOne', 'SelectTwo', 'SelectThree', 'SelectFour', 'SelectFive'] ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(select_ctr[select_num - 1])) for i, (agg, val_unit) in enumerate(select_clause): if i >= 5: break if agg != 0: # MAX/MIN/COUNT/SUM/AVG val_unit_ast = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary')) col_unit = [agg] + val_unit[1][1:] val_unit_ast.fields[0].add_value(self.parse_col_unit(col_unit)) else: val_unit_ast = self.parse_val_unit(val_unit) ast_node.fields[i].add_value(val_unit_ast) select_field.add_value(ast_node)
def python_ast_to_asdl_ast(py_ast_node, grammar): # node should be composite py_node_name = type(py_ast_node).__name__ # assert py_node_name.startswith('_ast.') production = grammar.get_prod_by_ctr_name(py_node_name) fields = [] for field in production.fields: field_value = getattr(py_ast_node, field.name) asdl_field = RealizedField(field) if field.cardinality == 'single' or field.cardinality == 'optional': if field_value is not None: # sometimes it could be 0 if grammar.is_composite_type(field.type): child_node = python_ast_to_asdl_ast(field_value, grammar) asdl_field.add_value(child_node) else: asdl_field.add_value(str(field_value)) # field with multiple cardinality elif field_value is not None: if grammar.is_composite_type(field.type): for val in field_value: child_node = python_ast_to_asdl_ast(val, grammar) asdl_field.add_value(child_node) else: for val in field_value: asdl_field.add_value(str(val)) fields.append(asdl_field) asdl_node = AbstractSyntaxTree(production, realized_fields=fields) return asdl_node
def lisp_expr_to_ast_helper(grammar, lisp_tokens, start_idx=0): i = start_idx if lisp_tokens[i] == '(': i += 1 parsed_nodes = [] while True: if lisp_tokens[i] == '(': ast_node, end_idx = lisp_expr_to_ast_helper( grammar, lisp_tokens, i) parsed_nodes.append(ast_node) i = end_idx else: ast_node, end_idx = lisp_node_to_ast(grammar, lisp_tokens, i) parsed_nodes.append(ast_node) i = end_idx if i >= len(lisp_tokens): break if lisp_tokens[i] == ')': # i += 1 break if lisp_tokens[i] == ' ': # and i += 1 assert parsed_nodes if len(parsed_nodes) > 1: prod = grammar.get_prod_by_ctr_name('And') return_node = AbstractSyntaxTree( prod, [RealizedField(prod['arguments'], parsed_nodes)]) else: return_node = parsed_nodes[0] return return_node, i
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField): col_ids = [] for col_unit in groupby_clause: col_ids.append(col_unit[1]) # agg is None and isDistinct False if having_clause: ast_node = AbstractSyntaxTree( self.grammar.get_prod_by_ctr_name('Having')) col_units_field, having_fields = ast_node.fields having_fields.add_value(self.parse_conds(having_clause)) else: ast_node = AbstractSyntaxTree( self.grammar.get_prod_by_ctr_name('NoHaving')) col_units_field = ast_node.fields[0] for col_unit in groupby_clause: col_units_field.add_value(self.parse_col_unit(col_unit)) groupby_field.add_value(ast_node)
def parse_from(self, from_clause: dict, from_field: RealizedField): """ Ignore from conditions, since it is not evaluated in evaluation script """ table_units = from_clause['table_units'] t = table_units[0][0] if t == 'table_unit': table_num = min(6, len(table_units)) table_ctr = ['FromOneTable', 'FromTwoTable', 'FromThreeTable', 'FromFourTable', 'FromFiveTable', 'FromSixTable'] ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(table_ctr[table_num - 1])) for i, (_, tab_id) in enumerate(table_units): if i >= 6: break ast_node.fields[i].add_value(int(tab_id)) else: assert t == 'sql' v = table_units[0][1] ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL')) ast_node.fields[0].add_value(self.parse_sql(v)) from_field.add_value(ast_node)
def parse_select(self, select_clause: list, select_field: RealizedField): """ ignore cases agg(col_id1 op col_id2) and agg(col_id1) op agg(col_id2) """ select_clause = select_clause[1] # list of (agg, val_unit) unit_op_list = ['Unary', 'Minus', 'Plus', 'Times', 'Divide'] agg_op_list = ['None', 'Max', 'Min', 'Count', 'Sum', 'Avg'] for agg, val_unit in select_clause: if agg != 0: # agg col_id ast_node = AbstractSyntaxTree( self.grammar.get_prod_by_ctr_name('Unary')) col_node = AbstractSyntaxTree( self.grammar.get_prod_by_ctr_name(agg_op_list[agg])) col_node.fields[0].add_value(int(val_unit[1][1])) ast_node.fields[0].add_value(col_node) else: # binary_op col_id1 col_id2 ast_node = self.parse_val_unit(val_unit) select_field.add_value(ast_node)
def regex_ast_to_asdl_ast(grammar, reg_ast): if reg_ast.children: rule = _NODE_CLASS_TO_RULE[reg_ast.node_class] prod = grammar.get_prod_by_ctr_name(rule) # unary if rule in ["Not", "Star", "StartWith", "EndWith", "Contain"]: child_ast_node = regex_ast_to_asdl_ast(grammar, reg_ast.children[0]) ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['arg'], child_ast_node)]) return ast_node elif rule in ["Concat", "And", "Or"]: left_ast_node = regex_ast_to_asdl_ast(grammar, reg_ast.children[0]) right_ast_node = regex_ast_to_asdl_ast(grammar, reg_ast.children[1]) ast_node = AbstractSyntaxTree(prod, [ RealizedField(prod['left'], left_ast_node), RealizedField(prod['right'], right_ast_node) ]) return ast_node elif rule in ["RepeatAtleast"]: # primitive node # RealizedField(prod['predicate'], value=node_name) child_ast_node = regex_ast_to_asdl_ast(grammar, reg_ast.children[0]) int_real_node = RealizedField(prod['k'], str(reg_ast.params[0])) ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['arg'], child_ast_node), int_real_node]) return ast_node else: raise ValueError("wrong node class", reg_ast.node_class) else: if reg_ast.node_class in [ "<num>", "<let>", "<vow>", "<low>", "<cap>", "<any>" ]: rule = "CharClass" elif reg_ast.node_class in ["<m0>", "<m1>", "<m2>", "<m3>"]: rule = "Const" else: raise ValueError("wrong node class", reg_ast.node_class) prod = grammar.get_prod_by_ctr_name(rule) return AbstractSyntaxTree( prod, [RealizedField(prod['arg'], reg_ast.node_class)])
def parse_from(self, from_clause: dict, from_field: RealizedField): """ Ignore from conditions, since it is not evaluated in evaluation script """ table_units = from_clause['table_units'] t = table_units[0][0] if t == 'table_unit': ast_node = AbstractSyntaxTree( self.grammar.get_prod_by_ctr_name('FromTable')) tables_field = ast_node.fields[0] for _, v in table_units: tables_field.add_value(int(v)) else: assert t == 'sql' v = table_units[0][1] ast_node = AbstractSyntaxTree( self.grammar.get_prod_by_ctr_name('FromSQL')) ast_node.fields[0].add_value(self.parse_sql(v)) from_field.add_value(ast_node)
def copy_tree_field(tree: AbstractSyntaxTree, field: RealizedField, bool_w_dummy_reduce=False): if bool_w_dummy_reduce: new_tree = tree.copy_and_reindex_w_dummy_reduce() else: new_tree = tree.copy_and_reindex_wo_dummy_reduce() root_to_field_trace = [] cur_field = field while cur_field: cur_parent_node = cur_field.parent_node cur_field_idx = find_by_id(cur_parent_node.fields, cur_field) assert cur_field_idx != -1 root_to_field_trace.append(('field', cur_field_idx)) cur_parent_node_parent_field = cur_parent_node.parent_field if cur_parent_node_parent_field: cur_parent_node_idx = find_by_id(cur_parent_node_parent_field.as_value_list, cur_parent_node) assert cur_parent_node_idx != -1 root_to_field_trace.append(('node', cur_parent_node_idx)) cur_field = cur_parent_node_parent_field pointer = new_tree.root_node while root_to_field_trace: trace = root_to_field_trace.pop() if trace[0] == 'field': assert isinstance(pointer, AbstractSyntaxNode) field_idx = trace[1] pointer = pointer.fields[field_idx] else: assert trace[0] == 'node' assert isinstance(pointer, RealizedField) node_idx = trace[1] pointer = pointer.as_value_list[node_idx] assert isinstance(pointer, RealizedField) new_field = pointer # assert new_tree == tree # not necessary since DummyReduce may have been inserted # assert new_field == field return new_tree, new_field
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField): orderby_num = min(2, len(orderby_clause[1])) num_str = 'One' if orderby_num == 1 else 'Two' order_str = 'Asc' if orderby_clause[0] == 'asc' else 'Desc' limit_str = 'Limit' if limit else '' # e.g. OneAsc, TwoDescLimit ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(num_str + order_str + limit_str)) for i, val_unit in enumerate(orderby_clause[1]): if i >= 2: break col_unit = val_unit[1] ast_node.fields[i].add_value(self.parse_col_unit(col_unit)) # ast_node.fields[i].add_value(self.parse_val_unit(val_unit)) orderby_field.add_value(ast_node)
def __init__(self, init_tree_w_dummy_reduce: AbstractSyntaxTree, bool_copy_subtree=False, tree=None, memory=None, memory_type='all_init_joint', init_code_tokens=None, length_norm=False): self.init_tree_w_dummy_reduce = init_tree_w_dummy_reduce self.bool_copy_subtree = bool_copy_subtree assert memory_type in ('all_init_joint', 'all_init_distinct', 'deleted_distinct') self.memory_type = memory_type self.init_code_tokens = init_code_tokens self.length_norm = length_norm if tree is not None: self.tree = tree else: self.tree = init_tree_w_dummy_reduce.copy() if bool_copy_subtree and memory is None: if self.memory_type == 'all_init_joint': self.memory = stack_subtrees( self.init_tree_w_dummy_reduce.root_node) elif self.memory_type == 'all_init_distinct': self.memory = [] for node in stack_subtrees( self.init_tree_w_dummy_reduce.root_node): if node not in self.memory: self.memory.append(node) else: self.memory = [] else: self.memory = memory # self.set_tree_all_finish() # redundant? self.edits = [] self.score_per_edit = [] self.score = 0. self.repr2field = {} self.open_del_node_and_ids = [] # nodes available to delete self.open_add_fields = [] # fields open to add nodes self.restricted_frontier_fields = [ ] # fields (esp. with single cardinality) grammatically need to fill self.update_frontier_info() # record the current time step self.last_edit_field_node = None # trace the last edit self.t = 0 self.stop_t = None
def parse_sql_unit(self, sql: dict): """ Parse a single sql unit, determine the existence of different clauses """ ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('SQL')) from_field, select_field, where_field, groupby_field, orderby_field = ast_node.fields self.parse_from(sql['from'], from_field) self.parse_select(sql['select'], select_field) if sql['where']: self.parse_where(sql['where'], where_field) if sql['groupBy']: # if having clause is not empty, groupBy must exist self.parse_groupby(sql['groupBy'], sql['having'], groupby_field) if sql['orderBy']: # if limit is not None, orderBY is not empty self.parse_orderby(sql['orderBy'], sql['limit'], orderby_field) return ast_node
def get_ast_from_json_obj(self, json_obj: Dict): """read an AST from serialized JSON string""" # FIXME: cyclic import from asdl.asdl_ast import AbstractSyntaxNode, RealizedField, SyntaxToken, AbstractSyntaxTree def get_subtree(entry, parent_field, next_available_id): if entry is None: return None, next_available_id constructor_name = entry['Constructor'] # terminal case if constructor_name == 'SyntaxToken': if entry['Value'] is None: return None, next_available_id # return None for optional field whose value is null token = SyntaxToken(parent_field.type, entry['Value'], position=entry['Position'], id=next_available_id) next_available_id += 1 return token, next_available_id field_entries = entry['Fields'] node_id = next_available_id next_available_id += 1 prod = self.get_prod_by_ctr_name(constructor_name) realized_fields = [] for field in prod.constructor.fields: field_value = field_entries[field.name] if isinstance(field_value, list): assert 'SyntaxList' in field.type.name sub_ast_id = next_available_id next_available_id += 1 sub_ast_prod = self.get_prod_by_ctr_name(field.type.name) sub_ast_constr_field = sub_ast_prod.constructor.fields[0] sub_ast_field_values = [] for field_child_entry in field_value: child_sub_ast, next_available_id = get_subtree( field_child_entry, sub_ast_constr_field, next_available_id=next_available_id) sub_ast_field_values.append(child_sub_ast) sub_ast = AbstractSyntaxNode(sub_ast_prod, [ RealizedField(sub_ast_constr_field, sub_ast_field_values) ], id=sub_ast_id) # FIXME: have a global mark_finished method! for sub_ast_field in sub_ast.fields: if sub_ast_field.cardinality in ('multiple', 'optional'): sub_ast_field._not_single_cardinality_finished = True realized_field = RealizedField(field, sub_ast) else: # if the child is an AST or terminal SyntaxNode sub_ast, next_available_id = get_subtree( field_value, field, next_available_id) realized_field = RealizedField(field, sub_ast) realized_fields.append(realized_field) ast_node = AbstractSyntaxNode(prod, realized_fields, id=node_id) for field in ast_node.fields: if field.cardinality in ('multiple', 'optional'): field._not_single_cardinality_finished = True return ast_node, next_available_id ast_root, _ = get_subtree(json_obj, parent_field=None, next_available_id=0) ast = AbstractSyntaxTree(ast_root) return ast
def streg_ast_to_asdl_ast(grammar, reg_ast): if reg_ast.children: rule = _NODE_CLASS_TO_RULE[reg_ast.node_class] prod = grammar.get_prod_by_ctr_name(rule) # unary if rule in [ "Not", "Star", "StartWith", "EndWith", "Contain", "NotCC", "Optional" ]: child_ast_node = streg_ast_to_asdl_ast(grammar, reg_ast.children[0]) ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['arg'], child_ast_node)]) return ast_node elif rule in ["Concat", "And", "Or"]: left_ast_node = streg_ast_to_asdl_ast(grammar, reg_ast.children[0]) right_ast_node = streg_ast_to_asdl_ast(grammar, reg_ast.children[1]) ast_node = AbstractSyntaxTree(prod, [ RealizedField(prod['left'], left_ast_node), RealizedField(prod['right'], right_ast_node) ]) return ast_node elif rule in ["RepeatAtleast", "Repeat"]: # primitive node # RealizedField(prod['predicate'], value=node_name) child_ast_node = streg_ast_to_asdl_ast(grammar, reg_ast.children[0]) int_real_node = RealizedField(prod['k'], str(reg_ast.params[0])) ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['arg'], child_ast_node), int_real_node]) return ast_node elif rule in ["RepeatRange"]: child_ast_node = streg_ast_to_asdl_ast(grammar, reg_ast.children[0]) int_real_node1 = RealizedField(prod['k1'], str(reg_ast.params[0])) int_real_node2 = RealizedField(prod['k2'], str(reg_ast.params[1])) ast_node = AbstractSyntaxTree(prod, [ RealizedField(prod['arg'], child_ast_node), int_real_node1, int_real_node2 ]) return ast_node elif rule in ["String"]: return AbstractSyntaxTree( prod, [RealizedField(prod['arg'], reg_ast.children[0].node_class)]) else: raise ValueError("wrong node class", reg_ast.node_class) else: if reg_ast.node_class in [ "<num>", "<let>", "<spec>", "<low>", "<cap>", "<any>" ]: rule = "CharClass" elif reg_ast.node_class.startswith( "const") and reg_ast.node_class[5:].isdigit(): rule = "ConstSym" elif reg_ast.node_class.startswith( "<") and reg_ast.node_class.endswith(">"): rule = "Token" else: raise ValueError("wrong node class", reg_ast.node_class) prod = grammar.get_prod_by_ctr_name(rule) return AbstractSyntaxTree( prod, [RealizedField(prod['arg'], reg_ast.node_class)])
def pdf_to_ast(grammar, lf_node): if lf_node.name.startswith('obj'): # obj = Objective(id name, expr* hdr) prod = grammar.get_prod_by_ctr_name('Objective') id_field = RealizedField(prod['name'], value=lf_node.name) hdr_ast_nodes = [] for hdr_node in lf_node.children: hdr_ast_node = pdf_to_ast(grammar, hdr_node) hdr_ast_nodes.append(hdr_ast_node) hdr_field = RealizedField(prod['hdr'], hdr_ast_nodes) ast_node = AbstractSyntaxTree(prod, [id_field, hdr_field]) elif lf_node.name in [ 'Type', 'SubType', 'Size', 'Length', 'Kids', 'Parent', 'Count', 'Limits', 'Range', 'Filter', 'Domain', 'FuncType', 'Pages', 'MediaBox', 'Resources' ]: # expr -> Apply(pred predicate, expr* arguments) prod = grammar.get_prod_by_ctr_name('Apply') pred_field = RealizedField(prod['predicate'], value=lf_node.name) arg_ast_nodes = [] for arg_node in lf_node.children: arg_ast_node = pdf_to_ast(grammar, arg_node) arg_ast_nodes.append(arg_ast_node) arg_field = RealizedField(prod['arguments'], arg_ast_nodes) ast_node = AbstractSyntaxTree(prod, [pred_field, arg_field]) elif lf_node.name.startswith('S'): # expr = Variable(var_type type, var variable) prod = grammar.get_prod_by_ctr_name('Variable') var_type_field = RealizedField(prod['type'], value='string') var_field = RealizedField(prod['variable'], value=lf_node.name[1:]) ast_node = AbstractSyntaxTree(prod, [var_type_field, var_field]) elif lf_node.name.startswith('I'): prod = grammar.get_prod_by_ctr_name('Variable') var_type_field = RealizedField(prod['type'], value='int') var_field = RealizedField(prod['variable'], value=lf_node.name[1:]) ast_node = AbstractSyntaxTree(prod, [var_type_field, var_field]) elif lf_node.name.startswith('H'): prod = grammar.get_prod_by_ctr_name('Variable') var_type_field = RealizedField(prod['type'], value='header') var_field = RealizedField(prod['variable'], value=lf_node.name[1:]) ast_node = AbstractSyntaxTree(prod, [var_type_field, var_field]) elif lf_node.name.startswith('R'): # expr = Reference(id ref) prod = grammar.get_prod_by_ctr_name('Reference') ref_var = 'obj' + lf_node.name[1:] ref_field = RealizedField(prod['ref'], value=ref_var) ast_node = AbstractSyntaxTree(prod, [ref_field]) else: raise NotImplementedError return ast_node
def logical_form_to_ast(grammar, lf_node): if lf_node.name == 'lambda': # expr -> Lambda(var variable, var_type type, expr body) prod = grammar.get_prod_by_ctr_name('Lambda') var_node = lf_node.children[0] var_field = RealizedField(prod['variable'], var_node.name) var_type_node = lf_node.children[1] var_type_field = RealizedField(prod['type'], var_type_node.name) body_node = lf_node.children[2] body_ast_node = logical_form_to_ast(grammar, body_node) # of type expr body_field = RealizedField(prod['body'], body_ast_node) ast_node = AbstractSyntaxTree(prod, [var_field, var_type_field, body_field]) elif lf_node.name == 'argmax' or lf_node.name == 'argmin' or lf_node.name == 'sum': # expr -> Argmax|Sum(var variable, expr domain, expr body) prod = grammar.get_prod_by_ctr_name(lf_node.name.title()) var_node = lf_node.children[0] var_field = RealizedField(prod['variable'], var_node.name) domain_node = lf_node.children[1] domain_ast_node = logical_form_to_ast(grammar, domain_node) domain_field = RealizedField(prod['domain'], domain_ast_node) body_node = lf_node.children[2] body_ast_node = logical_form_to_ast(grammar, body_node) body_field = RealizedField(prod['body'], body_ast_node) ast_node = AbstractSyntaxTree(prod, [var_field, domain_field, body_field]) elif lf_node.name == 'and' or lf_node.name == 'or': # expr -> And(expr* arguments) | Or(expr* arguments) prod = grammar.get_prod_by_ctr_name(lf_node.name.title()) arg_ast_nodes = [] for arg_node in lf_node.children: arg_ast_node = logical_form_to_ast(grammar, arg_node) arg_ast_nodes.append(arg_ast_node) ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['arguments'], arg_ast_nodes)]) elif lf_node.name == 'not': # expr -> Not(expr argument) prod = grammar.get_prod_by_ctr_name('Not') arg_ast_node = logical_form_to_ast(grammar, lf_node.children[0]) ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['argument'], arg_ast_node)]) elif lf_node.name == '>' or lf_node.name == '=' or lf_node.name == '<': # expr -> Compare(cmp_op op, expr left, expr right) prod = grammar.get_prod_by_ctr_name('Compare') op_name = 'GreaterThan' if lf_node.name == '>' else 'Equal' if lf_node.name == '=' else 'LessThan' op_field = RealizedField( prod['op'], AbstractSyntaxTree(grammar.get_prod_by_ctr_name(op_name))) left_node = lf_node.children[0] left_ast_node = logical_form_to_ast(grammar, left_node) left_field = RealizedField(prod['left'], left_ast_node) right_node = lf_node.children[1] right_ast_node = logical_form_to_ast(grammar, right_node) right_field = RealizedField(prod['right'], right_ast_node) ast_node = AbstractSyntaxTree(prod, [op_field, left_field, right_field]) elif lf_node.name in [ 'jet', 'flight', 'from_airport', 'airport', 'airline', 'airline_name', 'class_type', 'aircraft_code', 'aircraft_code:t', 'from', 'to', 'day', 'month', 'year', 'arrival_time', 'limousine', 'departure_time', 'meal', 'meal:t', 'meal_code', 'during_day', 'tomorrow', 'daily', 'time_elapsed', 'time_zone_code', 'booking_class:t', 'booking_class', 'economy', 'ground_fare', 'class_of_service', 'capacity', 'weekday', 'today', 'turboprop', 'aircraft', 'air_taxi_operation', 'month_return', 'day_return', 'day_number_return', 'minimum_connection_time', 'during_day_arrival', 'connecting', 'minutes_distant', 'named', 'miles_distant', 'approx_arrival_time', 'approx_return_time', 'approx_departure_time', 'has_stops', 'day_after_tomorrow', 'manufacturer', 'discounted', 'overnight', 'nonstop', 'has_meal', 'round_trip', 'oneway', 'loc:t', 'ground_transport', 'to_city', 'flight_number', 'equals:t', 'abbrev', 'equals', 'rapid_transit', 'stop_arrival_time', 'arrival_month', 'cost', 'fare', 'services', 'fare_basis_code', 'rental_car', 'city', 'stop', 'day_number', 'days_from_today', 'after_day', 'before_day', 'airline:e', 'stops', 'month_arrival', 'day_number_arrival', 'day_arrival', 'taxi', 'next_days', 'restriction_code', 'tomorrow_arrival', 'tonight', 'population:i', 'state:t', 'next_to:t', 'elevation:i', 'size:i', 'capital:t', 'len:i', 'city:t', 'named:t', 'river:t', 'place:t', 'capital:c', 'major:t', 'town:t', 'mountain:t', 'lake:t', 'area:i', 'density:i', 'high_point:t', 'elevation:t', 'population:t', 'in:t' ]: # expr -> Apply(pred predicate, expr* arguments) prod = grammar.get_prod_by_ctr_name('Apply') pred_field = RealizedField(prod['predicate'], value=lf_node.name) arg_ast_nodes = [] for arg_node in lf_node.children: arg_ast_node = logical_form_to_ast(grammar, arg_node) arg_ast_nodes.append(arg_ast_node) arg_field = RealizedField(prod['arguments'], arg_ast_nodes) ast_node = AbstractSyntaxTree(prod, [pred_field, arg_field]) elif lf_node.name.startswith('$'): prod = grammar.get_prod_by_ctr_name('Variable') ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['variable'], value=lf_node.name)]) elif ':ap' in lf_node.name or ':fb' in lf_node.name or ':mf' in lf_node.name or \ ':me' in lf_node.name or ':cl' in lf_node.name or ':pd' in lf_node.name or \ ':dc' in lf_node.name or ':al' in lf_node.name or \ lf_node.name in ['yr0', 'do0', 'fb1', 'rc0', 'ci0', 'fn0', 'ap0', 'al1', 'al2', 'ap1', 'ci1', 'ci2', 'ci3', 'st0', 'ti0', 'ti1', 'da0', 'da1', 'da2', 'da3', 'da4', 'al0', 'fb0', 'dn0', 'dn1', 'mn0', 'ac0', 'fn1', 'st1', 'st2', 'c0', 'm0', 's0', 'r0', 'n0', 'co0', 'usa:co', 'death_valley:lo', 's1', 'colorado:n']: prod = grammar.get_prod_by_ctr_name('Entity') ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['entity'], value=lf_node.name)]) elif lf_node.name.endswith(':i') or lf_node.name.endswith(':hr'): prod = grammar.get_prod_by_ctr_name('Number') ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['number'], value=lf_node.name)]) elif lf_node.name == 'the': # expr -> The(var variable, expr body) prod = grammar.get_prod_by_ctr_name('The') var_node = lf_node.children[0] var_field = RealizedField(prod['variable'], var_node.name) body_node = lf_node.children[1] body_ast_node = logical_form_to_ast(grammar, body_node) body_field = RealizedField(prod['body'], body_ast_node) ast_node = AbstractSyntaxTree(prod, [var_field, body_field]) elif lf_node.name == 'exists' or lf_node.name == 'max' or lf_node.name == 'min' or lf_node.name == 'count': # expr -> Exists(var variable, expr body) prod = grammar.get_prod_by_ctr_name(lf_node.name.title()) var_node = lf_node.children[0] var_field = RealizedField(prod['variable'], var_node.name) body_node = lf_node.children[1] body_ast_node = logical_form_to_ast(grammar, body_node) body_field = RealizedField(prod['body'], body_ast_node) ast_node = AbstractSyntaxTree(prod, [var_field, body_field]) else: raise NotImplementedError return ast_node
class Hypothesis(object): def __init__(self): self.tree = None self.actions = [] self.score = 0. self.frontier_node = None self.frontier_field = None self._value_buffer = [] # record the current time step self.t = 0 def apply_action(self, action): if self.tree is None: assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \ 'at the beginning of decoding' self.tree = AbstractSyntaxTree(action.production) self.update_frontier_info() elif self.frontier_node: if isinstance(self.frontier_field.type, ASDLCompositeType): if isinstance(action, ApplyRuleAction): field_value = AbstractSyntaxTree(action.production) field_value.created_time = self.t self.frontier_field.add_value(field_value) self.update_frontier_info() elif isinstance(action, ReduceAction): assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 'applied on field with multiple ' \ 'cardinality' self.frontier_field.set_finish() self.update_frontier_info() else: raise ValueError('Invalid action [%s] on field [%s]' % (action, self.frontier_field)) else: # fill in a primitive field if isinstance(action, GenTokenAction): # only field of type string requires termination signal </primitive> end_primitive = False if self.frontier_field.type.name == 'string': if action.is_stop_signal(): self.frontier_field.add_value(' '.join( self._value_buffer)) self._value_buffer = [] end_primitive = True else: self._value_buffer.append(action.token) else: self.frontier_field.add_value(action.token) end_primitive = True if end_primitive and self.frontier_field.cardinality in ( 'single', 'optional'): self.frontier_field.set_finish() self.update_frontier_info() elif isinstance(action, ReduceAction): assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 'applied on field with multiple ' \ 'cardinality' self.frontier_field.set_finish() self.update_frontier_info() else: raise ValueError( 'Can only invoke GenToken or Reduce actions on primitive fields' ) self.t += 1 self.actions.append(action) def update_frontier_info(self): def _find_frontier_node_and_field(tree_node): if tree_node: for field in tree_node.fields: # if it's an intermediate node, check its children if isinstance(field.type, ASDLCompositeType) and field.value: if field.cardinality in ('single', 'optional'): iter_values = [field.value] else: iter_values = field.value for child_node in iter_values: result = _find_frontier_node_and_field(child_node) if result: return result # now all its possible children are checked if not field.finished: return tree_node, field return None else: return None frontier_info = _find_frontier_node_and_field(self.tree) if frontier_info: self.frontier_node, self.frontier_field = frontier_info else: self.frontier_node, self.frontier_field = None, None def clone_and_apply_action(self, action): new_hyp = self.copy() new_hyp.apply_action(action) return new_hyp def copy(self): new_hyp = Hypothesis() if self.tree: new_hyp.tree = self.tree.copy() new_hyp.actions = list(self.actions) new_hyp.score = self.score new_hyp._value_buffer = list(self._value_buffer) new_hyp.t = self.t new_hyp.update_frontier_info() return new_hyp @property def completed(self): return self.tree and self.frontier_field is None
def pdf_to_ast(grammar, x, tr): if len(tr) >= 300: raise NotImplementedError if isinstance(x, PdfDict): prod = grammar.get_prod_by_ctr_name('PdfDict') ast_nodes = [] for y in x: prod_ = grammar.get_prod_by_ctr_name('Apply') pred_field = RealizedField(prod_['name'], value=str(y)) if y in ['/Parent', '/P', '/Dest', '/Prev']: op_field = RealizedField(prod_['op']) else: tr.append(y) args_ast_node = pdf_to_ast(grammar, x[y], tr) del tr[len(tr) - 1] op_field = RealizedField(prod_['op'], value=args_ast_node) ast_node_ = AbstractSyntaxTree(prod_, [pred_field, op_field]) ast_nodes.append(ast_node_) if x.stream: prod_ = grammar.get_prod_by_ctr_name('PdfString') var_field = RealizedField(prod_['value'], value=str(x.stream)) ast_node = AbstractSyntaxTree(prod_, [var_field]) ast_nodes.append(ast_node) arg_field = RealizedField(prod['args'], ast_nodes) ast_node = AbstractSyntaxTree(prod, [arg_field]) elif isinstance(x, PdfObject): prod = grammar.get_prod_by_ctr_name('PdfObject') var_field = RealizedField(prod['value'], value=str(x)) ast_node = AbstractSyntaxTree(prod, [var_field]) elif isinstance(x, PdfArray): dict_nodes = [] list_nodes = [] for y in x: if isinstance(y, PdfDict): args_ast_node = pdf_to_ast(grammar, y, tr) dict_nodes.append(args_ast_node) elif isinstance(y, PdfArray): raise NotImplementedError elif isinstance(y, BasePdfName): list_nodes.append(str(y)) elif isinstance(y, PdfObject): list_nodes.append(str(y)) if dict_nodes: prod = grammar.get_prod_by_ctr_name('PdfArray') arg_field = RealizedField(prod['args'], dict_nodes) ast_node = AbstractSyntaxTree(prod, [arg_field]) else: prod = grammar.get_prod_by_ctr_name('PdfList') var_field = RealizedField(prod['value'], value=tuple(list_nodes)) ast_node = AbstractSyntaxTree(prod, [var_field]) elif isinstance(x, PdfString): prod = grammar.get_prod_by_ctr_name('PdfString') var_field = RealizedField(prod['value'], value=str(x)) ast_node = AbstractSyntaxTree(prod, [var_field]) elif isinstance(x, BasePdfName): prod = grammar.get_prod_by_ctr_name('BasePdfName') var_field = RealizedField(prod['value'], value=str(x)) ast_node = AbstractSyntaxTree(prod, [var_field]) else: print(type(x)) raise NotImplementedError return ast_node
def prolog_expr_to_ast_helper(grammar, prolog_tokens, start_idx=0): i = start_idx if prolog_tokens[i] == '(': i += 1 parsed_nodes = [] while True: if prolog_tokens[i] == '\\+': # expr -> Not(expr argument) prod = grammar.get_prod_by_ctr_name('Not') i += 1 if prolog_tokens[i] == '(': arg_ast_node, end_idx = prolog_expr_to_ast_helper( grammar, prolog_tokens, i) else: arg_ast_node, end_idx = prolog_node_to_ast( grammar, prolog_tokens, i) i = end_idx assert arg_ast_node.production.type.name == 'expr' ast_node = AbstractSyntaxTree( prod, [RealizedField(prod['argument'], arg_ast_node)]) parsed_nodes.append(ast_node) elif prolog_tokens[i] == '(': ast_node, end_idx = prolog_expr_to_ast_helper( grammar, prolog_tokens, i) parsed_nodes.append(ast_node) i = end_idx else: ast_node, end_idx = prolog_node_to_ast(grammar, prolog_tokens, i) parsed_nodes.append(ast_node) i = end_idx if i >= len(prolog_tokens): break if prolog_tokens[i] == ')': i += 1 break if prolog_tokens[i] == ',': # and i += 1 elif prolog_tokens[i] == ';': # Or prod = grammar.get_prod_by_ctr_name('Or') assert parsed_nodes if len(parsed_nodes) == 1: left_ast_node = parsed_nodes[0] else: left_expr_prod = grammar.get_prod_by_ctr_name('And') left_ast_node = AbstractSyntaxTree( left_expr_prod, [RealizedField(left_expr_prod['arguments'], parsed_nodes)]) parsed_nodes = [] # get the right ast node i += 1 right_ast_node, end_idx = prolog_expr_to_ast_helper( grammar, prolog_tokens, i) ast_node = AbstractSyntaxTree(prod, [ RealizedField(prod['left'], left_ast_node), RealizedField(prod['right'], right_ast_node) ]) i = end_idx parsed_nodes = [ast_node] if i >= len(prolog_tokens): break if prolog_tokens[i] == ')': i += 1 break assert parsed_nodes if len(parsed_nodes) > 1: prod = grammar.get_prod_by_ctr_name('And') return_node = AbstractSyntaxTree( prod, [RealizedField(prod['arguments'], parsed_nodes)]) else: return_node = parsed_nodes[0] return return_node, i