예제 #1
0
    def visit_FunctionDef(self, function: ast.FunctionDef):
        """
        Visitor of the function node

        Includes the method in the scope of its module

        :param function:
        """
        fun_args = self.visit(function.args)
        fun_rtype_symbol = self.visit(function.returns) if function.returns is not None else Type.none

        if fun_rtype_symbol is None:
            # it is a function with None return: Main(a: int) -> None:
            raise NotImplementedError

        if isinstance(fun_rtype_symbol, str):
            symbol = self.get_symbol(function.returns.id)
            fun_rtype_symbol = self.get_type(symbol)

        fun_return: IType = self.get_type(fun_rtype_symbol)
        fun_decorators: List[Method] = self._get_function_decorators(function)

        if Builtin.Metadata in fun_decorators:
            self._read_metadata_object(function)
            return Builtin.Metadata

        method = Method(args=fun_args, defaults=function.args.defaults, return_type=fun_return,
                        origin_node=function, is_public=Builtin.Public in fun_decorators)
        self._current_method = method
        self._scope_stack.append(SymbolScope())

        # don't evaluate constant expression - for example: string for documentation
        from boa3.constants import SYS_VERSION_INFO
        if SYS_VERSION_INFO >= (3, 8):
            function.body = [stmt for stmt in function.body
                             if not (isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Constant))]
        else:
            function.body = [stmt for stmt in function.body
                             if not (isinstance(stmt, ast.Expr) and
                                     (hasattr(stmt.value, 'n') or hasattr(stmt.value, 's'))
                                     )]
        for stmt in function.body:
            self.visit(stmt)

        self.__include_callable(function.name, method)
        method_scope = self._scope_stack.pop()
        global_scope_symbols = self._scope_stack[0].symbols if len(self._scope_stack) > 0 else {}

        for var_id, var in method_scope.symbols.items():
            if isinstance(var, Variable) and var_id not in self._annotated_variables:
                method.include_variable(var_id, Variable(UndefinedType, var.origin))
            else:
                method.include_symbol(var_id, var)

        self._annotated_variables.clear()
        self._current_method = None
예제 #2
0
    def visit_FunctionDef(self, function: ast.FunctionDef):
        """
        Visitor of the function node

        Includes the method in the scope of its module

        :param function:
        """
        fun_args = self.visit(function.args)
        fun_rtype_symbol = self.visit(
            function.returns) if function.returns is not None else Type.none

        if fun_rtype_symbol is None:
            # it is a function with None return: Main(a: int) -> None:
            raise NotImplementedError

        if isinstance(fun_rtype_symbol, str):
            symbol = self.get_symbol(function.returns.id)
            fun_rtype_symbol = self.get_type(symbol)

        fun_return: IType = self.get_type(fun_rtype_symbol)
        fun_decorators: List[Method] = self._get_function_decorators(function)

        if Builtin.Metadata in fun_decorators:
            self._read_metadata_object(function)
            return Builtin.Metadata

        method = Method(args=fun_args,
                        defaults=function.args.defaults,
                        return_type=fun_return,
                        origin_node=function,
                        is_public=Builtin.Public in fun_decorators)
        self._current_method = method

        # don't evaluate constant expression - for example: string for documentation
        from boa3.constants import SYS_VERSION_INFO
        if SYS_VERSION_INFO >= (3, 8):
            function.body = [
                stmt for stmt in function.body
                if not (isinstance(stmt, ast.Expr)
                        and isinstance(stmt.value, ast.Constant))
            ]
        else:
            function.body = [
                stmt for stmt in function.body
                if not (isinstance(stmt, ast.Expr) and
                        (hasattr(stmt.value, 'n') or hasattr(stmt.value, 's')))
            ]
        for stmt in function.body:
            self.visit(stmt)

        self.__include_callable(function.name, method)
        self._current_method = None
