def visit_ImportFrom(self,node): if self.fileName[-12:] == '\\__init__.py': self.generic_visit(node) return try: if node.module is not None and len(node.module)>4 and node.module[0:2] == '__' and node.module[-2:] == '__': self.generic_visit(node) return except: print astunparse.unparse(node) for alias in node.names: if len(alias.name)>4 and alias.name[0:2] == '__' and alias.name[-2:] == '__': continue if alias.asname is not None: for (name,file,lineno) in self.imports: if name==alias.asname and self.fileName==file: break else: self.imports.add((alias.asname,self.fileName,node.lineno)) elif alias.name != '*': for (name,file,lineno) in self.imports: if name==alias.name and self.fileName==file: break else: self.imports.add((alias.name,self.fileName,node.lineno)) self.generic_visit(node)
def transformHelper(command): node = preprocess(command) print ast.dump(node) try: print astunparse.unparse(node) except: pass return node
def test_program_doesnt_return_false_after_if(self, node): u"""checks that the node does not resemble the block of code such as if (some boolean expression): return False else: return True """ if not isinstance(node, ast.If): return if self.return_false_in_if_body(node.body) and self.return_true_in_else_body(node.orelse): print "Unnecessary conditional - simply return negation of test - line {}".format(node.lineno) print astunparse.unparse(node)
def assertFilesEqual(expected, actual, replace_str="", replace_with="", python_files=False): with open(expected) as exp, open(actual) as act: act_lines = [x.replace(replace_str, replace_with).rstrip() for x in act.readlines()] exp_lines = [x.replace(replace_str, replace_with).rstrip() for x in exp.readlines()] if python_files: act_lines = astunparse.unparse(ast.parse('\n'.join(act_lines))).split('\n') exp_lines = astunparse.unparse(ast.parse('\n'.join(exp_lines))).split('\n') diff = list(difflib.unified_diff(exp_lines, act_lines)) if diff: ROOT_LOGGER.info("Replacements are: %s => %s", replace_str, replace_with) msg = "Failed asserting that two files are equal:\n%s\nversus\n%s\nDiff is:\n\n%s" raise AssertionError(msg % (actual, expected, "\n".join(diff)))
def test_program_doesnt_have_if_false(self, node): u"""checks that the node does not resemble the block of code such as if False: ... """ if not isinstance(node, ast.If): return if not node.test.id == "False": return if not hasattr(node,'test'): print "Unecessary if False statement - line {}".format(node.lineno) else: print "Dead code in if block - line {}".format(node.lineno) print astunparse.unparse(node)
def collect_names(node): names = OrderedSet() if node.__class__.__name__ == 'BoolOp': for value in node.values: if value.__class__.__name__ in ['BoolOp', 'UnaryOp']: names = names | collect_names(value) else: names = names | OrderedSet( [Name(astunparse.unparse(value).strip(), value)]) elif node.__class__.__name__ == 'UnaryOp': names = names | collect_names(node.operand) else: names = names | OrderedSet( [Name(astunparse.unparse(node).strip(), node)]) return names
def update_variable_domains(variable_domains, node_list): '''Updates variable_domains by inserting mappings from a variable name to the to the number of elements in their data domain. Program variables are created in assignments of the form "{varName} = Var({size})" or "{varName} = Param({size})". ''' for ch in node_list: if isinstance(ch, ast.Assign) and len(ch.targets) == 1: name = astunparse.unparse(ch.targets[0]).rstrip() rhs = ch.value if isinstance(rhs, ast.Call): decl_name = rhs.func.id args = rhs.args elif isinstance(rhs, ast.Subscript) and isinstance(rhs.value, ast.Call): decl_name = rhs.value.func.id args = rhs.value.args else: continue if decl_name not in ["Param", "Var", "Input", "Output"]: continue if len(args) > 1: error.error('More than one size parameter in variable declaration of "%s".' % (name), ch) size = args[0] if isinstance(size, ast.Num): if name in variable_domains and variable_domains[name] != size.n: error.fatal_error("Trying to reset the domain of variable '%s' to '%i' (old value '%i')." % (name, size.n, variable_domains[name]), size) variable_domains[name] = size.n else: error.fatal_error("Trying to declare variable '%s', but size parameters '%s' is not understood." % (name, astunparse.unparse(size).rstrip()), size)
def main(): env = {} for builtin in (dir(__builtins__) + ['__builtins__']): env[builtin] = builtin for f in sys.argv[1:]: nCode,env = replaceIdents(f, env) print(astunparse.unparse(nCode))
def truth_table(node): names = collect_names(node) conditions = list(itertools.product([False, True], repeat=len(names))) rows = [tuple(map(str, names) + [astunparse.unparse(node).strip()])] for condition in conditions: rows.append(condition + tuple([evaluate(node, names, condition)])) return rows
def function_from_source(source, globals_=None): """ A helper function to construct a Function object from a source with custom __future__ imports. """ module = ast.parse(unindent(source)) ast.fix_missing_locations(module) for stmt in module.body: if type(stmt) == ast.FunctionDef: tree = stmt name = stmt.name break else: raise ValueError("No function definitions found in the provided source") code_object = compile(module, '<nofile>', 'exec', dont_inherit=True) locals_ = {} eval(code_object, globals_, locals_) function_obj = locals_[name] function_obj._peval_source = astunparse.unparse(tree) return Function.from_object(function_obj)
def unparse(node, strip=None): result = astunparse.unparse(node) if strip: result = result.lstrip().rstrip() if isinstance(node, _ast.BinOp): return result[1:-1] return result
def visit_TryExcept(self,node): exceptions = ["BaseException","Exception","StandardError"] generalFlag = True for item in node.handlers: if astunparse.unparse(item.body[0]).strip() == "pass": # print "pass" self.result.append((7,self.fileName,node.lineno,-1)) self.generic_visit(node) return if item.type is not None: if isinstance(item.type,_ast.Tuple): # print "a tuple exceptions" for e in item.type.elts: if hasattr(e,"id") and e.id in exceptions: # print "general in tuple:",e.id self.result.append((7,self.fileName,node.lineno,-1)) self.generic_visit(node) return # print "not general in tuple" generalFlag = False elif (hasattr(item.type,"id") and item.type.id in exceptions) is False: generalFlag = False if generalFlag: # print "general" self.result.append((7,self.fileName,node.lineno,-1)) self.generic_visit(node)
def p_push_primary(p): # noqa """push_primary : DOLLARSIGN primary""" arg_list = ast.arguments( args=[ast.arg(arg='stack', annotation=None), ast.arg(arg='stash', annotation=None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]) # print(p[2]) if isinstance(p[2], ast.Name): pass # TODO: not a very good check elif 'stack.pop().' in astunparse.unparse(p[2][0]): # we are pushing an attributeref # get rid of the _call to leave the stack.pop().<attr> and concatify it p[2] = ast.Call(func=ast.Name(id='concatify', ctx=ast.Load()), args=p[2][0].value.args[0:1], keywords=[]) else: # print(p[2]) p[2] = ast.Call(func=ast.Name(id='ConcatFunction', ctx=ast.Load()), args=[ast.Lambda(arg_list, _combine_exprs(p[2]))], keywords=[]) p[0] = [ast.Expr(_push(p[2]))] _set_line_info(p)
def __init__(self, node, filename): self.node = node self.lineno = node.lineno self.filename = filename self.name = node.name self.unparsed = astunparse.unparse(node) self.source = ''.join(self.unparsed).strip() self.linecount = len(self.source.splitlines())
def get_name_and_const_from_test(test): if not(isinstance(test.ops[0], ast.Eq)): raise Exception("Tests in if can only use ==, not '%s'." % (astunparse.unparse(test.ops[0]).rstrip()), node) (name, const) = (None, None) if (isinstance(test.left, ast.Name) or isinstance(test.left, ast.Subscript)) and isinstance(test.comparators[0], ast.Num): (name, const) = (test.left, test.comparators[0].n) elif (isinstance(test.comparators[0], ast.Name) or isinstance(test.comparators[0], ast.Subscript)) and isinstance(test.left, ast.Num): (name, const) = (test.comparators[0], test.left.n) return (name, const)
def visit_IfExp(self,node): elseblock = node.orelse if elseblock: if elseblock.lineno == node.lineno: stmt = astunparse.unparse(node) exprLength = len(stmt.strip()) - stmt.count(' ') - 2 if exprLength >= config['ifinrowexp']: self.result.append((10,self.fileName,node.lineno,exprLength)) self.generic_visit(node)
def get_var_name(node): if isinstance(node, ast.Name): return node.id elif isinstance(node, ast.Subscript): return get_var_name(node.value) elif isinstance(node, ast.Call): return None # This is the Input()[10] case else: raise Exception("Can't extract var name from '%s'" % (unparse(node).rstrip()))
def find_timeits(tree): """ :param ast.AST tree: The container of one or more timeit_funcs. :rtype: list[TimeitHolder] The ``tree` may be: * an ast.Module containing the ``timeit_func`` * an ast.Module containing a function containing the ``timeit_func`` * the ``timeit_func`` directly as an ast.FunctionDef An AST parse tree will usually be an ast.Module, containing one or more items. But this function could also receive the timeit_func directly, or it could be nested within a function in the ast.Module. """ setup = [] timeits = [] isfunc = lambda o: isinstance(o, ast.FunctionDef) # unwrap a top module containing only a function if isinstance(tree, ast.Module) and len(tree.body) == 1: tree = tree.body[0] if not hasattr(tree, 'body'): raise TypeError("Can't handle '%s'" % type(tree)) # we might have the timeit func directly if isfunc(tree) and tree.name.startswith(func_name_prefix) \ and not any(c for c in tree.body if isfunc(c)): timeits.append(TimeitHolder(name=tree.name[len(func_name_prefix):], setup="", code=astunparse.unparse(tree.body))) else: # otherwise, iterate the statement body and # find setup code and timeit funcs for child in tree.body: if isfunc(child) and child.name.startswith(func_name_prefix): timeits.append(TimeitHolder(name=child.name[len(func_name_prefix):], setup=astunparse.unparse(setup), code=astunparse.unparse(child.body))) else: setup.append(child) return timeits
def _display(): screen = SCREEN node = globals['node'] screen.erase() (y, x) = screen.getmaxyx() for (num, line) in enumerate(unparse(node.node).split('\n')): if (num >= y): break screen.addnstr(num, 0, line, (x - 1)) screen.refresh()
def ast_to_source(tree): ''' Return python source of AST tree, as a string. ''' source = astunparse.unparse(tree) # trim newlines and trailing spaces --- some pretty printers add it source = "\n".join(line.rstrip() for line in source.split("\n")) source = source.strip("\n") return source
def assert_ast_equal(test_ast, expected_ast, print_ast=True): """ Check that test_ast is equal to expected_ast, printing helpful error message if they are not equal """ equal = ast_equal(test_ast, expected_ast) if not equal: if print_ast: expected_ast_str = astunparse.dump(expected_ast) test_ast_str = astunparse.dump(test_ast) print_diff(test_ast_str, expected_ast_str) expected_source = normalize_source(unparse(expected_ast)) test_source = normalize_source(unparse(test_ast)) print_diff(test_source, expected_source) assert equal
def compile_and_run(filename, file_obj=None, debug=False, globals=None): with file_obj or open(filename, 'r'): ast_ = parse.parse(file_obj.read(), debug) ast.fix_missing_locations(ast_) with open('debug.py', 'w') as f: f.write(astunparse.unparse(ast_)) with open('ast.out', 'w') as f: f.write('\n------------ AST DUMP ------------\n') f.write(astunparse.dump(ast_)) prog = compile(ast_, filename, 'exec') exec(prog, globals or {})
def visit_FunctionDef(self,node): # argsCount def findCharacter(s,d): try: value = s.index(d) except ValueError: return -1 else: return value funcName = node.name.strip() p = re.compile("^(__[a-zA-Z0-9]+__)$") if p.match(funcName.strip()) and funcName != "__import__" and funcName != "__all__": self.defmagic.add((funcName,self.fileName,node.lineno)) stmt = astunparse.unparse(node.args) arguments = stmt.split(",") argsCount = 0 for element in arguments: if findCharacter(element,'=') == -1: argsCount += 1 self.result.append((1,self.fileName,node.lineno,argsCount)) #function length lines = set() res = [node] while len(res) >= 1: t = res[0] for n in ast.iter_child_nodes(t): if not hasattr(n,'lineno') or ((isinstance(t,_ast.FunctionDef) or isinstance(t,_ast.ClassDef)) and n == t.body[0] and isinstance(n,_ast.Expr)): continue lines.add(n.lineno) if isinstance(n,_ast.ClassDef) or isinstance(n,_ast.FunctionDef): continue else: res.append(n) del res[0] self.result.append((2,self.fileName,node.lineno,len(lines))) #nested scope depth if node in self.scopenodes: self.scopenodes.remove(node) self.generic_visit(node) return dep = [[node,1]] #node,nestedlevel maxlevel = 1 while len(dep) >= 1: t = dep[0][0] currentlevel = dep[0][1] for n in ast.iter_child_nodes(t): if isinstance(n,_ast.FunctionDef): self.scopenodes.append(n) dep.append([n,currentlevel+1]) maxlevel = max(maxlevel,currentlevel) del dep[0] if maxlevel>1: self.result.append((3,self.fileName,node.lineno,maxlevel)) #DOC self.generic_visit(node)
def extend_subscript_for_input(node, extension): if isinstance(node.slice, ast.Index): idx = node.slice.value if isinstance(idx, ast.Tuple): new_idx = ast.Tuple([extension] + idx.elts, ast.Load()) else: new_idx = ast.Tuple([extension, idx], ast.Load()) node.slice.value = new_idx else: raise Exception("Unhandled node indexing: '%s'" % (unparse(node).rstrip())) return node
def visit_IfExp(self,node): expr = astunparse.unparse(node) exprLength = len(expr.strip()) - expr.count(' ') - 2 childnodes = list(ast.walk(node)) lines = 0 for n in childnodes: if hasattr(n,'lineno'): lines = max(n.lineno-node.lineno,lines) lines = lines + 1 self.result.append((10,self.fileName,node.lineno,exprLength,lines)) self.generic_visit(node)
def xproto_fol_to_python_validator(policy, fol, model, message, tag=None): if isinstance(fol, jinja2.Undefined): raise Exception('Could not find policy:', policy) f2p = FOL2Python() fol_reduced = f2p.hoist_outer(fol) if fol_reduced in ['True','False'] and fol != fol_reduced: raise TrivialPolicy("Policy %(name)s trivially reduces to %(reduced)s. If this is what you want, replace its contents with %(reduced)s"%{'name':policy, 'reduced':fol_reduced}) a = f2p.gen_validation_function(fol_reduced, policy, message, tag='validator') return astunparse.unparse(a)
def translation_eq(f, truth, print_f=False): """helper function for test_translate functions compares an AST to the string truth, which should contain Python code. truth is first dedented. """ f.compile() translation = f.translation translation_s = astunparse.unparse(translation) if print_f: print translation_s truth_s = "\n" + textwrap.dedent(truth) + "\n" assert translation_s == truth_s
def get_variable_domain(node): # Look up the number of values to switch on. if isinstance(node, ast.Name): return variable_domains[node.id] if isinstance(node, str): return variable_domains[node] if isinstance(node, ast.Subscript): if node.value.id in variable_domains: return variable_domains[node.value.id] node_name = astunparse.unparse(node).rstrip() if node_name in variable_domains: return variable_domains[node_name] error.fatal_error("No variable domain known for expression '%s', for which we want to unroll a for/with." % (astunparse.unparse(node).rstrip()), node)
def create_module(self, module, prefix, url, members = None, imports= None): p = Path(module) if True: prefix = str(prefix).replace('-','_') class_name = prefix.upper() print "module create:" + module, class_name ast = create_ont_module(class_name,prefix,url, imports, members) code = astunparse.unparse(ast) o = open (module,"w") o.write(code) o.close()
def set_source(self, source, function=None): if function is None: function = self.value assert isinstance(function, types.FunctionType) if isinstance(source, str): source = ast.parse(source) assert isinstance(source, ast.Module) assert isinstance(source.body, list) and len(source.body) == 1 source = source.body[0] assert isinstance(source, ast.FunctionDef) source = astunparse.unparse(source) source = source.strip('\n') self.source[function] = source
def t3st_unparse(): code = test_unparse_ast print(astunparse.unparse(ast.parse(inspect.getsource(code)))) # get a pretty-printed dump of the AST print(astunparse.dump(ast.parse(inspect.getsource(code))))
def test_lambda_copy_nested_captured(): a, new_a = util_run_parse("lambda b: (lambda a: a+b)") assert unparse(new_a).strip() == "(lambda arg_0: (lambda a: (a + arg_0)))"
def check_test(test): if not(isinstance(test.ops[0], ast.Eq)): self.__messages.append(CheckError("Tests can only use ==, not '%s'." % (astunparse.unparse(test.ops[0]).rstrip()), node)) if not(isinstance(test.left, ast.Name)): self.__messages.append(CheckError("Tests have to have identifier, not '%s' as left operand." % (astunparse.unparse(test.left).rstrip()), node)) if len(test.comparators) != 1: self.__messages.append(CheckError("Tests cannot have multiple comparators in test '%s'." % (astunparse.unparse(test).rstrip()), node)) if not(isinstance(test.comparators[0], ast.Num)): self.__messages.append(CheckError("Tests have to have constant, not '%s' as right operand." % (astunparse.unparse(test.comparators[0]).rstrip()), node)) return (test.left.id, test.comparators[0].n)
def unparse(node): """ Unparses an AST node to a Python string, chomping trailing newline. """ if node is None: return None return astunparse.unparse(node).strip()
def __call__(self, formula): parsed = ast.parse(formula) self.visitor.visit(parsed) return unparse(parsed).strip()
def to_source(self): return astunparse.unparse(self.parsed)
def ast_build(exprs): ast = Module(body=exprs) return unparse(ast)
return node return self.prepend_print_line(node) def visit_If(self, node): if len(node.children) > 0: node = self.generic_visit(node) return self.prepend_print_line(node) # def visit_Expr(self, node): # # if isinstance(node, ast.BinOp): # # return self.append_print_line(node) # # return node tree = ParentChildNodeTransformer().visit(tree) print("***************** New Tree ****************") MyTransformer().visit(tree) ast.fix_missing_locations(tree) print(astunparse.unparse(tree)) with open('my-instrumented-program.py', 'w') as file: file.write(astunparse.unparse(tree)) print("Done and Bye!")
args=( self.visit(node.targets[0]), self.visit(node.value), ), keywords=(), starargs=(), kwargs=(), ), node) def visit_While(self, node): # print astpretty.pprint(node) return copy_location(Call( func=Name(id="tf.while_loop", ctx=Load()), args=( Lambda( args=arguments(args=[],defaults=[],vararg=[],kwarg=[]), body=self.visit(node.test), ), map(self.visit, node.body), ), keywords=(), starargs=(), kwargs=(), ), node) myast = ast.parse(txt) myast = RewriteName().visit(myast) # print astpretty.pprint(myast) print(astunparse.unparse(myast))
def unparse(something): return astunparse.unparse(something).rstrip("\n")
newName = self.get_variable_name(len(self.seen)) self.seen[oldName] = newName return newName if __name__ == "__main__": with open(sys.argv[1]) as f: tree = ast.parse(f.read()) obfuscator = Obfuscator() obfuscator.visit(tree) random.seed('𝒙') shuffled = list(range(len(obfuscator.constants))) random.shuffle(shuffled) shuffled_constants = [ obfuscator.constants[shuffled.index(i)] for i in range(len(obfuscator.constants)) ] out = """ import random as X X.seed('𝒙') 𝒙 = [""" + ",".join(shuffled_constants) + """] X.shuffle(𝒙) """ + astunparse.unparse(tree) with open(sys.argv[2], "w") as f: f.write(out)
def json_pprint(data, stream=sys.stdout): s1 = json.dumps(data) tree = ast.parse(s1) s2 = astunparse.unparse(tree) s3 = re.sub(r',\n\s+([\]}][^ ])', r'\1', s2) stream.write(s3)
def forking_transform(src): return astunparse.unparse(ForkingTransformer().visit(ast.parse(src)))
def check_roundtrip(self, code1, filename="internal", mode="exec"): ast1 = compile(str(code1), filename, mode, ast.PyCF_ONLY_AST) code2 = astunparse.unparse(ast1) ast2 = compile(code2, filename, mode, ast.PyCF_ONLY_AST) self.assertASTEqual(ast1, ast2)
def dfs(a): if isinstance(a, ast.Import): if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a): import_flag.append('init') return ast.parse('from jittor import init').body[0] if 'torch' in astunparse.unparse(a) and 'nn' in astunparse.unparse(a): import_flag.append('nn') return ast.parse('from jittor import nn').body[0] if a.names[0].name == 'torch': return 'delete' elif isinstance(a, ast.ImportFrom): if 'torch' in a.module: return 'delete' elif isinstance(a, ast.Call): for idx, ag in enumerate(a.args): ret = dfs(ag) if ret is not None: a.args[idx] = ret for idx, kw in enumerate(a.keywords): ret = dfs(kw) if ret is not None: a.keywords[idx] = ret func = astunparse.unparse(a.func).strip('\n').split('.') prefix = '.'.join(func[0:-1]) func_name = func[-1] if func_name in unsupport_ops: raise_unsupport(func_name) if func_name in pjmap.keys(): ags = [astunparse.unparse(ag).strip('\n') for ag in a.args] kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords] ret = convert_(prefix, func_name, ags, kws) return ast.parse(ret).body[0].value if ".load_state_dict" in astunparse.unparse(a.func): a.func.attr = 'load_parameters' if astunparse.unparse(a.func).strip('\n').endswith(".size"): ags = [astunparse.unparse(ag).strip('\n') for ag in a.args] if len(ags) != 0: con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']' else: con = astunparse.unparse(a.func).replace('size', 'shape') return ast.parse(con).body[0].value elif isinstance(a, ast.Expr): pass elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name): replace(a) elif isinstance(a, ast.FunctionDef): if a.name == 'forward': a.name = 'execute' if hasattr(a, '__dict__'): for k in a.__dict__.keys(): if isinstance(a.__dict__[k], list): delete_flag = [] for i,a_ in enumerate(a.__dict__[k]): ret = dfs(a_) if ret is 'delete': delete_flag.append(True) del a.__dict__[k][i] continue if ret is not None: a.__dict__[k][i] = ret delete_flag.append(False) tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False] a.__dict__[k] = tmp else: ret = dfs(a.__dict__[k]) if ret is not None: a.__dict__[k] = ret
def ast_construct(astbody, parent): for child in ast.iter_child_nodes(astbody): if isinstance(child, ast.Assign): label = 'Assign' value = astunparse.unparse(child) value = pattern.sub('', value) assign = Node(label, value) parent.insert_child(assign) if isinstance(child, ast.Break): label = 'Break' value = 'break' n_break = Node(label, value) parent.insert_child(n_break) if isinstance(child, ast.Expr): if isinstance(child.value, ast.Call): label = 'Call' value = astunparse.unparse(child.value) value = pattern.sub('', value) call = Node(label, value) parent.insert_child(call) else: label = 'Expr' value = astunparse.unparse(child) value = pattern.sub('', value) node_expr = Node(label, value) parent.insert_child(node_expr) if isinstance(child, ast.AugAssign): label = 'AugAssign' value = astunparse.unparse(child) value = pattern.sub('', value) augassign = Node(label, value) parent.insert_child(augassign) if isinstance(child, ast.Return): n_return = Node('Return', 'return') parent.insert_child(n_return) # 如果有返回值,对返回值进行处理 if child.value: return_value = astunparse.unparse(child.value) return_value = pattern.sub('', return_value) if return_value[0] == '(' and return_value[-1] == ')': return_value = return_value[1:-1] return_value = return_value.split(',') for item in return_value: node_return_value = Node('Return_Value', item) n_return.insert_child(node_return_value) if isinstance(child, ast.If): label = 'If' value = astunparse.unparse(child.test) #If节点的value为其判断条件 value = pattern.sub('', value) node_if = Node(label, value) parent.insert_child(node_if) #如果If包含oreless块,则生成Then和Else作为If的子节点,否则body的内容直接作为If的孩子节点插入 if child.orelse: node_then = Node('Then', value) #自身不带value的节点从其父结点继承value node_if.insert_child(node_then) source_then = astunparse.unparse( child.body) # If.body无法直接处理,先反解析为源代码,再重新生成ast ast_then = ast.parse(source_then) ast_construct(ast_then, node_then) node_else = Node('Else', value) node_if.insert_child(node_else) source_else = astunparse.unparse(child.orelse) ast_else = ast.parse(source_else) ast_construct(ast_else, node_else) else: source_body = astunparse.unparse( child.body) # If.body无法直接处理,先泛解析为源代码,再重新生成ast ast_then = ast.parse(source_body) ast_construct(ast_then, node_if) if isinstance(child, ast.FunctionDef): value = child.name value = pattern.sub('', value) label = 'FunctionDef' n_functiondef = Node(label, value) parent.insert_child(n_functiondef) #parent.insert_child(n_functiondef) #处理FunctionDef的参数和装饰器decorator_list,参数分为位置参数args, 可变长度参数vararg和关键字参数kwargs几类 source_args = astunparse.unparse(child.args) source_args = pattern.sub('', source_args) args = list() #位置参数 vararg = list() #可变长度位置参数*args kwonlyargs = list() #关键字参数 kwargs = list() # 可变长度关键字参数**args defaults = list() kw_defaults = list() varargs_flag = 0 if source_args != "": source_args = source_args.split(',') for item in source_args: if '**' in item: kwargs.append(item[2:]) elif '*' in item: vararg.append(item[1:]) varargs_flag = 1 elif '=' in item and varargs_flag == 1: kw_defaults.append(item) elif '=' not in item and varargs_flag == 1: kwonlyargs.append(item) elif '=' in item: defaults.append(item) else: args.append(item) for arg in args: arg_node = Node('args', arg) n_functiondef.insert_child(arg_node) for arg in vararg: vararg_node = Node('vararg', arg) n_functiondef.insert_child(vararg_node) for arg in kwonlyargs: kwonlyargs_node = Node('kwnolyarg', arg) n_functiondef.insert_child(kwonlyargs_node) for arg in kwargs: kwargs_node = Node('kwargs', arg) n_functiondef.insert_child(kwargs_node) for arg in defaults: default_node = Node('default_args', arg) n_functiondef.insert_child(default_node) for arg in kw_defaults: kw_defaults_node = Node('kw_default', arg) n_functiondef.insert_child(kw_defaults_node) #处理Function_body source_functionbody = astunparse.unparse(child.body) ast_functionbody = ast.parse(source_functionbody) ast_construct(ast_functionbody, n_functiondef) if isinstance(child, ast.ClassDef): class_name = child.name label = 'ClassDef' node_classdef = Node(label, class_name) parent.insert_child(node_classdef) # 处理基类 bases = astunparse.unparse( child.bases) #只考虑name和base, keywos, starargs等暂时不考虑 bases = pattern.sub('', bases) if bases != '': node_bases = Node('base', bases) node_classdef.insert_child(node_bases) #处理 class body source_classbody = astunparse.unparse(child.body) ast_classbody = ast.parse(source_classbody) ast_construct(ast_classbody, node_classdef) if isinstance(child, ast.For): label = 'For' value = astunparse.unparse(child.iter) #For的value在考虑 value = pattern.sub('', value) node_for = Node(label, value) parent.insert_child(node_for) #判断For循环是否有else部分,如果有,new一个ForBody和ForElse作为Node_For的孩子节点;否则,for.body直接作为For的孩子节点插入 if child.orelse: #处理For.body node_then = Node('Then', value) node_for.insert_child(node_then) source_then = astunparse.unparse(child.body) ast_then = ast.parse(source_then) ast_construct(ast_then, node_then) node_else = Node('Else', value) node_for.insert_child(node_else) source_else = astunparse.unparse(child.orelse) ast_else = ast.parse(source_else) ast_construct(ast_else, node_else) else: source_forbody = astunparse.unparse(child.body) ast_forbody = ast.parse(source_forbody) ast_construct(ast_forbody, node_for) if isinstance(child, ast.While): label = 'While' value = astunparse.unparse(child.test) value = pattern.sub('', value) node_while = Node(label, value) parent.insert_child(node_while) source_whilebody = astunparse.unparse(child.body) ast_whilebody = ast.parse(source_whilebody) ast_construct(ast_whilebody, node_while)
def replace_validator(): with open(root+'/Framework_kernel/validator.py', 'r+', encoding='utf-8') as f1, \ open(root+'/Framework_Performance/validator_and_others.py', 'r+', encoding='utf-8')as f2,\ open(root+'/Framework_kernel/host.py', 'r+', encoding='utf-8')as f3,\ open(root+'/Framework_kernel/report.py', 'r+', encoding='utf-8')as f4,\ open(root+'/Framework_kernel/execution_engine.py', 'r+', encoding='utf-8')as f5: validator = f1.read() validator_node = ast.parse(validator) fake_validator = f2.read() fake_validator_node = ast.parse(fake_validator) host = f3.read() host_node = ast.parse(host) report = f4.read() report_node = ast.parse(report) execution_engine = f5.read() execution_engine_node = ast.parse(execution_engine) # print(ast.dump(fake_validator_node)) validator_map_dict = { 'validate_jenkins_server': fake_validator_node.body[5].body[0], 'validate_build_server': fake_validator_node.body[5].body[1], '__validate_QTP': fake_validator_node.body[5].body[2], '__validate_HPDM': fake_validator_node.body[5].body[3], 'validate_uut': fake_validator_node.body[5].body[4], 'validate_ftp': fake_validator_node.body[5].body[5], 'validate': fake_validator_node.body[6].body[0], 'build_task': fake_validator_node.body[-3], '__load_uut_result': fake_validator_node.body[-2], 'send_report': fake_validator_node.body[-1], } # skip validate for i in validator_node.body: i_index = validator_node.body.index(i) if isinstance(i, _ast.ClassDef) and i.name == 'HostValidator': print(validator_node.body[i_index]) for j in i.body: j_index = i.body.index(j) if isinstance(j, _ast.FunctionDef ) and j.name in validator_map_dict.keys(): validator_node.body[i_index].body[ j_index] = validator_map_dict[j.name] if isinstance(i, _ast.ClassDef) and i.name == 'ScriptValidator': for j in i.body: j_index = i.body.index(j) if isinstance(j, _ast.FunctionDef ) and j.name in validator_map_dict.keys(): validator_node.body[i_index].body[ j_index] = validator_map_dict[j.name] validator_source = astunparse.unparse(validator_node) f1.seek(0, 0) f1.truncate() f1.write(validator_source) # host skip generate script yml for h in host_node.body: h_index = host_node.body.index(h) if isinstance(h, _ast.ClassDef) and h.name == 'Build': # print(host_node.body[h_index]) for j in h.body: j_index = h.body.index(j) if isinstance(j, _ast.FunctionDef ) and j.name in validator_map_dict.keys(): host_node.body[h_index].body[ j_index] = validator_map_dict[j.name] host_source = astunparse.unparse(host_node) f3.seek(0, 0) f3.truncate() f3.write(host_source) # report cancel error handler for r in report_node.body: r_index = report_node.body.index(r) if isinstance(r, _ast.ClassDef) and r.name == 'Report': # print(host_node.body[h_index]) for j in r.body: j_index = r.body.index(j) if isinstance(j, _ast.FunctionDef ) and j.name in validator_map_dict.keys(): report_node.body[r_index].body[ j_index] = validator_map_dict[j.name] report_source = astunparse.unparse(report_node) f4.seek(0, 0) f4.truncate() f4.write(report_source) # execution_engine cancel email for e in execution_engine_node.body: e_index = execution_engine_node.body.index(e) if isinstance(e, _ast.ClassDef) and e.name == 'ExecutionEngine': # print(host_node.body[h_index]) for j in e.body: j_index = e.body.index(j) if isinstance(j, _ast.FunctionDef ) and j.name in validator_map_dict.keys(): execution_engine_node.body[e_index].body[ j_index] = validator_map_dict[j.name] report_source = astunparse.unparse( execution_engine_node) f5.seek(0, 0) f5.truncate() f5.write(report_source)
def save(self): code = astunparse.unparse(self._source_tree) with open(self.file, 'w') as f: f.write(code)
print("set arg to 'hello world'") args[0] = ast.Constant(s="hello world", kind=None) ast.fix_missing_locations(args[0]) print("add another arg") last_index = len(args) print("last index: ", last_index) args.insert(last_index, ast.Constant(s="!!!!!", kind=None)) ast.fix_missing_locations(args[last_index]) print("add test call one more time.") node = ast.parse(add_arg).body[0].value args.insert(last_index + 1, node) ast.fix_missing_locations(args[last_index + 1]) v = Visitor() v.visit(tree) co = compile(tree, "log.log", "exec") exec(co) import astunparse print(astunparse.unparse(tree))
def err(): raise ValueError('Unexpected kind of slice: {}'.format(astunparse.unparse(subscript)))
def visit_Assign(self, state, assgn): (parameters, initialised, used) = state if len(assgn.targets) > 1: self.__messages.append(CheckError("Cannot process tuple assignment in '%s'." % (astunparse.unparse(assgn).rstrip()), assgn)) target = assgn.targets[0] value = assgn.value if isinstance(target, ast.Name): assignedVar = target.id if isinstance(value, ast.Num): if self.__is_declared(assignedVar): self.__messages.append(CheckError("Trying to assign num literal '%i' to model variable '%s'." % (value.n, assignedVar), assgn)) elif isinstance(value, ast.Call) and isinstance(value.func, ast.Name): if (value.func.id in ["Var", "Param", "Input", "Output"]): if isinstance(value.args[0], ast.Num): self.__set_domain(assgn, assignedVar, value.args[0].n) else: self.__messages.append(CheckError("Cannot declare variable/parameter with non-constant range '%s'." % (astunparse.unparse(value.func.args[0]).rstrip()), assgn)) if value.func.id == "Param": return (parameters | {assignedVar}, initialised, used) if value.func.id == "Input": return (parameters, initialised | {assignedVar}, used) if value.func.id in ["Output"]: self.__outputs.append(assignedVar) return (parameters, initialised, used | {assignedVar}) elif value.func.id in self.__defined_functions: for argument in value.args: (_, _, used) = self.visit((parameters, initialised, used), argument) return (parameters, initialised | {assignedVar}, used) else: self.__messages.append(CheckError("Cannot assign unknown function to variable '%s'." % (assignedVar), assgn)) else: if self.__is_declared(assignedVar): self.__messages.append(CheckError("Trying to assign '%s' to model variable '%s'." % (astunparse.unparse(value).rstrip(), assignedVar), assgn)) else: self.__messages.append(CheckError("Cannot assign value to non-variable '%s'." % (astunparse.unparse(target).rstrip()), assgn)) return state
def generic_visit(self, state, node): self.__messages.append(CheckError("AST node '%s' unsupported." % (astunparse.unparse(node).strip()), node))
def visit_If(self, node): node.test = self.visit(node.test) node = eval_const_expressions(node) if not(isinstance(node, ast.If)): if isinstance(node, list): return flat_map(self.visit, node) else: return self.visit(node) # We want to unroll the "else: P" bit into something explicit for # all the cases that haven't been checked explicitly yet. def get_or_cases(test): if not(isinstance(test, ast.BoolOp) and isinstance(test.op, ast.Or)): return [test] else: return itertools.chain.from_iterable(get_or_cases(value) for value in test.values) def get_name_and_const_from_test(test): if not(isinstance(test.ops[0], ast.Eq)): raise Exception("Tests in if can only use ==, not '%s'." % (astunparse.unparse(test.ops[0]).rstrip()), node) (name, const) = (None, None) if (isinstance(test.left, ast.Name) or isinstance(test.left, ast.Subscript)) and isinstance(test.comparators[0], ast.Num): (name, const) = (test.left, test.comparators[0].n) elif (isinstance(test.comparators[0], ast.Name) or isinstance(test.comparators[0], ast.Subscript)) and isinstance(test.left, ast.Num): (name, const) = (test.comparators[0], test.left.n) return (name, const) checked_values = set() checked_vars = set() # Now walk the .orelse branches, visiting each body independently. Expand ors on the way: last_if = None current_if = node while True: # Recurse with visitor first: current_if.test = eval_const_expressions(self.visit(current_if.test)) current_if.body = flat_map(self.visit, current_if.body) # Now, unfold ors: test_cases = list(get_or_cases(current_if.test)) if len(test_cases) > 1: else_body = current_if.orelse new_if_node = None for test_case in reversed(test_cases): body_copy = copy.deepcopy(current_if.body) new_if_node = ast.copy_location(ast.If(test=test_case, body=body_copy, orelse=else_body), node) else_body = [new_if_node] current_if = new_if_node # Make the change stick: if last_if is None: node = current_if else: last_if.orelse = [current_if] # Do our deed: try: (checked_var, checked_value) = get_name_and_const_from_test(current_if.test) checked_vars.add(checked_var) checked_values.add(checked_value) # Look at the next elif: if len(current_if.orelse) == 1 and isinstance(current_if.orelse[0], ast.If): last_if = current_if current_if = current_if.orelse[0] else: break except: # This may happen if we couldn't cleanly identify the else case. For this, just leave things as they are: return node # We need to stringify them, to not be confused by several instances refering to the same thing: checked_var_strs = set(astunparse.unparse(var) for var in checked_vars) if len(checked_var_strs) != 1: raise Exception("If-else checking more than one variable (%s)." % (checked_var_strs)) checked_var = checked_vars.pop() domain_to_check = set(range(get_variable_domain(checked_var))) still_unchecked = domain_to_check - checked_values else_body = flat_map(self.visit, current_if.orelse) if len(else_body) == 0: return node # if len(still_unchecked) > 0: # print("Else for values %s of %s:\n%s" % (still_unchecked, astunparse.unparse(checked_var).rstrip(), astunparse.unparse(else_body))) for value in still_unchecked: # print("Inserting case %s == %i in else unfolding." % (astunparse.unparse(checked_var).rstrip(), value)) var_node = copy.deepcopy(checked_var) eq_node = ast.copy_location(ast.Eq(), node) value_node = ast.copy_location(ast.Num(n=value), node) test_node = ast.copy_location(ast.Compare(var_node, [eq_node], [value_node]), node) case_body = copy.deepcopy(else_body) new_if_node = ast.copy_location(ast.If(test=test_node, body=case_body, orelse=[]), node) current_if.orelse = [new_if_node] current_if = new_if_node return node
def __get_domain_of_expr(self, node): if isinstance(node, ast.Name): return self.__get_domain(node, node.id) elif isinstance(node, ast.Call): (_, value_domain) = self.__defined_functions[node.func.id] return value_domain elif isinstance(node, ast.Num): return node.n + 1 #And now all together: Zero-based counting is haaaaard. else: self.__messages.append(CheckError("Cannot determine domain of value '%s' used in assignment." % (astunparse.unparse(node).rstrip()), node)) return 0
def test_lambda_copy_simple(): a, new_a = util_run_parse("lambda a: a") assert unparse(new_a).strip() == "(lambda arg_0: arg_0)" assert ast.dump(new_a) != ast.dump(a)
def visit_Call(self, state, call): (parameters, initialised, used) = state if isinstance(call.func, ast.Attribute): func = call.func func_name = func.attr set_variable = func.value.id if func_name in ["set_to_constant", "set_to"]: #Check that the children are fine: if len(call.args) != 1: self.__messages.append(CheckError("'%s.%s' with more than one argument unsupported." % (set_variable, func_name), call)) value = call.args[0] (_, _, val_used_vars) = self.visit(state, value) if set_variable in initialised: self.__messages.append(CheckError("Trying to reset value of variable '%s'." % (set_variable), call)) if set_variable in parameters: self.__messages.append(CheckWarning("Setting value of parameter '%s'." % (set_variable), call)) domain = self.__get_domain(call, set_variable) if isinstance(value, ast.Num) and (value.n < 0 or value.n >= domain): self.__messages.append(CheckError("Trying to set variable '%s' (domain [0..%i]) to invalid value '%i'." % (set_variable, domain - 1, value.n), call)) else: value_domain = self.__get_domain_of_expr(value) if value_domain != domain: if isinstance(value, ast.Num) and (value.n >= domain or value.n < 0): self.__messages.append(CheckError("Trying to set variable '%s' (domain [0..%i]) to value '%s'." % (set_variable, domain - 1, astunparse.unparse(value).rstrip()), value)) elif not(isinstance(value, ast.Num)): self.__messages.append(CheckError("Trying to set variable '%s' (domain [0..%i]) to value '%s' with different domain [0..%i]." % (set_variable, domain - 1, astunparse.unparse(value).rstrip(), value_domain - 1), value)) return (parameters, initialised | {set_variable}, val_used_vars) elif func_name in ["set_as_input"]: return (parameters, initialised | {set_variable}, used) elif func_name in ["set_as_output"]: return (parameters, initialised, used | {set_variable}) elif func_name == "observe_value": #Check that the children are fine: if len(call.args) != 1: self.__messages.append(CheckError("'%s.%s' with more than one argument unsupported." % (set_variable, func_name), call)) (_, _, val_used_vars) = self.visit(state, call.args[0]) if set_variable not in initialised: self.__messages.append(CheckError("Observation of potentially uninitialised variable '%s'." % (set_variable), call)) return (parameters, initialised, val_used_vars | {set_variable}) else: self.__messages.append(CheckError("Unsupported call '%s'." % (astunparse.unparse(call).rstrip()), call)) else: func_name = call.func.id func_information = self.__defined_functions.get(func_name, None) if func_information != None: (par_domains, _) = func_information used_vars = used if len(call.args) != len(par_domains): self.__messages.append(CheckError("Call to %i-ary function '%s' with %i arguments." % (len(par_domains), func_name, len(call.args)), call)) for idx in range(len(call.args)): arg = call.args[idx] (_, _, used_vars) = self.visit((parameters, initialised, used_vars), arg) par_domain = par_domains[idx] arg_domain = self.__get_domain_of_expr(arg) if arg_domain != par_domain: if isinstance(arg, ast.Num) and (arg.n >= par_domain or arg.n < 0): self.__messages.append(CheckError("Parameter %i of function '%s' has domain [0..%i], but argument value '%s' is incompatible." % (idx + 1, func_name, par_domain - 1, astunparse.unparse(arg).rstrip()), arg)) elif not(isinstance(arg, ast.Num)): self.__messages.append(CheckError("Parameter %i of function '%s' has domain [0..%i], but argument value '%s' has different domain [0..%i]." % (idx + 1, func_name, par_domain - 1, astunparse.unparse(arg).rstrip(), arg_domain - 1), arg)) return (parameters, initialised, used_vars) else: self.__messages.append(CheckError("Call to undefined functions '%s'." % (func_name), call))
def test_lambda_copy_no_arg(): a, new_a = util_run_parse("lambda: 1+1") assert unparse(new_a).strip() == "(lambda : (1 + 1))" assert a is not new_a
def test_lambda_copy_nested_same_arg_name(): a, new_a = util_run_parse("lambda a: (lambda a: a)(a)") assert unparse(new_a).strip() == "(lambda arg_0: (lambda a: a)(arg_0))"
def build_ignore_context_manager(ctx, stmt): InputType = namedtuple('InputType', ['name', 'ann']) OutputType = namedtuple('OutputType', ['name', 'ann']) def process_ins_outs(args): # parse the context manager to figure out inputs and outputs # with their annotated types # TODO: add input, output validator inputs = [] outputs = [] for arg in args: var_name = arg.arg if sys.version_info < (3, 8): # Starting python3.8 ast.Str is deprecated var_ann = arg.value.s else: var_ann = arg.value.value var_decl_type, var_ann = var_ann.split(":") if var_decl_type == "inp": inputs.append(InputType(var_name, var_ann)) if var_decl_type == "out": outputs.append(OutputType(var_name, var_ann)) return inputs, outputs def create_unique_name_ext(ctx, stmt): # extension will be based on the full path filename plus # the line number of original context manager fn = re.sub(r'[^a-zA-Z0-9_]', '_', ctx.filename) return f"{fn}_{stmt.lineno}" def build_return_ann_stmt(outputs): return_type_ann = "" return_statement_str = "return " if len(outputs) == 0: return_type_ann += " -> None" if len(outputs) == 1: return_type_ann = " -> " + outputs[0].ann return_statement_str += outputs[0].name if len(outputs) > 1: return_type_ann = " -> Tuple" return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]" return_statement_str += ", ".join([var.name for var in outputs]) return return_type_ann, return_statement_str def build_args(args): return ", ".join([arg.name for arg in args]) inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords) # build the replacement function str with given inputs and outputs ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt) ignore_function_str = "\ndef " + ignore_function_name ignore_function_str += "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")" return_ann, return_stmt = build_return_ann_stmt(outputs) ignore_function_str += return_ann + ": pass" # first create the functionDef object from just declaration ignore_function = ast.parse(ignore_function_str).body[0] # dump the body of context manager to dummy function ignore_function.body = stmt.body # type: ignore[attr-defined] # insert return statement to the function return_stmt = ast.parse(return_stmt).body[0] ignore_function.body.append(return_stmt) # type: ignore[attr-defined] # registers the custom function in the global context ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function) ignore_func_str += "\nglobals()[\"{}\"] = {}".format(ignore_function_name, ignore_function_name) exec(ignore_func_str) # noqa: P204 # build the statements as: # <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>) assign_str_lhs = build_args(outputs) # this function will be registered in torch.jit.frontend module by default assign_str_rhs = "torch.jit.frontend.{}(".format(ignore_function_name) + build_args(inputs) + ")" if len(outputs) > 0: assign_str = assign_str_lhs + " = " + assign_str_rhs else: assign_str = assign_str_rhs assign_ast = ast.parse(assign_str).body[0] return assign_ast
def slice_node_to_tuple_of_numbers(slice_node): if isinstance(slice_node.value, ast.Tuple): indices = (elt for elt in slice_node.value.elts) else: indices = (slice_node.value,) indices = list(indices) for index in indices: if not(isinstance(index, ast.Num)): error.fatal_error("Trying to use non-constant value '%s' as array index." % (astunparse.unparse(index).rstrip()), index) # Convert to python numbers indices = (index.n for index in indices) return indices