def wrap_with_imports(self, program): theano_import = ast.Import( names=[ast.alias(name='theano', asname=None)]) theano_tensor_import = ast.Import( names=[ast.alias(name='theano.tensor', asname='tt')]) numpy_import = ast.Import(names=[ast.alias(name='numpy', asname='np')]) typing_import = ast.ImportFrom(level=0, module='typing', names=[ ast.alias(name='List', asname=None), (ast.alias(name='Tuple', asname=None)) ]) functools_import = ast.ImportFrom( level=0, module='functools', names=[ast.alias(name='reduce', asname=None)]) body_with_needed_imports = [ theano_import, theano_tensor_import, numpy_import, functools_import, typing_import ] + program.body program.body = body_with_needed_imports return program
def parse(code, nas_mode=None): """Annotate user code. Return annotated code (str) if annotation detected; return None if not. code: original user code (str), nas_mode: the mode of NAS given that NAS interface is used """ try: ast_tree = ast.parse(code) except Exception: raise RuntimeError('Bad Python code') transformer = Transformer(nas_mode) try: transformer.visit(ast_tree) except AssertionError as exc: raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0])) if not transformer.annotated: return None last_future_import = -1 import_nni = ast.Import(names=[ast.alias(name='nni', asname=None)]) nodes = ast_tree.body for i, _ in enumerate(nodes): if type(nodes[i] ) is ast.ImportFrom and nodes[i].module == '__future__': last_future_import = i nodes.insert(last_future_import + 1, import_nni) # enas, oneshot and darts modes for tensorflow need tensorflow module, so we import it here if nas_mode in ['enas_mode', 'oneshot_mode', 'darts_mode']: import_tf = ast.Import( names=[ast.alias(name='tensorflow', asname=None)]) nodes.insert(last_future_import + 1, import_tf) return astor.to_source(ast_tree)
def add_required_imports(tree: ast.Module, required_imports: typ.Set[common.ImportDecl]) -> None: """Add imports required by fixers. Some fixers depend on modules which may not be imported in the source module. As an example, occurrences of 'map' might be replaced with 'itertools.imap', in which case, "import itertools" will be added in the module scope. A further quirk is that all reqired imports must be added before any other statment. This is because that statement could be subject to the fix which requires the import. As a side effect, a module may end up being imported twice, if the module is imported after some statement. """ (future_imports_offset, imports_end_offset, found_imports) = parse_imports(tree) missing_imports = sorted(required_imports - found_imports) import_node: ast.stmt for import_decl in missing_imports: if import_decl.import_name is None: import_node = ast.Import( names=[ast.alias(name=import_decl.module_name, asname=None)]) else: import_node = ast.ImportFrom( module=import_decl.module_name, level=0, names=[ast.alias(name=import_decl.import_name, asname=None)], ) if import_decl.py2_module_name: asname = import_decl.import_name or import_decl.module_name fallback_import = ast.Import(names=[ ast.alias(name=import_decl.py2_module_name, asname=asname) ]) import_node = ast.Try( body=[import_node], handlers=[ ast.ExceptHandler( type=ast.Name(id='ImportError', ctx=ast.Load()), name=None, body=[fallback_import], ) ], orelse=[], finalbody=[], ) if import_decl.module_name == '__future__': tree.body.insert(future_imports_offset, import_node) future_imports_offset += 1 imports_end_offset += 1 else: tree.body.insert(imports_end_offset, import_node) imports_end_offset += 1
def replace_safe_str(code): tree = ast.parse(code) tree.body.insert(0, ast.Import(names=[ast.alias(name='sys')])) tree.body.insert(1, ast.Import(names=[ast.alias(name='safe_string', asname='safe_string')])) tree.body.insert(2, ast.Import(names=[ast.alias(name='safe_execute')])) SafeStringVisitor().visit(tree) tree.body.insert(1, ast.parse("sys.path.insert(1, \'{path}\')".format(path=os.path.dirname(__file__)))) ast.fix_missing_locations(tree) return ast.unparse(tree)
def test_Import(self): self.verify(ast.Import([ast.alias('spam', None)]), 'import spam') self.verify(ast.Import([ast.alias('spam', 'bacon')]), 'import spam as bacon') self.verify( ast.Import([ ast.alias('spam', None), ast.alias('bacon', 'bacn'), ast.alias('eggs', None) ]), 'import spam,bacon as bacn,eggs')
def test_fqdn_imports(self): simple = ast.alias(name="os", asname=None) assert pi.ImportedName("os", ast.Import(names=[simple]), simple).canonical_name == "os" module = ast.alias(name="os.path", asname=None) module_name = pi.ImportedName("os.path", ast.Import(names=[module]), module) assert module_name.canonical_name == "os.path" alias = ast.alias(name="os.path", asname="path") assert pi.ImportedName("path", ast.Import(names=[alias]), alias).canonical_name == "os.path"
def wrap_with_theano_import(self, program): theano_import = ast.Import( names=[ast.alias(name='theano', asname=None)]) theano_tensor_import = ast.Import( names=[ast.alias(name='theano.tensor', asname='tt')]) numpy_import = ast.Import(names=[ast.alias(name='numpy', asname='np')]) new_program = ast.Module( body=[theano_import, theano_tensor_import, numpy_import, program]) return new_program
def __init__(self, compiled_proto_path, class_name): self.module = ast.Module(body=[ ast.Import( names=[ast.alias(name=compiled_proto_path, asname=class_name) ]), ast.Import(names=[ast.alias(name='asyncio', asname=None)]) ]) self.root_class = ast.ClassDef(name=class_name, body=[], bases=[], decorator_list=[])
def get_kernel_embed(): """A list of kernel embed nodes Returns: nodes (list): AST nodes which form the following code. ``` import os pid = os.fork() if os.fork() == 0: open(f'{os.environ["HOME"]}/.pynt', 'a').close() import IPython IPython.start_kernel(user_ns={**locals(), **globals(), **vars()}) os.waitpid(pid, 0) ``` This is a purely functional method which always return the same thing. """ return [ ast.Import(names=[ast.alias(name='os', asname=None),]), ast.Assign(targets=[ast.Name(id='pid', ctx=ast.Store()),], value=ast.Call(func=ast.Attribute(value=ast.Name(id='os', ctx=ast.Load()), attr='fork', ctx=ast.Load()), args=[], keywords=[])), ast.If( test=ast.Compare(left=ast.Name(id='pid', ctx=ast.Load()), ops=[ast.Eq(),], comparators=[ast.Num(n=0),]), body=[ ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Call(func=ast.Name(id='open', ctx=ast.Load()), args=[ ast.JoinedStr(values=[ ast.FormattedValue(value=ast.Subscript(value=ast.Attribute(value=ast.Name(id='os', ctx=ast.Load()), attr='environ', ctx=ast.Load()), slice=ast.Index(value=ast.Str(s='HOME')), ctx=ast.Load()), conversion=-1, format_spec=None), ast.Str(s='/.pynt'), ]), ast.Str(s='a'), ], keywords=[]), attr='close', ctx=ast.Load()), args=[], keywords=[])), ast.Import(names=[ ast.alias(name='IPython', asname=None), ]), ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id='IPython', ctx=ast.Load()), attr='start_kernel', ctx=ast.Load()), args=[], keywords=[ ast.keyword(arg='user_ns', value=ast.Dict(keys=[ None, None, None, ], values=[ ast.Call(func=ast.Name(id='locals', ctx=ast.Load()), args=[], keywords=[]), ast.Call(func=ast.Name(id='globals', ctx=ast.Load()), args=[], keywords=[]), ast.Call(func=ast.Name(id='vars', ctx=ast.Load()), args=[], keywords=[]), ])), ])), ], orelse=[]), ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id='os', ctx=ast.Load()), attr='waitpid', ctx=ast.Load()), args=[ ast.Name(id='pid', ctx=ast.Load()), ast.Num(n=0), ], keywords=[])), ]
def test_interleaved_statements(self): # Do not combine if something between the Import statements. imp1 = ast.Import([ast.alias('X', None)]) imp2 = ast.Import([ast.alias('Y', None)]) from_import = ast.ImportFrom('Z', [ast.alias('W', None)], 0) module = ast.Module([imp1, from_import, imp2]) new_ast = self.transform.visit(module) self.assertEqual(len(new_ast.body), 3) for given, expect in zip(new_ast.body, (ast.Import, ast.ImportFrom, ast.Import)): self.assertIsInstance(given, expect) last_imp = new_ast.body[2] self.assertEqual(len(last_imp.names), 1) self.assertEqual(last_imp.names[0].name, 'Y')
def test_fqast_imports(self): simple = ast.alias(name="os", asname=None) node_ast = ast.dump(pi.ImportedName("os", ast.Import(names=[simple]), simple).canonical_ast) assert node_ast == "Name(id='os', ctx=Load())" module = ast.alias(name="os.path", asname=None) module_name = pi.ImportedName("os.path", ast.Import(names=[module]), module) module_ast = ast.dump(module_name.canonical_ast) assert module_ast == "Attribute(value=Name(id='os', ctx=Load()), attr='path', ctx=Load())" alias = ast.alias(name="os.path", asname="path") alias_ast = ast.dump( pi.ImportedName("path", ast.Import(names=[alias]), alias).canonical_ast ) assert alias_ast == "Attribute(value=Name(id='os', ctx=Load()), attr='path', ctx=Load())"
def test_import(self): names = pi.ImportedNames() names.add_import( ast.Import( names=[ast.alias(name="os", asname=None), ast.alias(name="sys", asname=None)] ) ) names.add_import(ast.Import(names=[ast.alias(name="ast", asname=None)])) assert len(names) == 3 assert "os" in names assert "sys" in names assert "ast" in names assert names["os"].node is names["sys"].node assert names["os"].node is not names["ast"].node
def _generate_trait_file(self): """Generate the traits.py file which contains marshmallow schema's for the various traits. For example ExpandableSchema, QuerySchema etc. Only used to serialize the query parameters for now. It uses the `self._trait_nodes` to write the body. """ nodes = [ ast.Import(names=[ast.alias(name="marshmallow", asname=None)], level=0), ast.ImportFrom( module="marshmallow", names=[ast.alias(name="fields", asname=None)], level=0, ), ast.ImportFrom( module="commercetools.helpers", names=[ast.alias(name="RemoveEmptyValuesMixin", asname=None)], level=0, ), ast.ImportFrom( module="commercetools.helpers", names=[ast.alias(name="OptionalList", asname=None)], level=0, ), ] nodes.extend(self._trait_nodes) return ast.Module(body=nodes)
def _build_st_import_statement(): """Build AST node for `import streamlit as __streamlit__`.""" return ast.Import( names=[ast.alias( name='streamlit', asname='__streamlit__', )], )
def augment_ast(root): mode = os.environ.get("PGZERO_MODE", "False") assert mode != "False" warning_prelude = "WARNING: Pygame Zero mode is turned on (Run → Pygame Zero mode)" try: import pgzero # @UnusedImport except ImportError: if mode == "True": print( warning_prelude + ",\nbut pgzero module is not found. Running program in regular mode.\n", file=sys.stderr, ) else: assert mode == "auto" return # Check if draw is defined for stmt in root.body: if isinstance(stmt, ast.FunctionDef) and stmt.name == "draw": break else: if mode == "auto": return else: print( warning_prelude + ",\nbut your program doesn't look like usual Pygame Zero program\n" + "(draw function is missing).\n", file=sys.stderr, ) # need more checks in auto mode if mode == "auto": # check that draw method is not called in the code for node in ast.walk(root): if (isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "draw"): return # prepend "import pgzrun as __pgzrun" imp = ast.Import([ast.alias("pgzrun", "__pgzrun")]) imp.lineno = 0 imp.col_offset = 0 ast.fix_missing_locations(imp) imp.tags = {"ignore"} root.body.insert(0, imp) # append "__pgzrun.go()" go = ast.Expr( ast.Call( ast.Attribute(ast.Name("__pgzrun", ast.Load()), "go", ast.Load()), [], [])) go.lineno = 1000000 go.col_offset = 0 ast.fix_missing_locations(go) go.tags = {"ignore"} root.body.append(go)
def visit_Import(self, node): # Import(alias* names) newnames = [] for A in node.names: firstName = A.name.partition('.')[0] if firstName in self._topnames: if A.asname is None: # Normally "import foo.bar.bang" will cause foo to # be added to globals. This code converts "import # x.y.z" to "import hash.x as x; import # hash.x.y.z" to effect the same thing. newnames.append( ast.copy_location( ast.alias(self._sourceHashDot + firstName, firstName), A)) newnames.append( ast.copy_location( ast.alias(self._sourceHashDot + A.name, None), A)) else: newnames.append( ast.copy_location( ast.alias(self._sourceHashDot + A.name, A.asname), A)) else: newnames.append(A) return ast.copy_location(ast.Import(newnames), node)
def find_dialectimport_ast(self, tree): '''Find the first dialect-import statement by scanning the AST `tree`. Transform the dialect-import into `import ...`, where `...` is the absolute module name the dialects are being imported from. As a side effect, import the dialect definition module. Primarily meant to be called with `tree` the AST of a module that uses dialects, but works with any `tree` that has a `body` attribute. A dialect-import is a statement of the form:: from ... import dialects, ... Return value is a dict `{dialectname: class, ...}` with all collected bindings from that one dialect-import. Each binding is a dialect, so usually there is just one. ''' for index, statement in enumerate(tree.body): if ismacroimport(statement, magicname="dialects"): break else: return "", {} module_absname, bindings = get_macros(statement, filename=self.filename, reload=False, allow_asname=False) # Remove all names to prevent dialects being used as regular run-time objects. # Always use an absolute import, for the unhygienic expose API guarantee. tree.body[index] = ast.copy_location( ast.Import(names=[ast.alias(name=module_absname, asname=None)]), statement) return module_absname, bindings
def _insert_imports(self, f, f_body, free_vars): add_imports = [] for k, v in f.__globals__.items(): if isinstance(v, ModuleType): add_imports.append((k, v)) old_free_vars = free_vars free_vars = [] for k, v in old_free_vars: if isinstance(v, ModuleType): add_imports.append((k, v)) else: free_vars.append((k, v)) if isinstance(f_body[0], ast.Expr) and isinstance(f_body[0].value, _ast_str_types): f_docstring = f_body[:1] f_body = f_body[1:] else: f_docstring = [] f_body = f_docstring + [ ast.Import(names=[ast.alias(name=v.__name__, asname=k if k != v.__name__ else None)]) for k, v in add_imports if (isinstance(self.imports, bool) or k in self.imports) and k not in _exclude ] + f_body return f_body, free_vars
def visit_Module(self, node): ''' Inject tracing logic on top module ''' self.generic_visit(node) # list of statements/expr to be prepended to body prebody = [] # "import ftracer" line = ast.Import([ast.alias('ftracer', None)]) prebody.append(line) # ftrace.set_trace attr = ast.Attribute(ast.Name('ftracer'), 'set_trace', ast.Load()) # ftrace.set_trace(<target>,<run>) call = ast.Call(func=attr, args=[ ast.Name(quoted(self.target_mpath)), ast.Name(quoted(self.run_mpath)) ], keywords=[]) # ftrace.set_trace(...) line = ast.Expr(call) prebody.append(line) node.body = prebody + node.body ast.fix_missing_locations(node) return node
def test_get_at_root(self) -> None: """Tests that `get_at_root` successfully gets the imports""" with open( path.join( path.dirname(__file__), "mocks", "eval{extsep}py".format(extsep=extsep) ) ) as f: imports = get_at_root(ast.parse(f.read()), (Import, ImportFrom)) self.assertIsInstance(imports, list) self.assertEqual(len(imports), 1) self.assertTrue( cmp_ast( imports[0], ast.Import( names=[ ast.alias( asname=None, name="cdd.tests.mocks", identifier=None, identifier_name=None, ) ], alias=None, ), ) )
def parse(code): """Annotate user code. Return annotated code (str) if annotation detected; return None if not. code: original user code (str) """ try: ast_tree = ast.parse(code) except Exception: raise RuntimeError('Bad Python code') transformer = Transformer() try: transformer.visit(ast_tree) except AssertionError as exc: raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0])) if not transformer.annotated: return None last_future_import = -1 import_nni = ast.Import(names=[ast.alias(name='nni', asname=None)]) nodes = ast_tree.body for i, _ in enumerate(nodes): if type(nodes[i] ) is ast.ImportFrom and nodes[i].module == '__future__': last_future_import = i nodes.insert(last_future_import + 1, import_nni) return astor.to_source(ast_tree)
def stubs_for_pydantic(models: Collection[Type[pd.BaseModel]], clsname: str = None) -> ast.Module: """ Generate stubs for Pydantic models Example: ast.unparse(stubs_for_models([db.User])) """ model_infos = [ModelInfo.from_pydantic_model(model) for model in models] ast_models = [model_info.to_ast() for model_info in model_infos] ast_imports = merge_imports(model_infos).to_ast() if clsname: ast_models = [ ast.ClassDef(clsname, bases=[], decorator_list=[], keywords=[], body=ast_models) ] return ast.Module([ ast.ImportFrom('__future__', [ast.alias('annotations')], level=0), ast.Import([ast.alias('pydantic')]), ast_imports, ast.parse('NoneType = type(None)'), *ast_models, ], type_ignores=[])
def _update_Import(self, node, stmt_list, idx): if not any(x for x in node.names if x.name == self._from_mod): return new_names = [] for i, alias in enumerate(node.names[:]): if alias.name == self._from_mod: new_names.append(alias) del node.names[i] if not node.names: del stmt_list[idx] if self._to_mod and self._to_id: for alias in new_names: new_node = ast.ImportFrom(module=self._to_mod, level=0, names=[ alias, ]) stmt_list.insert(idx, ast.copy_location(new_node, node)) elif self._to_mod: for alias in new_names: new_node = ast.Import(names=[ ast.alias(self._to_mod, alias.asname), ]) stmt_list.insert(idx, ast.copy_location(new_node, node))
def generate_code_file(mod_body, file, imports, external_functions_source=False, names="#"): for (module, name) in imports.as_imports: mod_body.insert( 0, ast.Import(names=[ast.alias(name=module, asname=name)], level=0)) for (module, name) in imports.from_imports: mod_body.insert( 0, ast.ImportFrom(module=module, names=[ast.alias(name=name, asname=None)], level=0)) if external_functions_source: mod_body.insert( 0, ast.ImportFrom(module=external_functions_source, names=[ast.alias(name='*', asname=None)], level=0)) mod = wrap_module(mod_body) print('Generating Source') source = names + ast.unparse(mod) return source
def visit_Module(self, node): self.need_import = False self.generic_visit(node) if self.need_import: importIt = ast.Import(names=[ast.alias(name='numpy', asname=None)]) node.body.insert(0, importIt) return node
def test_import(): assert eq( import_(alias.bar), ast.Import(names=[Alias(name='bar', asname=None)]) ) assert eq(import_('bar'), import_(alias.bar)) assert sourcify(import_('bar')) == 'import bar'
def instrument(self, sourcefile, inst_sourcefile, function): def get_source(path): with open(path) as source_file: return source_file.read() source = get_source(sourcefile) root = ast.parse(source) # Insert 'import covgen.wrapper as covw' in front of the file import_node = ast.Import( names=[ast.alias(name='covgen.wrapper', asname='covw')]) root.body.insert(0, import_node) ast.fix_missing_locations(root) function_node = None for stmt in root.body: if isinstance(stmt, ast.FunctionDef) and stmt.name == function: function_node = stmt break assert function_node self.collect_int_constants(function_node) self.visit(function_node) total_branches = { k: None for k in list( itertools.product(range(1, self.branch_id), [True, False])) } with open(inst_sourcefile, 'w') as instrumented: instrumented.write(astor.to_source(root)) return function_node, total_branches
def init_globals(opts, input_file): def make_stream(f): return Stream(imap(lambda x: Line(x.rstrip('\n\r')), iter(f))) pp = make_stream(input_file) globs = builtins.copy() globs['pp'] = pp for path in opts.import_paths: path = os.path.abspath(path) if path not in sys.path: sys.path.insert(0, path) for import_mod in opts.imports: import_node = ast.Import( names=[ast.alias(name=import_mod, asname=None)]) code = compile( ast.fix_missing_locations(ast.Module(body=[import_node])), 'import %s' % (import_mod, ), 'exec') eval(code, globs) for eval_str in opts.evals: try: _exec(eval_str, globs) except SyntaxError as e: raise Exit("got error: %s\nwhile evaluating: %s" % (e, eval_str)) files = [make_stream(open(f)) for f in opts.files] globs['files'] = files if len(files) > 0: # convenience for single file operation globs['ff'] = files[0] return globs
def enum_to_class(enum: Enum) -> ResolvedClassResult: """Convert Enum into AST class definition.""" enum_import = ast.Import(names=[ast.alias(name='enum', asname=None)]) class_body = [] if enum.doc: class_body.append(docstring_declaration(enum.doc)) members = [ ast.Expr(value=ast.Assign(targets=[render_enum_name(symbol)], value=ast.Str(s=symbol))) for symbol in enum.symbols ] class_body.extend(sorted(members, key=lambda e: e.value.value.s)) enum_class = ast.ClassDef( name=enum.name, bases=[ast.Attribute(value=ast.Name(id='enum'), attr='Enum')], keywords=[], body=class_body, decorator_list=[ ast.Attribute(value=ast.Name(id='enum'), attr='unique') ] # just for signalling purposes ) return ResolvedClassResult( resolved_class=enum_class, imports=[enum_import], new_frontier=[], )
def visit_Import(self, node): for i in node.names: if i.asname: aliases[i.asname] = i.name return ast.Import(names=[ ast.alias(name=i.name, asname=(rename(i.name))) for i in node.names ])