예제 #3
0
    def visit_FunctionDef(self, node: FunctionDef):
        if _is_qdef(node):
            fix_location = partial(copy_location, old_node=node)
            node.body = self.generic_visit(Suite(node.body)).body
            new_nodes = [node]
            if _auto_adjoint(node):
                adjoint_implementation = fix_location(
                    self._compute_adjoint(node))
                adjoint_implementation.body.reverse()
                new_nodes.append(adjoint_implementation)
                new_nodes.extend(
                    map(fix_location,
                        _wire_adjoints(node, adjoint_implementation)))

            if _auto_controlled(node):
                controlled_implementation = fix_location(
                    self._compute_controlled(node))
                new_nodes.append(controlled_implementation)
                new_nodes.extend(
                    map(fix_location,
                        _wire_controlled(node, controlled_implementation)))

            return new_nodes

        self.generic_visit(node)
        return node
def add_function_doc_to_ast(ast_function: ast.FunctionDef) -> None:
    """<#TODO Description>

    Parameters
    ----------
    ast_function : ast
        <#TODO Description>

    Returns
    -------
    None : <#TODO return description>

    Examples
    --------
    >>> from crawto-quality import crawto_doc
    >>> add_function_doc_to_ast(ast_function=<#TODO Example Value>)
    <#TODO Method Return Value>
    """
    functiondef = CrawtoFunction(**ast_function.__dict__,
                                 doc_string=ast.get_docstring(ast_function))
    function_docstring = ast.parse(functiondef.docs)
    expr = function_docstring.body[0]
    new_function_doc = [expr]
    if ast.get_docstring(ast_function):
        ast_function.body.pop(0)
    ast_function.body = new_function_doc + ast_function.body
예제 #5
0
    def visit_FunctionDef(self, node: FunctionDef) -> Optional[AST]:
        scope = node._pyo_scope

        # find unused vars
        for name in scope.locals_:
            if len(get_loads(scope, name)) == 0:
                for usage in scope.locals_[name]:
                    if isinstance(usage, Name):
                        # usage.id = _PYO_UNUSED
                        usage._pyo_unused = True
                    elif isinstance(usage, (Nonlocal, Global)):
                        usage.names.remove(name)
                    # TODO: elif IMport

        node = self.generic_visit(node)

        # remove passes
        node.body = [stmt for stmt in node.body if not isinstance(stmt, Pass)]
        if not node.body:
            node.body = [Pass()]
        return node
    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
        with_item = ast.withitem(context_expr=ast.Call(func=ast.Name(
            id='SoftAssertions', ctx=ast.Load()),
                                                       args=[],
                                                       keywords=[]),
                                 optional_vars=ast.Name(id='ctx',
                                                        ctx=ast.Store()))
        with_stmt = ast.With(items=[with_item], body=node.body)

        node.body = [with_stmt]
        new_node = fix_node(node)

        return self.generic_visit(new_node)
예제 #7
0
    def _expand_view(self, node: ast.FunctionDef) -> Any:
        # add `contract_callback: Contract[return_type]` to method parameter
        callback_annotation = ast.Subscript(
            value=ast.Name(id='Contract', ctx=ast.Load()),
            slice=ast.Index(value=node.returns),
            ctx=ast.Load())
        callback_argument = ast.arg(arg="__callback__",
                                    annotation=callback_annotation)
        node.args.args.append(callback_argument)

        # remove the return type annotation
        node.returns = None

        # transform all return expressions into `transaction` function call
        node.body = self._transform_block(node.body)

        return node
예제 #8
0
    def visit_FunctionDef(self, node: ast.FunctionDef):
        # * TypeDef: FunctionDef(identifier name, arguments args,
        # *             stmt* body, expr* decorator_list, expr? returns,
        # *             string? type_comment)
        # If this is the correct target function:
        # Add a new line to the body of the function that will contain the execution trace as part of the
        # GraphQL context object
        if node.name == self.function_name:
            context_add = ast.parse(
                "info.context[\"trace_execution\"] = []").body[0]

            node.body = [context_add] + node.body
            self.generic_visit(node)
            return node
        else:
            self.generic_visit(node)
            return node
