def ifttt_ast_to_parse_tree_helper(s, offset): """ adapted from ifttt codebase """ if s[offset] != '(': raise RuntimeError('malformed string: node did not start with open paren at position ' + offset) offset += 1 # extract node name(type) name = '' if s[offset] == '\"': offset += 1 while s[offset] != '\"': if s[offset] == '\\': offset += 1 name += s[offset] offset += 1 offset += 1 else: while s[offset] != ' ' and s[offset] != ')': name += s[offset] offset += 1 node = ASTNode(name) while True: if s[offset] == ')': offset += 1 return node, offset if s[offset] != ' ': raise RuntimeError('malformed string: node should have either had a ' 'close paren or a space at position ' + offset) offset += 1 child_node, offset = ifttt_ast_to_parse_tree_helper(s, offset) node.add_child(child_node)
def parse_raw(code): root_node = ASTNode('root') py_ast = ast.parse(code) for p in py_ast.body: tree = python_ast_to_parse_tree(p) root_node.add_child(tree) return root_node
def decode_rules_to_tree(rules, root, rule_num=0): rule = rules[rule_num] node_types_labels = parse_rule(rule) for node_type, node_label in node_types_labels: if is_builtin_type(node_type): node_val = extract_val_GenToken(rules[rule_num + 1]) rule_num += 2 #skip node_val and GenToken[<eob>] child_node = ASTNode(node_type, node_label, node_type(node_val)) elif is_terminal_ast_type(node_type) or node_type == 'epsilon': rule_num += 1 #skip GenToken[<eob>] child_node = ASTNode(node_type) else: rule_num += 1 child_node = ASTNode(node_type, node_label) child_node, rule_num = decode_rules_to_tree( rules, child_node, rule_num) root.add_child(child_node) return root, rule_num
def seq2tree_repr_to_ast_tree_helper(tree_repr, offset): """convert a seq2tree representation to AST tree""" # extract node name node_name_end = offset while node_name_end < len(tree_repr) and tree_repr[node_name_end] != ' ': node_name_end += 1 node_repr = tree_repr[offset:node_name_end] m = node_re.match(node_repr) n_type = m.group('type') n_type = type_str_to_type(n_type) n_label = m.group('label') n_value = m.group('value') if n_type in {int, float, str, bool}: n_value = n_type(n_value) n_label = None if n_label == '' else n_label n_value = None if n_value == '' else n_value node = ASTNode(n_type, label=n_label, value=n_value) offset = node_name_end if offset == len(tree_repr): return node, offset offset += 1 if tree_repr[offset] == '(': offset += 2 while True: child_node, offset = seq2tree_repr_to_ast_tree_helper(tree_repr, offset=offset) node.add_child(child_node) if offset >= len(tree_repr) or tree_repr[offset] == ')': offset += 2 break return node, offset
def unary_link_to_closure(unary_link): closure = ASTNode(unary_link.type) last_node = unary_link.get_leaves()[0] closure_child = ASTNode(last_node.type) prod, _ = unary_link.get_productions() closure_child_label = '@'.join(str(rule).replace(' ', '$') for rule in prod) closure_child.label = closure_child_label closure.add_child(closure_child) return closure
def extract_unary_closure_helper(parse_tree, unary_link, last_node): if parse_tree.is_leaf: if unary_link and unary_link.size > 2: return [unary_link] else: return [] elif len(parse_tree.children) > 1: unary_links = [] if unary_link and unary_link.size > 2: unary_links.append(unary_link) for child in parse_tree.children: new_node = ASTNode(child.type) child_unary_links = extract_unary_closure_helper(child, new_node, new_node) unary_links.extend(child_unary_links) return unary_links else: # has a single child child = parse_tree.children[0] new_node = ASTNode(child.type, label=child.label) last_node.add_child(new_node) last_node = new_node return extract_unary_closure_helper(child, unary_link, last_node)
def parse(code): """ parse a python code into a tree structure code -> AST tree -> AST tree to internal tree structure """ ''' code = canonicalize_code(code) py_ast = ast.parse(code) tree = python_ast_to_parse_tree(py_ast.body[0]) tree = add_root(tree) return tree ''' root_node = ASTNode('root') code = canonicalize_code(code) py_ast = ast.parse(code) for p in py_ast.body: tree = python_ast_to_parse_tree(p) root_node.add_child(tree) return root_node
def ifttt_ast_to_parse_tree_helper(s, offset): """ adapted from ifttt codebase """ if s[offset] != '(': raise RuntimeError( 'malformed string: node did not start with open paren at position ' + offset) offset += 1 # extract node name(type) name = '' if s[offset] == '\"': offset += 1 while s[offset] != '\"': if s[offset] == '\\': offset += 1 name += s[offset] offset += 1 offset += 1 else: while s[offset] != ' ' and s[offset] != ')': name += s[offset] offset += 1 node = ASTNode(name) while True: if s[offset] == ')': offset += 1 return node, offset if s[offset] != ' ': raise RuntimeError( 'malformed string: node should have either had a ' 'close paren or a space at position ' + offset) offset += 1 child_node, offset = ifttt_ast_to_parse_tree_helper(s, offset) node.add_child(child_node)
def decode_tree_to_python_ast(decode_tree): from lang.py.unaryclosure import compressed_ast_to_normal ast_tree = ASTNode('root') compressed_ast_to_normal(decode_tree) for t in decode_tree.children: #print(t) terminals = t.get_leaves() for terminal in terminals: if terminal.value is not None and type(terminal.value) is str: if terminal.value.endswith('<eos>'): terminal.value = terminal.value[:-5] if terminal.type in {int, float, str, bool}: # cast to target data type terminal.value = terminal.type(terminal.value) #print(decode_tree) #root_node.add_child(tree) pt = parse_tree_to_python_ast(t) # print(pt) #print('ptya') ast_tree.add_child(pt) #print(ast_tree) return ast_tree
def compressed_ast_to_normal(parse_tree): if parse_tree.label and '@' in parse_tree.label and '$' in parse_tree.label: label = parse_tree.label label = label.replace('$', ' ') rule_reprs = label.split('@') intermediate_nodes = [] first_node = last_node = None for rule_repr in rule_reprs: m = rule_regex.match(rule_repr) p = m.group('parent') c = m.group('child') cl = m.group('clabel') p_type = type_str_to_type(p) c_type = type_str_to_type(c) node = ASTNode(c_type, label=cl) if last_node: last_node.add_child(node) if not first_node: first_node = node last_node = node intermediate_nodes.append(node) last_node.value = parse_tree.value for child in parse_tree.children: last_node.add_child(child) compressed_ast_to_normal(child) parent_node = parse_tree.parent assert len(parent_node.children) == 1 del parent_node.children[0] parent_node.add_child(first_node) # return first_node else: new_child_trees = [] for child in parse_tree.children[:]: compressed_ast_to_normal(child)
def write_to_code_file(mode, data, path_to_load, path_to_export, path_raw_code): g = data.grammar nt = {v: reverse_typename(k) for k, v in g.node_type_to_id.items()} #print(nt,g.node_type_to_id) v = data.terminal_vocab raw = [] with open(path_raw_code, 'r') as f: for line in f: raw.append(line[:-1]) with open(path_to_load, 'r') as f: l = json.load(f, encoding='utf8') l_code = [] for i in range(len(l)): # print(raw[i]) try: t = ASTNode.from_dict(l[i], nt, v) ast_tree = parse.decode_tree_to_python_ast(t) code = astor.to_source(ast_tree)[:-1] real_code = parse.de_canonicalize_code(code, raw[i]) if (mode == "hs"): real_code = " ".join(parse.tokenize_code_adv( real_code, True)).replace("\n", "#NEWLINE#").replace( "#NEWLINE# ", "").replace("#INDENT# ", "") real_code = " ".join(parse.tokenize_code_adv(real_code, False)) #print(real_code,raw[i]) l_code.append(real_code) except: print "Tree %d impossible to parse" % (i) l_code.append("") with open(path_to_export, 'w') as f: for c in l_code: f.write(c + "\n")
def __getitem__(self, lhs): key_node = ASTNode(lhs.type, None) # Rules are indexed by types only if key_node in self.rule_index: return self.rule_index[key_node] else: KeyError('key=%s' % key_node)
def extract_unary_closure(parse_tree): root_node_copy = ASTNode(parse_tree.type) unary_links = extract_unary_closure_helper(parse_tree, root_node_copy, root_node_copy) return unary_links
return tree if __name__ == '__main__': code = """ class Demonwrath(SpellCard): def __init__(self): super().__init__("Demonwrath", 3, CHARACTER_CLASS.WARLOCK, CARD_RARITY.RARE) def use(self, player, game): super().use(player, game) targets = copy.copy(game.other_player.minions) targets.extend(game.current_player.minions) for minion in targets: if minion.card.minion_type is not MINION_TYPE.DEMON: minion.damage(player.effective_spell_damage(2), self) """ code = """raise ImproperlyConfigured ( "You must define a '%s' cache" % DEFAULT_CACHE_ALIAS )""" parse_tree = parse(code) rules = parse_tree.to_rules() root_node = ASTNode('root') root_node, _ = decode_rules_to_tree(rules, root_node) ast_tree = parse_tree_to_python_ast(root_node) out_code = astor.to_source(ast_tree) format_code = de_sugar_code(out_code, code) final_code = format_code.replace('\n', '') print(final_code)
def visit(current, parent, depth): this = ASTNode(parent, current, depth) return this
def create_node_with_empty_leaf(node_name): tree = ASTNode(node_name) empty_child = ASTNode("empty") tree.add_child(empty_child) return tree
def sql_to_parse_tree(rule_list, doPrint=False, debug=False): queue = [] level = 0 if doPrint: print("sql to parse tree") for rule_idx in range(len(rule_list)): #while rule_list: current_level = rule_list.pop() parent = current_level[0] if debug: print("parent: " + str(parent)) length = len(current_level) list_of_children = [] for node_idx in range(1, length): child = current_level[node_idx] if child[0:8] == "LexToken": temp = child[9:len(child) - 1].split(",") child_node = ASTNode(node_type=temp[0], value=temp[1]) list_of_children.append(child_node) else: #print "not lex: " + child child_node = create_node_with_empty_leaf(child) #child_node = ASTNode(child) list_of_children.append(child_node) if debug: print("list of children: " + str(list_of_children)) if queue: front = queue.pop(0) if debug: print("front if" + str(front.print_with_level())) else: root = ASTNode(parent, level=1) front = root if debug: print("queue: " + str(queue)) while front.type != parent and queue: front = queue.pop(0) if debug: print("front inside while:" + str(front.print_with_level())) if debug: print("old front: " + str(front.print_with_level()) ) # + "rule_idx: " + str(rule_idx)) try: if rule_idx > 0: #print "here" front.__delitem__("empty") except: pass for child in list_of_children: level = front.level + 1 child.level = level front.add_child(child) if debug: print("new front: " + str(front.print_with_level())) print("queue before extension: " + str(queue)) #queue.extend(reversed(list_of_children)) reversed_children = list(reversed(list_of_children)) queue = reversed_children + queue if debug: print "last queue: " + str(queue) #print("root: " + str(root)) #pointer #print root queue = [] tree = add_root(root) return tree
def sql_ast_to_parse_tree(node): if isinstance(node, basestring): print(node) node_type = node["type"] if node_type == "literal": return ASTNode(node_type, label=node["variant"], value=node["value"]) if node_type == "identifier": return ASTNode(node_type, label=node["variant"], value=node["name"]) tree = ASTNode(node_type) for key in node: if key == "type": if node[key] == "literal": print(node) continue if isinstance(node[key], basestring) or isinstance(node[key], bool): child = ASTNode(key, value=node[key]) elif isinstance(node[key], dict): child = ASTNode(key) child.add_child(sql_ast_to_parse_tree(node[key])) elif isinstance(node[key], list): child = ASTNode(key + "*") for item in node[key]: child.add_child(sql_ast_to_parse_tree(item)) else: print key, node[key] tree.add_child(child) return tree
def python_ast_to_parse_tree(node): assert isinstance(node, ast.AST) node_type = type(node) tree = ASTNode(node_type) # it's a leaf AST node, e.g., ADD, Break, etc. if len(node._fields) == 0: return tree # if it's a compositional AST node with empty fields if is_compositional_leaf(node): epsilon = ASTNode('epsilon') tree.add_child(epsilon) return tree fields_info = PY_AST_NODE_FIELDS[node_type.__name__] for field_name, field_value in ast.iter_fields(node): # remove ctx stuff if field_name in NODE_FIELD_BLACK_LIST: continue # omit empty fields, including empty lists if field_value is None or (isinstance(field_value, list) and len(field_value) == 0): continue # now it's not empty! field_type = fields_info[field_name]['type'] is_list_field = fields_info[field_name]['is_list'] if isinstance(field_value, ast.AST): child = ASTNode(field_type, field_name) child.add_child(python_ast_to_parse_tree(field_value)) elif type(field_value) is str or type(field_value) is int or \ type(field_value) is float or type(field_value) is object or \ type(field_value) is bool: # if field_type != type(field_value): # print 'expect [%s] type, got [%s]' % (field_type, type(field_value)) child = ASTNode(type(field_value), field_name, value=field_value) elif is_list_field: list_node_type = typename(field_type) + '*' child = ASTNode(list_node_type, field_name) for n in field_value: if field_type in { ast.comprehension, ast.excepthandler, ast.arguments, ast.keyword, ast.alias }: child.add_child(python_ast_to_parse_tree(n)) else: intermediate_node = ASTNode(field_type) if field_type is str: intermediate_node.value = n else: intermediate_node.add_child( python_ast_to_parse_tree(n)) child.add_child(intermediate_node) else: raise RuntimeError('unknown AST node field!') tree.add_child(child) return tree
def add_root(tree): root_node = ASTNode('root') root_node.add_child(tree) return root_node
return reverse_typename(t[:-1]) else: return vars(ast)[t] if __name__ == '__main__': flag = "hs" path_to_load = "../data/exp/results/test_hs_10_iter.json" if flag == "django": train_data, dev_data, test_data = deserialize_from_file("../../django.cleaned.dataset.freq5.par_info.refact.space_only.bin") elif flag == "hs": train_data, dev_data, test_data = deserialize_from_file("../../hs.freq3.pre_suf.unary_closure.bin") data = test_data g = data.grammar nt = {v:reverse_typename(k) for k,v in g.node_type_to_id.items()} #print(nt,g.node_type_to_id) v = data.terminal_vocab results = [] with open(path_to_load,'r') as f: l = json.load(f, encoding='utf8') for i in range(len(l)): t = ASTNode.from_dict(l[i], nt,v) ast_tree = parse.decode_tree_to_python_ast(t) results.append(ast_tree) evaluate_decode_results(flag, test_data, results, verbose=True)