예제 #9
0
    def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
        prologue_body_instructions = []
        arguments = node.args.args

        # skip class instantiations and functions of 1 argument
        if len(arguments) > 1:
            ### generate argument dataclass
            arguments_spec = {
                argument_node.arg: argument_node.annotation
                for argument_node in arguments
            }
            param_dataclass_name = node.name + "Param"
            self.dataclasses.append(
                make_dataclass(param_dataclass_name, arguments_spec))

            # tuplify arguments
            param_name = node.name + "__param"

            self.env[node.name] = param_dataclass_name

            node.args.args = [
                ast.arg(arg=param_name,
                        annotation=ast.Name(id=param_dataclass_name,
                                            ctx=ast.Load()))
            ]

            # destructure tuplified arguments
            prologue_body_instructions = [
                ast.Assign(targets=[ast.Name(id=attr_name, ctx=ast.Store())],
                           value=ast.Attribute(value=ast.Name(id=param_name,
                                                              ctx=ast.Load()),
                                               attr=attr_name,
                                               ctx=ast.Load()),
                           type_comment=None)
                for attr_name in arguments_spec.keys()
            ]

        new_body = [self.visit(body_node) for body_node in node.body]

        node.body = prologue_body_instructions + new_body

        return node
예제 #10
0
    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
        if self.inst_type == "log_func_exec":
            for funcname in self.inst_args["funcnames"]:
                if re.fullmatch(funcname, node.name):
                    self.log_func_exec_enable = True
        elif self.inst_type == "log_func_entry":
            for funcname in self.inst_args["funcnames"]:
                if re.fullmatch(funcname, node.name):
                    node.body.insert(
                        0, self.get_instrument_node("Function Entry",
                                                    node.name))
        elif self.inst_type in ("log_var", "log_number"):
            instrumented_nodes: List[ast.stmt] = []
            args = node.args
            if "posonlyargs" in args._fields:
                func_args_name = [
                    a.arg
                    for a in args.posonlyargs + args.args + args.kwonlyargs
                ]
            else:
                # python 3.6 does not have posonlyargs
                func_args_name = [a.arg for a in args.args + args.kwonlyargs]
            if "vararg" in args._fields and args.vararg:
                func_args_name.append(args.vararg.arg)
            if "kwarg" in args._fields and args.kwarg:
                func_args_name.append(args.kwarg.arg)
            for name in func_args_name:
                for pattern in self.inst_args["varnames"]:
                    if re.fullmatch(pattern, name):
                        instrumented_nodes.append(
                            self.get_instrument_node("Variable Assign", name))
                        break

        self.generic_visit(node)

        if self.inst_type == "log_func_exec":
            self.log_func_exec_enable = False
        elif self.inst_type in ("log_var",
                                "log_number") and instrumented_nodes:
            node.body = instrumented_nodes + node.body
        return node
예제 #11
0
def build_method_def(
    method: ast.FunctionDef
) -> Tuple[ast.FunctionDef, ast.AsyncFunctionDef]:
    orig = method
    method = deepcopy(method)
    rtype_arg = method.decorator_list[0].args[0]  # type: ignore

    method.decorator_list = []
    method.returns = rtype_arg
    # method.returns.id = rtype  # type: ignore

    call, async_call = build_ret(method, rtype_arg)
    docstring = ast.get_docstring(orig)
    method.body = [call]
    if docstring:
        method.body.insert(0, orig.body[0])

    async_method = deepcopy(method)
    async_method = ast.AsyncFunctionDef(  # type: ignore
        **async_method.__dict__
    )
    async_method.body[-1] = async_call

    return method, async_method  # type: ignore
예제 #12
0
 def visit_FunctionDef(self, node: ast.FunctionDef):
     node.body = self._process_stmts(node.body)
     return node
예제 #13
0
 def visit_FunctionDef(self, node: ast.FunctionDef):
     if node.name == node_to_be_patched:
         node.body = ast.parse(
             'return os.path.join(os.path.dirname(sys.executable), "mpl-data")'
         ).body
     return node
예제 #14
0
파일: inline.py 프로젝트: binref/refinery
        def visit_FunctionDef(self, node: FunctionDef):
            nonlocal function_name, context
            if node is not function_head:
                return node
            function_body = []
            function_name = node.name
            function_args = [arg.arg for arg in node.args.args]
            inlined_start = 1

            if inspect.ismethod(method):
                inlined_start += 1

            iterator_name = function_args[inlined_start - 1]
            function_args[:inlined_start] = []
            arity = len(function_args)

            try:
                vararg = as_arg(node.args.vararg.arg)
            except Exception:
                if arity != len(inline_args):
                    raise ArgumentCountMismatch
            else:
                context[vararg] = inline_args[arity:]

            for name, value in zip(function_args, inline_args):
                targets = [Name(id=as_var(name), ctx=Store())]
                if isinstance(value, PassAsConstant):
                    context[as_var(name)] = value.value
                    continue
                if isinstance(value, (int, str, bytes)):
                    context[as_var(name)] = value
                    continue
                context[as_arg(name)] = value
                function_body.append(
                    Assign(targets=targets,
                           value=Call(func=Name(id='next', ctx=Load()),
                                      args=[Name(id=as_arg(name), ctx=Load())],
                                      keywords=[])))

            if node.args.vararg:
                name = node.args.vararg.arg
                function_body.append(
                    Assign(targets=[Name(id=as_var(name), ctx=Store())],
                           value=Call(
                               func=Name(id='tuple', ctx=Load()),
                               args=[
                                   GeneratorExp(
                                       elt=Call(func=Name(id='next',
                                                          ctx=Load()),
                                                args=[
                                                    Name(id=as_tmp(name),
                                                         ctx=Load())
                                                ],
                                                keywords=[]),
                                       generators=[
                                           comprehension(
                                               is_async=0,
                                               target=Name(id=as_tmp(name),
                                                           ctx=Store()),
                                               iter=Name(id=as_arg(name),
                                                         ctx=Load()),
                                               ifs=[])
                                       ])
                               ],
                               keywords=[])))

            function_body.extend(node.body)
            context[as_arg(iterator_name)] = iterator
            function_body = [
                For(target=Name(id=as_var(iterator_name), ctx=Store()),
                    iter=Name(id=as_arg(iterator_name), ctx=Load()),
                    body=function_body,
                    orelse=[])
            ]

            node.body = function_body
            node.args.args = [arg(arg=as_var('self'))]
            node.args.vararg = None
            node.decorator_list = []
            return node
예제 #15
0
 def visit_FunctionDef(self, node: ast.FunctionDef):
     node.body = [self.visit(n) for n in node.body]
     return node
예제 #16
0
파일: debugger.py 프로젝트: legalian/xaxiom
		def visit_FunctionDef(self, node: ast.FunctionDef):
			if len(node.body)>0 and isinstance(node.body[0],Expr) and isinstance(node.body[0].value,Str) and node.body[0].value.s == 'dbg_ignore':
				oldHHRC = self.hotHasReturnCheck
				oldBV = self.hot
				self.hotHasReturnCheck = False
				self.hot = None
				self.generic_visit(node)
				self.hotHasReturnCheck = oldHHRC
				self.hot = oldBV
				return node
			# print("visiting",node.name)
			frozone = len(self.funcNames)
			self.funcNames.append(node.name)#+str(node.lineno)
			if hasattr(node,'definedforclass'):
				self.classowners.append(node.definedforclass)
			else:
				self.classowners.append(None)
			self.scopes.append(0)
			self.funcparams.append([k.arg for k in node.args.args])
			hasEnter = False
			hasExit = False
			fpad = 2
			if len(node.body)>0:
				if  self.isEnterFunc(node.body[0]): hasEnter=True
				elif self.isExitFunc(node.body[0],node.args.args): hasExit=True
				else: fpad-=1
			else: fpad-=1
			if len(node.body)>1:
				if  self.isEnterFunc(node.body[1]): hasEnter=True
				elif self.isExitFunc(node.body[1],node.args.args): hasExit=True
				else: fpad-=1
			else: fpad-=1

			oldHHRC = self.hotHasReturnCheck
			oldBV = self.hot
			self.hotHasReturnCheck = hasExit
			self.hot = frozone
			self.generic_visit(node)


			if len(self.exitpatterns.get(node.name,[])) > len(node.args.args):
				print("Exit pattern for function ",node.name," has too many parameters.")
				assert False

			shobb = []
			for z in range(len(node.args.args) if hasExit else len(self.exitpatterns.get(node.name,[]))):
				# print("Assign2: ",str(frozone))
				shobb.append(Assign(
					targets=[Name(id=node.args.args[z].arg+'_dbg_str_var_'+str(frozone), ctx=Store())],
					value=Name(id=node.args.args[z].arg,ctx=Load())
				))

			if hasExit:
				expattern = [k.arg for k in node.args.args]
				# sin.insert(1,Expr(value=Call(func=Name(id='_dbgExit', ctx=Load()), args=[Name(id=pn+'_dbg_str_var_'+str(self.hot),ctx=Load()) for pn in expattern]+[Name(id='_dbg_ret_var', ctx=Load())], keywords=[])))


				node.body.append(Expr(value=Call(func=Name(id='_dbgExit', ctx=Load()), args=[Name(id=pn+'_dbg_str_var_'+str(self.hot),ctx=Load()) for pn in expattern]+[NameConstant(value=None)], keywords=[])))
			if node.name in self.exitpatterns:
				expattern = self.exitpatterns[node.name]
				if len(node.args.args)<len(expattern) or expattern != [k.arg for k in node.args.args][:len(expattern)]:
					print("You defined an exit pattern, "+node.name+", and then you define a function with different first N parameters from it.")
					assert False


				node.body.append(Expr(value=Call(func=Name(id='_dbgExit_'+node.name, ctx=Load()), args=[Name(id=pn+'_dbg_str_var_'+str(self.hot),ctx=Load()) for pn in expattern]+[NameConstant(value=None)], keywords=[])))

			track_frames = False
			freebody = node.body[fpad:]+[]
			if track_frames:
				freebody = [With(
					items=[
						withitem(
							context_expr=Call(func=Attribute(value=Name(id='madscience_debugger', ctx=Load()), attr='push_context', ctx=Load()), args=[Num(n=frozone)], keywords=[]),
							optional_vars=Name(id='madscience_debug_context', ctx=Store())
						)
					],
					body=node.body[fpad:]+[]
				)]

			node.body = shobb + node.body[:fpad] + freebody
			

			if track_frames:
				self.visitblock(node.body[-1].body,"func")
			else:
				self.visitblock(node.body[-1],"func")
			# if node.name=="verify":
			# 	print("verify mutated",node.lineno)
			# 	node.body.insert(0,Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Str(s='function was knocked on '+str(node.lineno))], keywords=[])))
			# 	node.body[-1].body.insert(0,Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Str(s='function was visited '+str(node.lineno))], keywords=[])))
			# print("mutated",node.name)
			# print(self.enterpatterns,frozone,node.name)
			if hasEnter: node.body.insert(fpad+len(shobb),Expr(value=Call(func=Name(id='_dbgEnter', ctx=Load()), args=[], keywords=[])))
			if node.name in self.enterpatterns:
				# print("enter pattern added.")
				expattern = self.enterpatterns[node.name]
				if len(node.args.args)<len(expattern) or expattern != [k.arg for k in node.args.args][:len(expattern)]:
					print("You defined an enter pattern, "+node.name+", and then you define a function with different first N parameters from it.")
					assert False
				node.body.insert(fpad+len(shobb),Expr(value=Call(func=Name(id='_dbgEnter_'+node.name, ctx=Load()), args=[Name(id=pn,ctx=Load()) for pn in expattern], keywords=[])))
			ast.fix_missing_locations(node)
			if self.isTestFunc(node):
				# self.generic_visit(node)
				sin = [
					node,
					Expr(value=Call(func=Name(id='_dbgTest', ctx=Load()), args=[], keywords=[]))
				]
				ast.copy_location(sin[1], node)
				ast.fix_missing_locations(sin[1])
				return sin
			self.absorbEnterPattern(node)
			self.absorbExitPattern(node)
			# print()
			# print(ast.dump(node))
			self.hotHasReturnCheck = oldHHRC
			self.hot = oldBV



			return node
예제 #17
0
 def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
     """Add invariants to function"""
     # print(ast.dump(node))
     node.body = self.insert_assertions(node.body)
     return node