Exemplo n.º 1
0
def generate_parallel_function(loop):
    # need to generate random string here
    name = "nest_fn" + str(id(loop))
    args = []
    for arg in loop.non_locals:
        args.append(ast.arg(arg=arg, annotation=None))
    args.append(ast.arg(arg="proc_id", annotation=None))
    args = ast.arguments(args=args, vararg=None, kwarg=None, defaults=[], kwonlyargs=[], kw_defaults = [])
#    return_values = (ast.parse(str(loop.non_locals))).body[0]
    return_values = [ast.Name(id=arg, ctx=ast.Load(), lineno=1, col_offset=0) for arg in loop.non_locals]
    print(return_values)
    return_stmt = ast.fix_missing_locations(ast.Return(value=return_values, lineno=1, col_offset=0))
    return_template = "return ["
    for i,arg in enumerate(loop.non_locals):
        if i ==0:
            return_template += "%s" % arg
        else:
            return_template += " ,%s" % arg
    return_template += "]"
    print("return values")
    print(return_stmt)
    transformed_tree = BoundsTransformer(loop).visit(loop.node)
#    transformed_tree = AccessTransformer(loop).visit(transformed_tree)
    # transformed_tree.lineno=0
    # transformed_tree.col_offset=0
    body = [transformed_tree, ast.parse(return_template).body[0]]
    dectorator_list = []
    fun_def = ast.FunctionDef(name=name, args=args, body=body, decorator_list=[], returns=None)
    return ast.fix_missing_locations(fun_def)
Exemplo n.º 2
0
def p_push_primary(p):  # noqa
    """push_primary : DOLLARSIGN primary"""
    arg_list = ast.arguments(
        args=[ast.arg(arg='stack', annotation=None),
              ast.arg(arg='stash', annotation=None)],
        vararg=None,
        kwonlyargs=[],
        kwarg=None,
        defaults=[],
        kw_defaults=[])
    # print(p[2])
    if isinstance(p[2], ast.Name):
        pass
    # TODO: not a very good check
    elif 'stack.pop().' in astunparse.unparse(p[2][0]):
        # we are pushing an attributeref
        # get rid of the _call to leave the stack.pop().<attr> and concatify it
        p[2] = ast.Call(func=ast.Name(id='concatify', ctx=ast.Load()),
                        args=p[2][0].value.args[0:1], keywords=[])
    else:
        # print(p[2])
        p[2] = ast.Call(func=ast.Name(id='ConcatFunction', ctx=ast.Load()),
                        args=[ast.Lambda(arg_list, _combine_exprs(p[2]))],
                        keywords=[])
    p[0] = [ast.Expr(_push(p[2]))]
    _set_line_info(p)
 def create_arguments(self, args=[], vararg=None, varargannotation=None,
         kwonlyargs=[], kwarg=None, kwargannotation=None, defaults=[],
         kw_defaults=[None]):
     args = [ast.arg(x, None) for x in args]
     kwonlyargs = [ast.arg(x, None) for x in kwonlyargs]
     return ast.arguments(args, vararg, varargannotation, kwonlyargs,
                             kwarg, kwargannotation, defaults, kw_defaults)
Exemplo n.º 4
0
    def build_init(cls):
        # build arguments objects
        args = [ast.arg('self', None)] + [
            ast.arg(arg, None)
            for arg in cls._fields
        ]

        body = [
            ast.Assign(
                [ast.Attribute(
                    ast.Name('self', ast.Load()),
                    arg, ast.Store()
                )],
                ast.Name(arg, ast.Load())
            )
            for arg in cls._fields
        ]

        function = ast.FunctionDef(
            '__init__',
            ast.arguments(args, None, [], [], None, []),
            body, [], None
        )

        return ast.fix_missing_locations(
            ast.Module([function])
        )
Exemplo n.º 5
0
 def _check_arguments(self, fac, check):
     def arguments(args=None, vararg=None,
                   kwonlyargs=None, kwarg=None,
                   defaults=None, kw_defaults=None):
         if args is None:
             args = []
         if kwonlyargs is None:
             kwonlyargs = []
         if defaults is None:
             defaults = []
         if kw_defaults is None:
             kw_defaults = []
         args = ast.arguments(args, vararg, kwonlyargs, kw_defaults,
                              kwarg, defaults)
         return fac(args)
     args = [ast.arg("x", ast.Name("x", ast.Store()))]
     check(arguments(args=args), "must have Load context")
     check(arguments(kwonlyargs=args), "must have Load context")
     check(arguments(defaults=[ast.Num(3)]),
                    "more positional defaults than args")
     check(arguments(kw_defaults=[ast.Num(4)]),
                    "length of kwonlyargs is not the same as kw_defaults")
     args = [ast.arg("x", ast.Name("x", ast.Load()))]
     check(arguments(args=args, defaults=[ast.Name("x", ast.Store())]),
                    "must have Load context")
     args = [ast.arg("a", ast.Name("x", ast.Load())),
             ast.arg("b", ast.Name("y", ast.Load()))]
     check(arguments(kwonlyargs=args,
                       kw_defaults=[None, ast.Name("x", ast.Store())]),
                       "must have Load context")
Exemplo n.º 6
0
def make_lambda(expression, args, env=None):
    # type: (ast.Expression, List[str], Dict[str, Any]) -> types.FunctionType
    """
    Create an lambda function from a expression AST.

    Parameters
    ----------
    expression : ast.Expression
        The body of the lambda.
    args : List[str]
        A list of positional argument names
    env : Optional[Dict[str, Any]]
        Extra environment to capture in the lambda's closure.

    Returns
    -------
    func : types.FunctionType
    """
    # lambda *{args}* : EXPRESSION
    lambda_ = ast.Lambda(
        args=ast.arguments(
            args=[ast.arg(arg=arg, annotation=None) for arg in args],
            varargs=None,
            varargannotation=None,
            kwonlyargs=[],
            kwarg=None,
            kwargannotation=None,
            defaults=[],
            kw_defaults=[]),
        body=expression.body,
    )
    lambda_ = ast.copy_location(lambda_, expression.body)
    # lambda **{env}** : lambda *{args}*: EXPRESSION
    outer = ast.Lambda(
        args=ast.arguments(
            args=[ast.arg(arg=name, annotation=None) for name in (env or {})],
            varargs=None,
            varargannotation=None,
            kwonlyargs=[],
            kwarg=None,
            kwargannotation=None,
            defaults=[],
            kw_defaults=[],
        ),
        body=lambda_,
    )
    exp = ast.Expression(body=outer, lineno=1, col_offset=0)
    ast.fix_missing_locations(exp)
    GLOBALS = __GLOBALS.copy()
    GLOBALS["__builtins__"] = {}
    # pylint: disable=eval-used
    fouter = eval(compile(exp, "<lambda>", "eval"), GLOBALS)
    assert isinstance(fouter, types.FunctionType)
    finner = fouter(**env)
    assert isinstance(finner, types.FunctionType)
    return finner
Exemplo n.º 7
0
def build_class():
    """
    Constructs a :class:`ast.ClassDef` node that wraps the entire template
    file. The class will have an entry function ``root`` with:

    .. function:: root(context)

        Starts the template parsing with the given context.

        :returns: Returns a generator of strings that can be joined to the
            rendered template.

    :returns: a 2-tuple with the class and the entry function
    """
    args = {}
    if PY3:
        args.update({
            'args': [
                ast.arg(arg='self', annotation=None),
                ast.arg(arg='context', annotation=None),
            ],
            'kwonlyargs': [],
            'kw_defaults': [],
        })
    else:
        args['args'] = [
            ast.Name(id='self', ctx=ast.Param()),
            ast.Name(id='context', ctx=ast.Param())
        ]
    root_func = ast.FunctionDef(
        name='root',
        args=ast.arguments(
            vararg=None,
            kwarg=None,
            defaults=[],
            **args
        ),
        body=[
            # we add an empty string to guarantee for a string and generator on
            # root level
            build_yield(ast.Str(s=''))
        ],
        decorator_list=[]
    )
    klass = ast.ClassDef(
        name='Template',
        bases=[ast.Name(id='object', ctx=ast.Load())],
        keywords=[],
        starargs=None,
        kwargs=None,
        body=[root_func],
        decorator_list=[]
    )
    return klass, root_func
Exemplo n.º 8
0
    def decorator(wrapped):
        spec = inspect.getargspec(wrapped)
        name = wrapped.__name__

        assert spec.varargs is not None

        # Example was generated with print ast.dump(ast.parse("def f(a, b, *args, **kwds): return call_wrapped((a, b), args, kwds)"), include_attributes=True)
        # http://code.activestate.com/recipes/578353-code-to-source-and-back/ helped a lot
        # http://stackoverflow.com/questions/10303248#29927459
        if sys.hexversion < 0x03000000:
            wrapper_ast_args = ast.arguments(
                    args=[ast.Name(id=a, ctx=ast.Param(), lineno=1, col_offset=0) for a in spec.args],
                    vararg=spec.varargs,
                    kwarg=spec.keywords,
                    defaults=[]
                )
        else:
            wrapper_ast_args = ast.arguments(
                args=[ast.arg(arg=a, annotation=None, lineno=1, col_offset=0) for a in spec.args],
                vararg=None if spec.varargs is None else ast.arg(arg=spec.varargs, annotation=None, lineno=1, col_offset=0),
                kwonlyargs=[],
                kw_defaults=[],
                kwarg=None if spec.keywords is None else ast.arg(arg=spec.keywords, annotation=None, lineno=1, col_offset=0),
                defaults=[]
            )
        wrapper_ast = ast.Module(body=[ast.FunctionDef(
            name=name,
            args=wrapper_ast_args,
            body=[ast.Return(value=ast.Call(
                func=ast.Name(id="wrapped", ctx=ast.Load(), lineno=1, col_offset=0),
                args=[ast.Name(id=a, ctx=ast.Load(), lineno=1, col_offset=0) for a in spec.args],
                keywords=[],
                starargs=ast.Call(
                    func=ast.Name(id="flatten", ctx=ast.Load(), lineno=1, col_offset=0),
                    args=[ast.Name(id=spec.varargs, ctx=ast.Load(), lineno=1, col_offset=0)],
                    keywords=[], starargs=None, kwargs=None, lineno=1, col_offset=0
                ),
                kwargs=None if spec.keywords is None else ast.Name(id=spec.keywords, ctx=ast.Load(), lineno=1, col_offset=0),
                lineno=1, col_offset=0
            ), lineno=1, col_offset=0)],
            decorator_list=[],
            lineno=1,
            col_offset=0
        )])
        wrapper_code = [c for c in compile(wrapper_ast, "<ast_in_variadic_py>", "exec").co_consts if isinstance(c, types.CodeType)][0]
        wrapper = types.FunctionType(wrapper_code, {"wrapped": wrapped, "flatten": flatten}, argdefs=spec.defaults)

        functools.update_wrapper(wrapper, wrapped)
        if wrapper.__doc__ is not None:
            wrapper.__doc__ = "Note that this function is variadic. See :ref:`variadic-functions`.\n\n" + wrapper.__doc__
        return wrapper
Exemplo n.º 9
0
def p_funcdef(p):  # noqa
    """funcdef : DEF funcname COLON suite"""
    arg_list = ast.arguments(
        args=[ast.arg(arg='stack', annotation=None),
              ast.arg(arg='stash', annotation=None)],
        vararg=None,
        kwonlyargs=[],
        kwarg=None,
        defaults=[],
        kw_defaults=[])
    p[0] = ast.FunctionDef(p[2], arg_list, p[4],
                           [ast.Name(id='ConcatFunction', ctx=ast.Load())],
                           None)
    _set_line_info(p)
 def make_function_def(self, body, name):
     return ast.FunctionDef(
         name=name,
         body=body,
         args=ast.arguments(
             args=[
                 ast.arg(arg='self'),
                 ast.arg(arg=CONTEXT_ARG_NAME),
             ],
             kwonlyargs=[],
             kw_defaults=[],
             defaults=[]),
         decorator_list=[],
     )
Exemplo n.º 11
0
    def _functiondef(self, name, args, body, lineno, col_offset):
        args = [ast.arg(arg=c.car.name, annotation=None) for c in args]
        body = self._body(c.car for c in body)

        # Rewrite return
        body[-1] = ast.Return(
            value=body[-1].value,
            lineno=body[-1].lineno,
            col_offset=body[-1].col_offset
        )

        return ast.FunctionDef(
            name=name,
            args=ast.arguments(
                args=args,
                defaults=[],
                kw_defaults=[],
                kwarg=None,
                kwargannotation=None,
                kwonlyargs=[],
                vararg=None,
                varargannotation=None
            ),
            body=body,
            returns=None,
            decorator_list=[],
            lineno=lineno,
            col_offset=col_offset
        )
Exemplo n.º 12
0
def _visit_local(gen_sym, node, to_mangle, mangled):
    """
    Replacing known variables with literal values
    """
    is_name = type(node) == ast.Name

    node_id = node.id if is_name else node.arg

    if node_id in to_mangle:

        if node_id in mangled:
            mangled_id = mangled[node_id]
        else:
            mangled_id, gen_sym = gen_sym('mangled')
            mangled = mangled.set(node_id, mangled_id)

        if is_name:
            new_node = ast.Name(id=mangled_id, ctx=node.ctx)
        else:
            new_node = ast.arg(arg=mangled_id, annotation=node.annotation)

    else:
        new_node = node

    return gen_sym, new_node, mangled
Exemplo n.º 13
0
    def compile_Module(self, node):
        self.is_builtins = self.module_name == "builtins"

        body = node.body
        if self.print_module_result:
            try:
                last_body_item = body[-1]
            except IndexError:
                last_body_item = ast.Name("None", ast.Load())
            print_fn = ast.Name("print", ast.Load())
            last_body_item = ast.Call(print_fn, [last_body_item], None, None, None)
            body = body[:-1] + [last_body_item]

        if not self.bare and not self.is_builtins:
            module_name = ast.Str(self.module_name)
            args = ast.arguments([ast.arg(self.local_module_name.id, None)], None, None, None, None, None, None, None)
            func = ast.FunctionDef(name = '', args = args, body = body, decorator_list = [], returns = None)
            to_call = ast.Name("__registermodule__", ast.Load())
            call = ast.Call(to_call, [module_name, func], None, None, None)
            result = self.compile_node(call)
        else:
            context = self.context_stack.new()
            result = self.compile_node(body)
            result = self.compile_statement_list([context.get_vars(True), JSCode(result)])
            self.context_stack.pop()

        if self.is_builtins:
            self.main_compiler.modules = [result] + self.main_compiler.modules
        else:
            self.main_compiler.modules.append(result)
Exemplo n.º 14
0
def _make_fn(name, chain_fn, args, defaults):
    args_with_self = ['_self'] + list(args)
    arguments = [_ast.Name(id=arg, ctx=_ast.Load()) for arg in args_with_self]
    defs = [_ast.Name(id='_def{0}'.format(idx), ctx=_ast.Load()) for idx, _ in enumerate(defaults)]
    if _PY2:
        parameters = _ast.arguments(args=[_ast.Name(id=arg, ctx=_ast.Param()) for arg in args_with_self],
                                    defaults=defs)
    else:
        parameters = _ast.arguments(args=[_ast.arg(arg=arg) for arg in args_with_self],
                                    kwonlyargs=[],
                                    defaults=defs,
                                    kw_defaults=[])
    module_node = _ast.Module(body=[_ast.FunctionDef(name=name,
                                                     args=parameters,
                                                     body=[_ast.Return(value=_ast.Call(func=_ast.Name(id='_chain', ctx=_ast.Load()),
                                                                                       args=arguments,
                                                                                       keywords=[]))],
                                                     decorator_list=[])])
    module_node = _ast.fix_missing_locations(module_node)

    # compile the ast
    code = compile(module_node, '<string>', 'exec')

    # and eval it in the right context
    globals_ = {'_chain': chain_fn}
    locals_ = dict(('_def{0}'.format(idx), value) for idx, value in enumerate(defaults))
    eval(code, globals_, locals_)

    # extract our function from the newly created module
    return locals_[name]
Exemplo n.º 15
0
    def __call__(self, selection):
        if not self.is_initialized:
            self._initialize()

        try:
            parse_result = self.expression.parseString(selection, parseAll=True)
        except ParseException as e:
            msg = str(e)
            lines = ["%s: %s" % (msg, selection),
                     " " * (12 + len("%s: " % msg) + (e.loc)) + "^^^"]
            raise ValueError('\n'.join(lines))


        # Change __ATOM__ in function bodies. It must bind to the arg
        # name specified below (i.e. 'atom')
        astnode = self.transformer.visit(deepcopy(parse_result[0].ast()))

        if PY2:
            args = [ast.Name(id='atom', ctx=ast.Param())]
            signature = ast.arguments(args=args, vararg=None, kwarg=None,
                                      defaults=[])
        else:
            args = [ast.arg(arg='atom', annotation=None)]
            signature = ast.arguments(args=args, vararg=None, kwarg=None,
                                      kwonlyargs=[], defaults=[],
                                      kw_defaults=[])

        func = ast.Expression(body=ast.Lambda(signature, astnode))
        source = codegen.to_source(astnode)

        expr = eval(
            compile(ast.fix_missing_locations(func), '<string>', mode='eval'),
            SELECTION_GLOBALS)
        return _ParsedSelection(expr, source, astnode)
Exemplo n.º 16
0
 def test_lambda(self):
     lamb_expr = Lambda(args=ast.arguments(args=[ast.arg(arg='x')]), body=Num(n=3))
     env = Scope([])
     k = Done()
     val = step(lamb_expr, env, k)[0]
     self.assert_(isinstance(val, Lambda))
     self.assertEqual(val.args.args[0].arg, 'x')
     self.assertEqual(val.body.n, 3)
Exemplo n.º 17
0
 def _make_arg(self, node):
     if node is None:
         return None
     new_node = ast.arg(
         self._visit(node.id),
         self._visit(node.annotation),
     )
     return ast.copy_location(new_node, node)
Exemplo n.º 18
0
def _create_ast_lambda(names, body):
    if sys.version_info >= (3, 0):  # change in AST structure for Python 3
        args = [ast.arg(arg=name, annotation=None) for name in names]
    else:
        args = [ast.Name(id=name, ctx=ast.Load()) for name in names]

    return ast.Lambda(args=ast.arguments(
        args=args, vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=body)
Exemplo n.º 19
0
def compile_define(p):
    if isinstance(p[1], list):
        body = [make_stmt(build_ast(x)) for x in p[2:]]
        if isinstance(body[-1], ast.Expr):
            body[-1] = ast.Return(body[-1].value)
        return ast.FunctionDef(pydent(p[1][0].name), ast.arguments(args=[ast.arg(arg=x.name) for x in p[1][1:]], kwonlyargs=[], defaults=[], kw_defaults=[]), body, [], None)
    else:
        return ast.Assign([ast.Name(pydent(p[1].name), ast.Store())], build_ast(p[2]))
Exemplo n.º 20
0
 def _generate_self(self):
     return ast.arguments(
         args=[ast.arg(arg='self', annotation=None)],
         vararg=None,
         kwonlyargs=[],
         kw_defaults=[],
         kwarg=None,
         defaults=[]
     )
Exemplo n.º 21
0
 def _translate_python_arguments(self, tvars: list) -> pyast.arguments:
     """Translate typed variable list to python 'arguments' AST"""
     return pyast.arguments(
         args=list(pyast.arg(arg=tv.var.name, annotation=None) for tv in tvars),
         vararg=None,
         kwonlyargs=[],
         kw_defaults=[],
         kwarg=None,
         defaults=[])
Exemplo n.º 22
0
def _lower_array_expr(lowerer, expr):
    '''Lower an array expression built by RewriteArrayExprs.
    '''
    expr_name = "__numba_array_expr_%s" % (hex(hash(expr)).replace("-", "_"))
    expr_args = sorted(set(expr.list_vars()), key=lambda x: x.name)
    expr_arg_names = [arg.name for arg in expr_args]
    if hasattr(ast, "arg"):
        # Should be Python 3.x
        ast_args = [ast.arg(arg_name, None)
                    for arg_name in expr_arg_names]
    else:
        # Should be Python 2.x
        ast_args = [ast.Name(arg_name, ast.Param())
                    for arg_name in expr_arg_names]
    # Parse a stub function to ensure the AST is populated with
    # reasonable defaults for the Python version.
    ast_module = ast.parse('def {0}(): return'.format(expr_name),
                           expr_args[0].loc.filename, 'exec')
    assert hasattr(ast_module, 'body') and len(ast_module.body) == 1
    ast_fn = ast_module.body[0]
    ast_fn.args.args = ast_args
    ast_fn.body[0].value, namespace = _arr_expr_to_ast(expr.expr)
    ast.fix_missing_locations(ast_module)
    code_obj = compile(ast_module, expr_args[0].loc.filename, 'exec')
    six.exec_(code_obj, namespace)
    impl = namespace[expr_name]

    context = lowerer.context
    builder = lowerer.builder
    outer_sig = expr.ty(*(lowerer.typeof(name) for name in expr_arg_names))
    inner_sig_args = []
    for argty in outer_sig.args:
        if isinstance(argty, types.Array):
            inner_sig_args.append(argty.dtype)
        else:
            inner_sig_args.append(argty)
    inner_sig = outer_sig.return_type.dtype(*inner_sig_args)

    _locals = dict((name, value)
                   for name, value in namespace.items()
                   if name.startswith("__ufunc_or_dufunc_"))
    cres = context.compile_only_no_cache(builder, impl, inner_sig,
                                         locals=_locals)

    class ExprKernel(npyimpl._Kernel):
        def generate(self, *args):
            arg_zip = zip(args, self.outer_sig.args, inner_sig.args)
            cast_args = [self.cast(val, inty, outty)
                         for val, inty, outty in arg_zip]
            result = self.context.call_internal(
                builder, cres.fndesc, inner_sig, cast_args)
            return self.cast(result, inner_sig.return_type,
                             self.outer_sig.return_type)

    args = [lowerer.loadvar(name) for name in expr_arg_names]
    return npyimpl.numpy_ufunc_kernel(
        context, builder, outer_sig, args, ExprKernel, explicit_output=False)
Exemplo n.º 23
0
 def _generate_argument(self, name):
     return ast.arguments(
         args=[ast.arg(arg=name, annotation=None)],
         vararg=None,
         kwonlyargs=[],
         kw_defaults=[],
         kwarg=None,
         defaults=[]
     )
Exemplo n.º 24
0
def query(tree, gen_sym, **kw):
    x = process(tree)
    x = expand_let_bindings.recurse(x)
    sym = gen_sym()
    # return q[(lambda query: query.bind.execute(query).fetchall())(ast[x])]
    new_tree = hq[(lambda query: name[sym].bind.execute(
        name[sym]).fetchall())(ast_literal[x])]
    new_tree.func.args = ast.arguments([ast.arg(sym, None)], None, [], [],
                                       None, [])
    return new_tree
Exemplo n.º 25
0
 def test_call(self):
     call_expr = Call(func=Lambda(args=ast.arguments(args=[ast.arg(arg='x')]), body=Name(id='x')), args=[Num(n=3)])
     env = Scope([])
     k = Done()
     val = step(call_expr, env, k)[0]
     earg = step(call_expr, env, k)[2]
     self.assertEqual(val.args.args[0].arg, 'x')
     self.assertEqual(val.body.id, 'x')
     self.assert_(isinstance(earg, Earg))
     self.assertEqual(earg.expr.n, 3)
Exemplo n.º 26
0
 def add_binding_args_to_func(cls, args, func):
     """Alter the definition of 'func' to accept 'args'."""
     # Get the AST of the function and add extra argument nodes.
     funcast = ast.parse(inspect.getsource(func).strip())
     funcname = funcast.body[0].name
     funcargs = funcast.body[0].args
     funcargs.args = [ ast.arg(funcargs.args[0].arg, None) ]
     funcargs.args.extend(ast.arg(str(a), None) for a in args)
     env = dict()  # An environment in which to evaluate the modded AST.
     if len(func.__code__.co_freevars) < 1:
         # If the function doesn't close over any free variables,
         # it can be evaluated as is.
         exec(compile(funcast, '<generated>', 'exec'),
              func.__globals__, env)
         newfunc = env[funcname]
     else:
         # TODO: what happens if a freevar is also a binding arg?
         # Have to build a closure.
         clsname = cls.__name__
         freevars = tuple(func.__code__.co_freevars)
         # Pull the values out of the existing closure.
         # When a function is defined in a class and refers to
         # to the class as a free variable it is not included
         # in the closure and has to be added manually.
         closvals = [cell.cell_contents if var != clsname else cls
                     for var, cell in
                     zip(freevars, func.__closure__)]
         # Create a wrapper function to bind the closure values.
         wrapperargs = ', '.join(freevars)
         wrapper = ast.parse("def wrapper(%s):\n"
                             "  def %s(): pass\n"
                             "  return %s" %
                             (wrapperargs, funcname, funcname))
         wrapperfunc = wrapper.body[0]
         # Replace the body of the wrapper with the target function.
         wrapperfunc.body[0] = funcast.body[0]
         # Evaluate the resulting AST and call the
         # wrapper function with the closure values.
         exec(compile(wrapper, '<generated>', 'exec'),
              func.__globals__, env)
         newfunc = env['wrapper'](*closvals)
     return newfunc
Exemplo n.º 27
0
def translate_lambda(compiler, lambda_):
	return python_ast.Lambda(
		args=python_ast.arguments(
			args=list(map(lambda n: python_ast.arg(arg=n, annotation=None), lambda_.params)),
			vararg=None,
			kwonlyargs=[],
			kw_defaults=[],
			kwarg=None,
			defaults=[]
		),
		body=compiler.translate(lambda_.body)
	)
Exemplo n.º 28
0
    def takes_only_self(self):
        """
        Return an argument list node that takes only ``self``.

        """

        return ast.arguments(
            args=[ast.arg(arg="self")],
            defaults=[],
            kw_defaults=[],
            kwonlyargs=[],
        )
Exemplo n.º 29
0
 def visit_ListComp(self, t):
     result_append = ast.Attribute(ast.Name('.0', load), 'append', load)
     body = ast.Expr(Call(result_append, [t.elt]))
     for loop in reversed(t.generators):
         for test in reversed(loop.ifs):
             body = ast.If(test, [body], [])
         body = ast.For(loop.target, loop.iter, [body], [])
     fn = [body,
           ast.Return(ast.Name('.0', load))]
     args = ast.arguments([ast.arg('.0', None)], None, [], None, [], [])
     return Call(Function('<listcomp>', args, fn),
                 [ast.List([], load)])
Exemplo n.º 30
0
def ast_lambda(name, body):
    if PYTHON_VERSION is 2:
        return ast.Lambda(args=ast.arguments(args=[name],
                          defaults=[]), body=body)
    elif PYTHON_VERSION is 3:
        return ast.Lambda(args=ast.arguments(args=[ast.arg(arg=name.id)],
                                             defaults=[],
                                             kwonlyargs=[],
                                             kw_defaults=[]),
                          body=body)
    else:
        invalid_python_version()
Exemplo n.º 31
0
    def visit_FunctionDef(self, node):
        node.parent = self.fundef
        self.fundef = node

        if len(
                list(
                    filter(
                        lambda n: isinstance(n, ast.Name) and n.id is
                        'rep_fun', node.decorator_list))) > 0:
            self.recs.append(node.name)

        self.generic_visit(node)

        r_args = {}

        for arg in node.args.args:
            arg_name = arg.arg
            try:
                if self.fundef.locals[arg_name] > 1:
                    r_args[arg_name] = self.freshName('x')
                # self.fundef.locals[arg_name] += 1
            except KeyError as e:
                pass

        # generate code to pre-initialize staged vars
        # we stage all vars that are written to more than once
        inits = [ast.Assign(targets=[ast.Name(id=id, ctx=ast.Store())],
                    value=ast.Call(
                        func=ast.Name(id='_var', ctx=ast.Load()),
                        args=[],
                        keywords=[])) \
                    for id in node.locals if node.locals[id] > 1]

        a_nodes = [ast.Expr(
                    ast.Call(
                        func=ast.Name(id='_assign', ctx=ast.Load()),
                        args=[ast.Name(id=arg,ctx=ast.Load()), ast.Name(id=r_args[arg],ctx=ast.Load())],
                        keywords=[])) \
                    for arg in r_args]

        new_node = ast.FunctionDef(
            name=node.name,
            args=ast.arguments(args=[
                ast.arg(arg=r_args[arg.arg], annotation=None)
                if arg.arg in r_args else arg for arg in node.args.args
            ],
                               vararg=None,
                               kwonlyargs=[],
                               kwarg=None,
                               defaults=[],
                               kw_defaults=[]),  # node.args,
            body=[
                ast.Try(body=inits + a_nodes + node.body,
                        handlers=[
                            ast.ExceptHandler(
                                type=ast.Name(id='NonLocalReturnValue',
                                              ctx=ast.Load()),
                                name='r',
                                body=[
                                    ast.Return(value=ast.Attribute(
                                        value=ast.Name(id='r', ctx=ast.Load()),
                                        attr='value',
                                        ctx=ast.Load()))
                                ])
                        ],
                        orelse=[],
                        finalbody=[])
            ],
            decorator_list=list(
                filter(
                    lambda n: isinstance(n, ast.Name) and n.id != 'lms' and n.
                    id != 'rep_fun', node.decorator_list)))
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.fundef = node.parent
        return new_node
Exemplo n.º 32
0
    def make_continuation(owner, callcc, contbody):
        targets, starget, condition, thecall, altcall = analyze_callcc(callcc)

        # no-args special case: allow but ignore one arg so there won't be arity errors
        # from a "return None"-generated None being passed into the cc
        # (in Python, a function always has a return value, though it may be None)
        if not targets and not starget:
            targets = ["_ignored_arg"]
            posargdefaults = [q[None]]
        else:
            posargdefaults = []

        # Name the continuation: f_cont, f_cont1, f_cont2, ...
        # if multiple call_cc[]s in the same function body.
        if owner:
            # TODO: robustness: use regexes, strip suf and any numbers at the end, until no match.
            # return prefix of s before the first occurrence of suf.
            def strip_suffix(s, suf):
                n = s.find(suf)
                if n == -1:
                    return s
                return s[:n]

            basename = "{}_cont".format(strip_suffix(owner.name, "_cont"))
        else:
            basename = "cont"
        contname = gen_sym(basename)

        # Set our captured continuation as the cc of f and g in
        #   call_cc[f(...)]
        #   call_cc[f(...) if p else g(...)]
        def prepare_call(tree):
            if tree:
                tree.keywords = [keyword(arg="cc", value=q[name[contname]])
                                 ] + tree.keywords
            else:  # no call means proceed to cont directly, with args set to None
                tree = q[name[contname](*([None] * u[len(targets)]),
                                        cc=name["cc"])]
            return tree

        thecall = prepare_call(thecall)
        if condition:
            altcall = prepare_call(altcall)

        # Create the continuation function, set contbody as its body.
        #
        # Any return statements in the body have already been transformed,
        # because they appear literally in the code at the use site,
        # and our main processing logic runs the return statement transformer
        # before transforming call_cc[].
        FDef = type(
            owner
        ) if owner else FunctionDef  # use same type (regular/async) as parent function
        locref = callcc  # bad but no better source location reference node available
        non = q[None]
        non = copy_location(non, locref)
        maybe_capture = IfExp(test=hq[name["cc"] is not identity],
                              body=q[name["cc"]],
                              orelse=non,
                              lineno=locref.lineno,
                              col_offset=locref.col_offset)
        funcdef = FDef(
            name=contname,
            args=arguments(args=[arg(arg=x) for x in targets],
                           kwonlyargs=[arg(arg="cc"),
                                       arg(arg="_pcc")],
                           vararg=(arg(arg=starget) if starget else None),
                           kwarg=None,
                           defaults=posargdefaults,
                           kw_defaults=[hq[identity], maybe_capture]),
            body=contbody,
            decorator_list=[],  # patched later by transform_def
            returns=None,  # return annotation not used here
            lineno=locref.lineno,
            col_offset=locref.col_offset)

        # in the output stmts, define the continuation function...
        newstmts = [funcdef]
        if owner:  # ...and tail-call it (if currently inside a def)

            def jumpify(tree):
                tree.args = [tree.func] + tree.args
                tree.func = hq[jump]

            jumpify(thecall)
            if condition:
                jumpify(altcall)
                newstmts.append(
                    If(test=condition,
                       body=[Return(value=q[ast_literal[thecall]])],
                       orelse=[Return(value=q[ast_literal[altcall]])]))
            else:
                newstmts.append(Return(value=q[ast_literal[thecall]]))
        else:  # ...and call it normally (if at the top level)
            if condition:
                newstmts.append(
                    If(test=condition,
                       body=[Expr(value=q[ast_literal[thecall]])],
                       orelse=[Expr(value=q[ast_literal[altcall]])]))
            else:
                newstmts.append(Expr(value=q[ast_literal[thecall]]))
        return newstmts
Exemplo n.º 33
0
    def __build__init(self):

        super_func_call = ast.Call(func=ast.Name(id='super', ctx=ast.Load()),
                                   args=[],
                                   keywords=[])
        if (sys.version_info[0], sys.version_info[1]) == (3, 5) or \
         (sys.version_info[0], sys.version_info[1]) == (3, 6) or \
         (sys.version_info[0], sys.version_info[1]) == (3, 7):
            super_func = ast.Call(
                func=ast.Attribute(value=super_func_call,
                                   attr='__init__',
                                   ctx=ast.Load()),
                args=[
                    ast.Starred(value=ast.Name(id='args', ctx=ast.Load()),
                                ctx=ast.Load())
                ],
                keywords=[
                    ast.keyword(arg=None,
                                value=ast.Name(id='kwargs', ctx=ast.Load()),
                                ctx=ast.Load())
                ],
            )
        elif (sys.version_info[0], sys.version_info[1]) == (3, 4):
            super_func = ast.Call(
                func=ast.Attribute(value=super_func_call,
                                   attr='__init__',
                                   ctx=ast.Load()),
                args=[],
                keywords=[],
                starargs=ast.Name(id='args', ctx=ast.Load()),
                kwargs=ast.Name(id='kwargs', ctx=ast.Load()),
            )
        else:
            print("Version:", sys.version_info)
            raise RuntimeError(
                "This script only functions on python 3.4, 3.5, 3.6, or 3.7. Active python version {}.{}"
                .format(*sys.version_info))

        super_init = ast.Expr(
            value=super_func,
            lineno=self.__get_line(),
            col_offset=0,
        )

        body = [super_init]

        sig = ast.arguments(args=[ast.arg('self', None)],
                            vararg=ast.arg(arg='args', annotation=None),
                            kwarg=ast.arg(arg='kwargs', annotation=None),
                            varargannotation=None,
                            kwonlyargs=[],
                            kwargannotation=None,
                            defaults=[],
                            kw_defaults=[])

        func = ast.FunctionDef(
            name="__init__",
            args=sig,
            body=body,
            decorator_list=[],
            lineno=self.__get_line(),
            col_offset=0,
        )

        return func
Exemplo n.º 34
0
def compile_func(gen: 'Generator',
                 func: Callable,
                 strategy: Strategy,
                 with_hooks: bool = False) -> Callable:
    """
    The compilation basically assigns functionality to each of the operator calls as
    governed by the semantics (strategy). Memoization is done with the keys as the `func`,
    the class of the `strategy` and the `with_hooks` argument.

    Args:
        gen (Generator): The generator object containing the function to compile
        func (Callable): The function to compile
        strategy (Strategy): The strategy governing the behavior of the operators
        with_hooks (bool): Whether support for hooks is required

    Returns:
        The compiled function

    """

    if isinstance(strategy, PartialReplayStrategy):
        strategy = strategy.backup_strategy

    if with_hooks:
        cache = CompilationCache.WITH_HOOKS[strategy.__class__]
    else:
        cache = CompilationCache.WITHOUT_HOOKS[strategy.__class__]

    if func in cache:
        return cache[func]

    cache[func] = None

    source_code, start_lineno = inspect.getsourcelines(func)
    source_code = ''.join(source_code)
    f_ast = astutils.parse(textwrap.dedent(source_code))

    # This matches up line numbers with original file and is thus super useful for debugging
    ast.increment_lineno(f_ast, start_lineno - 1)

    #  Remove the ``@generator`` decorator to avoid recursive compilation
    f_ast.decorator_list = [
        d for d in f_ast.decorator_list
        if (not isinstance(d, ast.Name) or d.id != 'generator') and (
            not isinstance(d, ast.Attribute) or d.attr != 'generator') and (
                not (isinstance(d, ast.Call) and isinstance(d.func, ast.Name))
                or d.func.id != 'generator')
    ]

    #  Get all the external dependencies of this function.
    #  We rely on a modified closure function adopted from the ``inspect`` library.
    closure_vars = getclosurevars_recursive(func, f_ast)
    g = {**closure_vars.nonlocals.copy(), **closure_vars.globals.copy()}
    known_ops: Set[str] = strategy.get_known_ops()
    known_methods: Set[str] = strategy.get_known_methods()
    op_info_constructor = OpInfoConstructor()
    delayed_compilations: List[Tuple[Generator, str]] = []

    ops = {}
    handlers = {}
    op_infos = {}
    op_idx: int = 0
    composition_cnt: int = 0
    for n in astutils.preorder_traversal(f_ast):
        if isinstance(n, ast.Call) and isinstance(
                n.func, ast.Name) and n.func.id in known_ops:
            #  Rename the function call, and assign a new function to be called during execution.
            #  This new function is determined by the semantics (strategy) being used for compilation.
            #  Also determine if there any eligible hooks for this operator call.
            op_idx += 1
            handler_idx = len(handlers)
            op_info: OpInfo = op_info_constructor.get(n, gen.name, gen.group)

            n.keywords.append(
                ast.keyword(arg='model',
                            value=ast.Name(_GEN_MODEL_VAR, ast.Load())))

            n.keywords.append(
                ast.keyword(arg='op_info',
                            value=ast.Name(f"_op_info_{op_idx}", ast.Load())))
            op_infos[f"_op_info_{op_idx}"] = op_info

            n.keywords.append(
                ast.keyword(arg='handler',
                            value=ast.Name(f"_handler_{handler_idx}",
                                           ast.Load())))
            handler = strategy.get_op_handler(op_info)
            handlers[f"_handler_{handler_idx}"] = handler

            if not with_hooks:
                n.func = astutils.parse(
                    f"{_GEN_STRATEGY_VAR}.generic_op").value
            else:
                n.keywords.append(
                    ast.keyword(arg=_GEN_HOOK_VAR,
                                value=ast.Name(_GEN_HOOK_VAR, ctx=ast.Load())))
                n.keywords.append(
                    ast.keyword(arg=_GEN_STRATEGY_VAR,
                                value=ast.Name(_GEN_STRATEGY_VAR,
                                               ctx=ast.Load())))

                n.func.id = _GEN_HOOK_WRAPPER
                ops[_GEN_HOOK_WRAPPER] = hook_wrapper

            if returns_lambda(handler):
                n.func = ast.Call(func=n.func,
                                  args=n.args[:],
                                  keywords=n.keywords[:])
                n.keywords = []
                n.args = [n.args[0]]

            ast.fix_missing_locations(n)

        elif isinstance(n, ast.Call) and isinstance(
                n.func, ast.Name) and n.func.id in known_methods:
            #  Similar in spirit to the known_ops case, just much less fancy stuff to do.
            #  Only need to get the right handler which we will achieve by simply making this
            #  a method call instead of a regular call.
            n.func = ast.Attribute(value=ast.Name(_GEN_STRATEGY_VAR,
                                                  ctx=ast.Load()),
                                   attr=n.func.id,
                                   ctx=ast.Load())
            ast.fix_missing_locations(n)

        elif isinstance(n, ast.Call):
            #  Try to check if it is a call to a Generator
            #  TODO : Can we be more sophisticated in our static analysis here
            try:
                function = eval(astunparse.unparse(n.func), g)
            except:
                continue

            if isinstance(function, Generator):
                call_id = f"{_GEN_COMPOSITION_ID}_{composition_cnt}"
                composition_cnt += 1
                n.func.id = call_id
                n.keywords.append(
                    ast.keyword(arg=_GEN_EXEC_ENV_VAR,
                                value=ast.Name(_GEN_EXEC_ENV_VAR, ast.Load())))
                n.keywords.append(
                    ast.keyword(arg=_GEN_STRATEGY_VAR,
                                value=ast.Name(_GEN_STRATEGY_VAR, ast.Load())))
                n.keywords.append(
                    ast.keyword(arg=_GEN_MODEL_VAR,
                                value=ast.Name(_GEN_MODEL_VAR, ast.Load())))
                n.keywords.append(
                    ast.keyword(arg=_GEN_HOOK_VAR,
                                value=ast.Name(_GEN_HOOK_VAR, ast.Load())))
                ast.fix_missing_locations(n)

                #  We delay compilation to handle mutually recursive generators
                delayed_compilations.append((function, call_id))

            elif function is CallGenerator:
                wrapped_func = n.args[0]
                n.func = wrapped_func.func
                n.args = wrapped_func.args[:]
                n.keywords = wrapped_func.keywords[:]
                n.keywords.append(
                    ast.keyword(arg=_GEN_EXEC_ENV_VAR,
                                value=ast.Name(_GEN_EXEC_ENV_VAR, ast.Load())))
                ast.fix_missing_locations(n)

    #  Add the execution environment argument to the function
    f_ast.args.kwonlyargs.append(
        ast.arg(arg=_GEN_EXEC_ENV_VAR, annotation=None))
    f_ast.args.kw_defaults.append(ast.NameConstant(value=None))

    #  Add the strategy argument to the function
    f_ast.args.kwonlyargs.append(
        ast.arg(arg=_GEN_STRATEGY_VAR, annotation=None))
    f_ast.args.kw_defaults.append(ast.NameConstant(value=None))

    #  Add the strategy argument to the function
    f_ast.args.kwonlyargs.append(ast.arg(arg=_GEN_MODEL_VAR, annotation=None))
    f_ast.args.kw_defaults.append(ast.NameConstant(value=None))

    #  Add the hook argument to the function
    f_ast.args.kwonlyargs.append(ast.arg(arg=_GEN_HOOK_VAR, annotation=None))
    f_ast.args.kw_defaults.append(ast.NameConstant(value=None))
    ast.fix_missing_locations(f_ast)

    #  New name so it doesn't clash with original
    func_name = f"{_GEN_COMPILED_TARGET_ID}_{len(cache)}"

    g.update({k: v for k, v in ops.items()})
    g.update({k: v for k, v in handlers.items()})
    g.update({k: v for k, v in op_infos.items()})

    module = ast.Module()
    module.body = [f_ast]

    #  Passing ``g`` to exec allows us to execute all the new functions
    #  we assigned to every operator call in the previous AST walk
    filename = inspect.getabsfile(func)
    exec(compile(module, filename=filename, mode="exec"), g)
    result = g[func.__name__]
    g["__name__"] = filename

    if inspect.ismethod(func):
        result = result.__get__(func.__self__, func.__self__.__class__)

    #  Restore the correct namespace so that tracebacks contain actual function names
    g[gen.name] = gen
    g[func_name] = result

    cache[func] = result

    #  Handle the delayed compilations now that we have populated the cache
    for gen, call_id in delayed_compilations:
        compiled_func = compile_func(gen, gen.func, strategy, with_hooks)
        if gen.caching and isinstance(strategy, DfsStrategy):
            #  Add instructions for using cached result if any
            g[call_id] = cache_wrapper(compiled_func)

        else:
            g[call_id] = compiled_func

    return result
Exemplo n.º 35
0
def res_python_setup(res: Property) -> Tuple[Callable[[EntityFixup], object], str]:
    variables = {}
    variable_order = []
    code = None
    result_var = None
    for child in res:
        if child.name.startswith('$'):
            if child.value.casefold() not in FUNC_GLOBALS:
                raise Exception('Invalid variable type! ({})'.format(child.value))
            variables[child.name[1:]] = child.value.casefold()
            variable_order.append(child.name[1:])
        elif child.name == 'op':
            code = child.value
        elif child.name == 'resultvar':
            result_var = child.value
        else:
            raise Exception('Invalid key "{}"'.format(child.real_name))
    if not code:
        raise Exception('No operation specified!')
    if not result_var:
        raise Exception('No destination specified!')

    for name in variables:
        if name.startswith('_'):
            raise Exception('"{}" is not permitted as a variable name!'.format(name))

    # Allow $ in the variable names..
    code = code.replace('$', '')

    # Now process the code to convert it into a function taking variables
    # and returning them.
    # We also need to whitelist operations for security.

    expression = ast.parse(
        code,
        '<bee2_op>',
        mode='eval',
    ).body

    Checker(variable_order).visit(expression)

    # For each variable, do
    # var = func(_fixup['var'])
    statements: List[ast.AST] = [
        ast.Assign(
            targets=[ast.Name(id=var_name, ctx=ast.Store())],
            value=ast.Call(
                func=ast.Name(id=variables[var_name], ctx=ast.Load()),
                args=[
                    ast.Subscript(
                        value=ast.Name(id='_fixup', ctx=ast.Load()),
                        slice=ast.Index(value=ast.Str(s=var_name)),
                        ctx=ast.Load(),
                    ),
                ],
                keywords=[],
                starargs=None,
                kwargs=None,
            )
        )
        for line_num, var_name in enumerate(
            variable_order, start=1,
        )
    ]
    # The last statement returns the target expression.
    statements.append(ast.Return(expression, lineno=len(variable_order)+1, col_offset=0))

    args = ast.arguments(
        vararg=None, 
        kwonlyargs=[], 
        kw_defaults=[], 
        kwarg=None, 
        defaults=[],
    )
    # Py 3.8+, make it pos-only.
    if 'posonlyargs' in args._fields:
        args.posonlyargs = [ast.arg('_fixup', None)]
        args.args = []
    else:  # Just make it a regular arg.
        args.args = [ast.arg('_fixup', None)]

    func = ast.Module([
            ast.FunctionDef(
                name='_bee2_generated_func',
                args=args,
                body=statements,
                decorator_list=[],
            ),
        ],
        lineno=1,
        col_offset=0,
    )
    # Python 3.8 also
    if 'type_ignores' in func._fields:
        func.type_ignores = []

    # Fill in lineno and col_offset
    ast.fix_missing_locations(func)

    ns: Dict[str, Any] = {}
    eval(compile(func, '<bee2_op>', mode='exec'), FUNC_GLOBALS, ns)
    compiled_func = ns['_bee2_generated_func']
    compiled_func.__name__ = '<bee2_func>'
    return compiled_func, result_var
Exemplo n.º 36
0
 def visitArgs(self, ctx: vjjParser.ArgsContext):
     return [ast.arg(arg=str(id)) for id in ctx.ID()]
Exemplo n.º 37
0
def parseTranslatedSource(source, lineMap, filename):
	try:
		tree = parse(source, filename=filename)
		return tree
	except SyntaxError as e:
		cause = e if showInternalBacktrace else None
		raise PythonParseError.fromSyntaxError(e, lineMap) from cause

### TRANSLATION PHASE FOUR: modifying the parse tree

noArgs = ast.arguments(
	args=[], vararg=None,
	kwonlyargs=[], kw_defaults=[],
	kwarg=None, defaults=[])
selfArg = ast.arguments(
	args=[ast.arg(arg='self', annotation=None)], vararg=None,
	kwonlyargs=[], kw_defaults=[],
	kwarg=None, defaults=[])

if sys.version_info >= (3, 8):	# TODO cleaner way to handle this?
	noArgs.posonlyargs = []
	selfArg.posonlyargs = []

class AttributeFinder(NodeVisitor):
	"""Utility class for finding all referenced attributes of a given name."""
	@staticmethod
	def find(target, node):
		af = AttributeFinder(target)
		af.visit(node)
		return af.attributes
Exemplo n.º 38
0
 def visitarg(self, n, *args):
     annotation = self.dispatch(n.annotation, *
                                args) if n.annotation else None
     return ast.arg(arg=n.arg, annotation=annotation)
Exemplo n.º 39
0
def generate_repr_method(params, cls_name, docstring_format):
    """
    Generate a `__repr__` method with all params, using `str.format` syntax

    :param params: an `OrderedDict` of form
        OrderedDict[str, {'typ': str, 'doc': Optional[str], 'default': Any}]
    :type params: ```OrderedDict```

    :param cls_name: Name of class
    :type cls_name: ```str```

    :param docstring_format: Format of docstring
    :type docstring_format: ```Literal['rest', 'numpydoc', 'google']```

    :returns: `__repr__` method
    :rtype: ```FunctionDef```
    """
    keys = tuple(params.keys())
    return FunctionDef(
        name="__repr__",
        args=arguments(
            posonlyargs=[],
            arg=None,
            args=[
                arg(arg="self",
                    annotation=None,
                    expr=None,
                    identifier_arg=None,
                    **maybe_type_comment)
            ],
            kwonlyargs=[],
            kw_defaults=[],
            defaults=[],
            vararg=None,
            kwarg=None,
        ),
        body=[
            Expr(
                set_value(docstring_repr_str if docstring_format ==
                          "rest" else docstring_repr_google_str)),
            Return(
                value=Call(
                    func=Attribute(
                        set_value("{cls_name}({format_args})".format(
                            cls_name=cls_name,
                            format_args=", ".join(
                                map("{0}={{{0}!r}}".format, keys)),
                        )),
                        "format",
                        Load(),
                    ),
                    args=[],
                    keywords=list(
                        map(
                            lambda key: keyword(
                                arg=key,
                                value=Attribute(Name("self", Load()), key,
                                                Load()),
                                identifier=None,
                            ),
                            keys,
                        )),
                    expr=None,
                    expr_func=None,
                ),
                expr=None,
            ),
        ],
        decorator_list=[],
        arguments_args=None,
        identifier_name=None,
        stmt=None,
        lineno=None,
        returns=None,
        **maybe_type_comment)
Exemplo n.º 40
0
async def meval(code, globs, **kwargs):
    # Note to self: please don't set globals here as they will be lost.
    # Don't clutter locals
    locs = {}
    # Restore globals later
    globs = globs.copy()
    # This code saves __name__ and __package into a kwarg passed to the func.
    # It is set before the users code runs to make sure relative imports work
    global_args = "_globs"
    while global_args in globs.keys():
        # Make sure there's no name collision, just keep prepending _s
        global_args = "_" + global_args
    kwargs[global_args] = {}
    for glob in ["__name__", "__package__"]:
        # Copy data to args we are sending
        kwargs[global_args][glob] = globs[glob]

    root = ast.parse(code, "exec")
    code = root.body

    ret_name = "_ret"
    ok = False
    while True:
        if ret_name in globs.keys():
            ret_name = "_" + ret_name
            continue
        for node in ast.walk(root):
            if isinstance(node, ast.Name) and node.id == ret_name:
                ret_name = "_" + ret_name
                break
            ok = True
        if ok:
            break

    if not code:
        return None

    if not any(isinstance(node, ast.Return) for node in code):
        for i in range(len(code)):
            if isinstance(code[i], ast.Expr):
                if (i == len(code) - 1
                        or not isinstance(code[i].value, ast.Call)):
                    code[i] = ast.copy_location(
                        ast.Expr(
                            ast.Call(func=ast.Attribute(value=ast.Name(
                                id=ret_name, ctx=ast.Load()),
                                                        attr="append",
                                                        ctx=ast.Load()),
                                     args=[code[i].value],
                                     keywords=[])), code[-1])
    else:
        for node in code:
            if isinstance(node, ast.Return):
                node.value = ast.List(elts=[node.value], ctx=ast.Load())

    code.append(
        ast.copy_location(
            ast.Return(value=ast.Name(id=ret_name, ctx=ast.Load())), code[-1]))

    # globals().update(**<global_args>)
    glob_copy = ast.Expr(
        ast.Call(
            func=ast.Attribute(value=ast.Call(func=ast.Name(id="globals",
                                                            ctx=ast.Load()),
                                              args=[],
                                              keywords=[]),
                               attr="update",
                               ctx=ast.Load()),
            args=[],
            keywords=[
                ast.keyword(arg=None,
                            value=ast.Name(id=global_args, ctx=ast.Load()))
            ]))
    ast.fix_missing_locations(glob_copy)
    code.insert(0, glob_copy)
    ret_decl = ast.Assign(targets=[ast.Name(id=ret_name, ctx=ast.Store())],
                          value=ast.List(elts=[], ctx=ast.Load()))
    ast.fix_missing_locations(ret_decl)
    code.insert(1, ret_decl)
    args = []
    for a in list(map(lambda x: ast.arg(x, None), kwargs.keys())):
        ast.fix_missing_locations(a)
        args += [a]
    args = ast.arguments(args=[],
                         vararg=None,
                         kwonlyargs=args,
                         kwarg=None,
                         defaults=[],
                         kw_defaults=[None for i in range(len(args))])
    args.posonlyargs = []
    fun = ast.AsyncFunctionDef(name="tmp",
                               args=args,
                               body=code,
                               decorator_list=[],
                               returns=None)
    ast.fix_missing_locations(fun)
    mod = ast.parse("")
    mod.body = [fun]
    comp = compile(mod, "<string>", "exec")

    exec(comp, {}, locs)

    r = await locs["tmp"](**kwargs)
    for i in range(len(r)):
        if hasattr(r[i], "__await__"):
            r[i] = await r[i]  # workaround for 3.5
    i = 0
    while i < len(r) - 1:
        if r[i] is None:
            del r[i]
        else:
            i += 1
    if len(r) == 1:
        [r] = r
    elif not r:
        r = None
    return r
Exemplo n.º 41
0
def _translate_all_expression_to_a_module(
        generator_exp: ast.GeneratorExp, generated_function_name: str,
        name_to_value: Mapping[str, Any]) -> ast.Module:
    """
    Generate the AST of the module to trace an all quantifier on an generator expression.

    :param generator_exp: generator expression to be translated
    :param generated_function_name: UUID of the tracing function to be used in the code
    :param name_to_value:
        mapping of all resolved values to the variable names
        (passed as arguments to the function so that the generation can access them)
    :return: translation to a module
    """
    assert generated_function_name not in name_to_value
    assert not hasattr(builtins, generated_function_name)

    # Collect all the names involved in the generation
    relevant_names = _collect_stored_names(
        generator.target for generator in generator_exp.generators)

    assert generated_function_name not in relevant_names

    # Work backwards, from the most-inner block outwards

    result_id = 'icontract_tracing_all_result_{}'.format(uuid.uuid4().hex)
    result_assignment = ast.Assign(
        targets=[ast.Name(id=result_id, ctx=ast.Store())],
        value=generator_exp.elt)

    exceptional_return = ast.Return(
        ast.Tuple(elts=[
            ast.Name(id=result_id, ctx=ast.Load()),
            ast.Tuple(elts=[
                ast.Tuple(elts=[
                    ast.Constant(value=relevant_name, kind=None),
                    ast.Name(id=relevant_name, ctx=ast.Load())
                ],
                          ctx=ast.Load()) for relevant_name in relevant_names
            ],
                      ctx=ast.Load())
        ],
                  ctx=ast.Load()))

    # While happy return shall not be executed, we add it here for robustness in case
    # future refactorings forget to check for that edge case.
    happy_return = ast.Return(
        ast.Tuple(elts=[
            ast.Name(id=result_id, ctx=ast.Load()),
            ast.Constant(value=None, kind=None)
        ],
                  ctx=ast.Load()))

    critical_if: If = ast.If(test=ast.Name(id=result_id, ctx=ast.Load()),
                             body=[ast.Pass()],
                             orelse=[exceptional_return])

    # Previous inner block to be added as body to the next outer block
    block = None  # type: Optional[List[ast.stmt]]
    for i, comprehension in enumerate(reversed(generator_exp.generators)):
        if i == 0:
            # This is the inner-most comprehension.
            block = [result_assignment, critical_if]
        assert block is not None

        for condition in reversed(comprehension.ifs):
            block = [ast.If(test=condition, body=block, orelse=[])]

        if not comprehension.is_async:
            block = [
                ast.For(target=comprehension.target,
                        iter=comprehension.iter,
                        body=block,
                        orelse=[])
            ]
        else:
            block = [
                ast.AsyncFor(target=comprehension.target,
                             iter=comprehension.iter,
                             body=block,
                             orelse=[])
            ]

    assert block is not None

    block.append(happy_return)

    # Now we are ready to generate the function.

    is_async = any(comprehension.is_async
                   for comprehension in generator_exp.generators)

    args = [
        ast.arg(arg=name, annotation=None)
        for name in sorted(name_to_value.keys())
    ]

    if sys.version_info < (3, 5):
        raise NotImplementedError(
            "Python versions below 3.5 not supported, got: {}".format(
                sys.version_info))

    if not is_async:
        if sys.version_info < (3, 8):
            func_def_node = ast.FunctionDef(
                name=generated_function_name,
                args=ast.arguments(args=args,
                                   kwonlyargs=[],
                                   kw_defaults=[],
                                   defaults=[],
                                   vararg=None,
                                   kwarg=None),
                decorator_list=[],
                body=block
            )  # type: Union[ast.FunctionDef, ast.AsyncFunctionDef]

            module_node = ast.Module(body=[func_def_node])
        else:
            func_def_node = ast.FunctionDef(name=generated_function_name,
                                            args=ast.arguments(args=args,
                                                               posonlyargs=[],
                                                               kwonlyargs=[],
                                                               kw_defaults=[],
                                                               defaults=[],
                                                               vararg=None,
                                                               kwarg=None),
                                            decorator_list=[],
                                            body=block)

            module_node = ast.Module(body=[func_def_node], type_ignores=[])
    else:
        if sys.version_info < (3, 8):
            async_func_def_node = ast.AsyncFunctionDef(
                name=generated_function_name,
                args=ast.arguments(args=args,
                                   kwonlyargs=[],
                                   kw_defaults=[],
                                   defaults=[],
                                   vararg=None,
                                   kwarg=None),
                decorator_list=[],
                body=block)

            module_node = ast.Module(body=[async_func_def_node])
        else:
            async_func_def_node = ast.AsyncFunctionDef(
                name=generated_function_name,
                args=ast.arguments(args=args,
                                   posonlyargs=[],
                                   kwonlyargs=[],
                                   kw_defaults=[],
                                   defaults=[],
                                   vararg=None,
                                   kwarg=None),
                decorator_list=[],
                body=block)

            module_node = ast.Module(body=[async_func_def_node],
                                     type_ignores=[])

    ast.fix_missing_locations(module_node)

    return module_node
Exemplo n.º 42
0
 def make_arg(name):
     if sys.version_info >= (3, 0):
         return ast.arg(arg=name, annotation=None)
     else:
         return ast.Name(id=name, ctx=ast.Param(), lineno=1, col_offset=0)
Exemplo n.º 43
0
def make_arg(key, annotation=None):
    """Make an ast function argument."""
    arg = ast.arg(key, annotation)
    arg.lineno, arg.col_offset = 0, 0
    return arg
Exemplo n.º 44
0
def test_it_processes_supported_nodes(node, expected):
    """Tests that it processes supported nodes correctly"""

    actual = get_annotation_value(ast.arg(annotation=node))
    assert actual == expected
Exemplo n.º 45
0
 def signature_spec_arg(self, node, var, write_comma, prefix):
     arg = getattr(node, var)
     if arg:
         if hasattr(node, var + 'annotation'):
             arg = ast.arg(arg, getattr(node, var + 'annotation'))
         self.signature_arg(arg, None, write_comma, prefix)
Exemplo n.º 46
0
        [ast.Constant(value=False), False],
        [ast.Constant(value=None), None],
        [ast.Name(id="identifier"), "identifier"],
    ],
)
def test_it_processes_supported_nodes(node, expected):
    """Tests that it processes supported nodes correctly"""

    actual = get_annotation_value(ast.arg(annotation=node))
    assert actual == expected


@pytest.mark.parametrize(
    "node",
    [
        ast.arg(annotation=mock.Mock()),
        ast.arg(annotation=str()),
    ],
)
def test_it_logs_errors_for_unsupported_node_types(mocker, node):
    """Tests that it logs an error for unsupported node types"""

    logger = mocker.patch("sphinx_ast_autodoc.utils.logger")

    get_annotation_value(node)
    assert logger.error.called is True


def test_it_logs_warnings_for_unsupported_subscript_nodes(mocker):
    """Tests that it logs a warning for unsupported ast.Subscript nodes"""
Exemplo n.º 47
0
    def __build_function(self, dom_name, full_name, func_params):

        assert 'name' in func_params
        func_name = func_params['name']

        docstr = self.__build_desc_string(dom_name, func_name, func_params)

        args = [ast.arg('self', None)]
        message_params = []
        func_body = []

        if docstr:
            func_body.append(ast.Expr(ast.Str("\n" + docstr + "\n\t\t")))

        for param in func_params.get("parameters", []):

            argname = param['name']

            param_optional = param.get("optional", False)

            if param_optional is False:
                message_params.append(
                    ast.keyword(argname, ast.Name(id=argname, ctx=ast.Load())))
                args.append(ast.arg(argname, None))
                if self.do_debug_prints:
                    func_body.append(self.__build_debug_print(
                        argname, argname))

            param_type = param.get("type", None)
            if param_type in CHECKS:
                if param_optional:
                    check = self.__build_conditional_arg_check(
                        argname, CHECKS[param_type])
                else:
                    check = self.__build_unconditional_arg_check(
                        argname, CHECKS[param_type])

                if check:
                    func_body.append(check)

        optional_params = [
            param.get("name") for param in func_params.get("parameters", [])
            if param.get("optional", False)
        ]
        func_kwargs = None
        if len(optional_params):

            value = ast.List(elts=[
                ast.Str(s=param, ctx=ast.Store()) for param in optional_params
            ],
                             ctx=ast.Load())
            create_list = ast.Assign(
                targets=[ast.Name(id='expected', ctx=ast.Store())],
                value=value)

            func_body.append(create_list)

            passed_arg_list = ast.Assign(
                targets=[ast.Name(id='passed_keys', ctx=ast.Store())],
                value=ast.Call(func=ast.Name(id='list', ctx=ast.Load()),
                               args=[
                                   ast.Call(func=ast.Attribute(value=ast.Name(
                                       id='kwargs', ctx=ast.Load()),
                                                               attr='keys',
                                                               ctx=ast.Load()),
                                            args=[],
                                            keywords=[])
                               ],
                               keywords=[]))

            func_body.append(passed_arg_list)

            comprehension = ast.comprehension(target=ast.Name(id='key',
                                                              ctx=ast.Store()),
                                              iter=ast.Name(id='passed_keys',
                                                            ctx=ast.Load()),
                                              ifs=[],
                                              is_async=False)
            comparator = ast.Name(id='expected', ctx=ast.Load())

            listcomp = ast.ListComp(elt=ast.Compare(left=ast.Name(
                id='key', ctx=ast.Load()),
                                                    ops=[ast.In()],
                                                    comparators=[comparator]),
                                    generators=[comprehension])

            check_message = ast.BinOp(left=ast.Str(
                s="Allowed kwargs are {}. Passed kwargs: %s".format(
                    optional_params)),
                                      op=ast.Mod(),
                                      right=ast.Name(id='passed_keys',
                                                     ctx=ast.Load()),
                                      lineno=self.__get_line())

            kwarg_check = ast.Assert(test=ast.Call(func=ast.Name(
                id='all', ctx=ast.Load()),
                                                   args=[listcomp],
                                                   keywords=[]),
                                     msg=check_message)
            func_body.append(kwarg_check)

            func_kwargs = ast.Name(id='kwargs', ctx=ast.Load())

        fname = "{}.{}".format(dom_name, func_name)
        fname = ast.Str(s=fname, ctx=ast.Load())


        if (sys.version_info[0], sys.version_info[1]) == (3, 5) or \
         (sys.version_info[0], sys.version_info[1]) == (3, 6) or \
         (sys.version_info[0], sys.version_info[1]) == (3, 7):

            # More irritating minor semantic differences in the AST between 3.4 and 3.5
            if func_kwargs:
                message_params.append(
                    ast.keyword(arg=None,
                                value=ast.Name(id='kwargs', ctx=ast.Load())))

            communicate_call = ast.Call(func=ast.Attribute(
                value=ast.Name(id='self', ctx=ast.Load()),
                ctx=ast.Load(),
                attr='synchronous_command'),
                                        args=[fname],
                                        keywords=message_params)

        elif (sys.version_info[0], sys.version_info[1]) == (3, 4):

            communicate_call = ast.Call(func=ast.Attribute(
                value=ast.Name(id='self', ctx=ast.Load()),
                ctx=ast.Load(),
                attr='synchronous_command'),
                                        args=[fname],
                                        kwargs=func_kwargs,
                                        keywords=message_params)
        else:
            print("Version:", sys.version_info)
            raise RuntimeError(
                "This script only functions on python 3.4, 3.5, 3.6, or 3.7. Active python version {}.{}"
                .format(*sys.version_info))

        do_communicate = ast.Assign(
            targets=[ast.Name(id='subdom_funcs', ctx=ast.Store())],
            value=communicate_call)
        func_ret = ast.Return(
            value=ast.Name(id='subdom_funcs', ctx=ast.Load()))

        if len(optional_params) and self.do_debug_prints:
            func_body.append(self.__build_debug_print('kwargs', 'kwargs'))

        func_body.append(do_communicate)
        func_body.append(func_ret)

        if len(optional_params):
            kwarg = ast.arg(arg='kwargs', annotation=None)
        else:
            kwarg = None

        sig = ast.arguments(args=args,
                            vararg=None,
                            varargannotation=None,
                            kwonlyargs=[],
                            kwarg=kwarg,
                            kwargannotation=None,
                            defaults=[],
                            kw_defaults=[])

        func = ast.FunctionDef(
            name="{}_{}".format(full_name, func_name),
            args=sig,
            body=func_body,
            decorator_list=[],
            lineno=self.__get_line(),
            col_offset=0,
        )

        return func
Exemplo n.º 48
0
 def from_arg(symbol: str):
     arg(arg=symbol, annotation=None)
     pass
Exemplo n.º 49
0
import ast

ASTNone = ast.NameConstant(value=None)
dirFunc = ast.Name(id="dir", ctx=ast.Load())
strAST = ast.Name(id="str", ctx=ast.Load())
getAttrFunc = ast.Name(id="getattr", ctx=ast.Load())

ASTSelf = ast.Name(id="self", ctx=ast.Load())
astSelfArg = ast.arg(arg="self", annotation=None, type_comment=None)
ASTSelfClass = ast.Attribute(value=ASTSelf, attr="__class__", ctx=ast.Load())
AST__slots__ = ast.Name(id="__slots__", ctx=ast.Load())
emptySlots = ast.Assign(targets=[ast.Name(id="__slots__", ctx=ast.Store())],
                        value=ast.Tuple(elts=[], ctx=ast.Load()),
                        type_comment=None)
ASTTypeError = ast.Name(id="TypeError", ctx=ast.Load())
typingAST = ast.Name(id="typing", ctx=ast.Load())
typingOptionalAST = ast.Attribute(value=typingAST, attr="Optional")
typingIterableAST = ast.Attribute(value=typingAST, attr="Iterable")
typingUnionAST = ast.Attribute(value=typingAST, attr="Union")
Exemplo n.º 50
0
def build_ast_class():
    '''
		ClassDef(name='TestWat',
			bases=[Name(id='object')],
			keywords=[],
			starargs=None,
			kwargs=None,
			body=[
				FunctionDef(name='__init__',
					args=arguments(args=[arg(arg='self', annotation=None)],
						vararg=arg(arg='args', annotation=None),
						kwonlyargs=[],
						kw_defaults=[],
						kwarg=arg(arg='kwargs', annotation=None),
						defaults=[]),
					body=[
						Expr(
							value=Call(
								func=Attribute(
									value=Call(func=Name(id='super'), args=[], keywords=[], starargs=None, kwargs=None),
									attr='__init__'),
								args=[Name(id='self')],
								keywords=[],
								starargs=Name(id='args'),
								kwargs=Name(id='kwargs')))],
					decorator_list=[],
					returns=None)],
			decorator_list=[]),

	'''

    init_call = ast.Call(func=ast.Name(id='super', ctx=ast.Load()),
                         args=[],
                         keywords=[])
    super_func = ast.Call(
        func=ast.Attribute(value=init_call, attr='__init__', ctx=ast.Load()),
        args=[ast.Name(id='self', ctx=ast.Load())],
        starargs=ast.Name(id='args', ctx=ast.Load()),
        kwargs=ast.Name(id='kwargs', ctx=ast.Load()),
        keywords=[],
    )

    super_init = ast.Expr(
        value=super_func,
        lineno=3,
        col_offset=0,
    )

    body = [super_init]
    # body = [ast.Pass()]

    sig = ast.arguments(args=[ast.arg('self', None)],
                        vararg=ast.arg(arg='args', annotation=None),
                        kwarg=ast.arg(arg='kwargs', annotation=None),
                        varargannotation=None,
                        kwonlyargs=[],
                        kwargannotation=None,
                        defaults=[],
                        kw_defaults=[])

    init_func = ast.FunctionDef(
        name="__init__",
        args=sig,
        body=body,
        decorator_list=[],
        lineno=2,
        col_offset=0,
    )

    body = [ast.Expr(value=ast.Str(s='\n\n\t')), init_func]

    # print(body)

    interface_class = ast.ClassDef(
        name="Test-Class",
        bases=[],
        body=body,
        keywords=[],
        decorator_list=[],
        starargs=None,
        kwargs=None,
        lineno=1,
        col_offset=0,
    )
    print("Interface class:", interface_class)
    return interface_class
Exemplo n.º 51
0
 def visit_arguments(self, a):
     # print('[pre]', ast.dump(a))
     a.args.append(ast.arg(arg='k', annotation=['a -> a']))
     # print('[post]', ast.dump(a))
     return a
Exemplo n.º 52
0
def _lower_array_expr(lowerer, expr):
    '''Lower an array expression built by RewriteArrayExprs.
    '''
    expr_name = "__numba_array_expr_%s" % (hex(hash(expr)).replace("-", "_"))
    expr_var_list = expr.list_vars()
    expr_var_map = {}
    for expr_var in expr_var_list:
        expr_var_name = expr_var.name
        expr_var_new_name = expr_var_name.replace("$", "_").replace(".", "_")
        # Avoid inserting existing var into the expr_var_map
        if expr_var_new_name not in expr_var_map:
            expr_var_map[expr_var_new_name] = expr_var_name, expr_var
        expr_var.name = expr_var_new_name
    expr_filename = expr_var_list[0].loc.filename
    # Parameters are the names internal to the new closure.
    expr_params = sorted(expr_var_map.keys())
    # Arguments are the names external to the new closure (except in
    # Python abstract syntax, apparently...)
    expr_args = [expr_var_map[key][0] for key in expr_params]
    if hasattr(ast, "arg"):
        # Should be Python 3.x
        ast_args = [ast.arg(param_name, None)
                    for param_name in expr_params]
    else:
        # Should be Python 2.x
        ast_args = [ast.Name(param_name, ast.Param())
                    for param_name in expr_params]
    # Parse a stub function to ensure the AST is populated with
    # reasonable defaults for the Python version.
    ast_module = ast.parse('def {0}(): return'.format(expr_name),
                           expr_filename, 'exec')
    assert hasattr(ast_module, 'body') and len(ast_module.body) == 1
    ast_fn = ast_module.body[0]
    ast_fn.args.args = ast_args
    ast_fn.body[0].value, namespace = _arr_expr_to_ast(expr.expr)
    ast.fix_missing_locations(ast_module)
    code_obj = compile(ast_module, expr_filename, 'exec')
    six.exec_(code_obj, namespace)
    impl = namespace[expr_name]

    context = lowerer.context
    builder = lowerer.builder
    outer_sig = expr.ty(*(lowerer.typeof(name) for name in expr_args))
    inner_sig_args = []
    for argty in outer_sig.args:
        if isinstance(argty, types.Array):
            inner_sig_args.append(argty.dtype)
        else:
            inner_sig_args.append(argty)
    inner_sig = outer_sig.return_type.dtype(*inner_sig_args)

    cres = context.compile_only_no_cache(builder, impl, inner_sig)

    class ExprKernel(npyimpl._Kernel):
        def generate(self, *args):
            arg_zip = zip(args, self.outer_sig.args, inner_sig.args)
            cast_args = [self.cast(val, inty, outty)
                         for val, inty, outty in arg_zip]
            result = self.context.call_internal(
                builder, cres.fndesc, inner_sig, cast_args)
            return self.cast(result, inner_sig.return_type,
                             self.outer_sig.return_type)

    args = [lowerer.loadvar(name) for name in expr_args]
    return npyimpl.numpy_ufunc_kernel(
        context, builder, outer_sig, args, ExprKernel, explicit_output=False)
Exemplo n.º 53
0
        tree = parse(source, filename=filename)
        return tree
    except SyntaxError as e:
        cause = e if showInternalBacktrace else None
        raise PythonParseError.fromSyntaxError(e, lineMap) from cause


### TRANSLATION PHASE FOUR: modifying the parse tree

noArgs = ast.arguments(args=[],
                       vararg=None,
                       kwonlyargs=[],
                       kw_defaults=[],
                       kwarg=None,
                       defaults=[])
selfArg = ast.arguments(args=[ast.arg(arg='self', annotation=None)],
                        vararg=None,
                        kwonlyargs=[],
                        kw_defaults=[],
                        kwarg=None,
                        defaults=[])


class AttributeFinder(NodeVisitor):
    """Utility class for finding all referenced attributes of a given name."""
    @staticmethod
    def find(target, node):
        af = AttributeFinder(target)
        af.visit(node)
        return af.attributes
Exemplo n.º 54
0
 def visitParameter(self, ctx: PlSqlParser.ParameterContext):
     ret = self.visitChildren(ctx)
     name, *_ = ret
     return ast.arg(arg=name, annotation=None)
Exemplo n.º 55
0
 def create_def(self, func_name: str, arguments: [str], body):
     """Wrapper over function definition AST node, whose constructor is inconvenient."""
     return ast.FunctionDef(
         func_name,
         ast.arguments([ast.arg(argument, None) for argument in arguments],
                       None, [], [], None, []), body, [], None)
Exemplo n.º 56
0
    def visit_For(self, node):
        self.generic_visit(node)

        #| recognize the pattern of PyTorch's DataLoader |#
        def isPyTorchDataLoader(tgt, iter):
            return isinstance(tgt, ast.Tuple) and \
            len(tgt.elts) == 2 and \
            isinstance(tgt.elts[0], ast.Name) and \
            isinstance(tgt.elts[1], ast.Tuple) and \
            len(tgt.elts[1].elts) == 2 and \
            isinstance(tgt.elts[1].elts[0], ast.Name) and \
            isinstance(tgt.elts[1].elts[1], ast.Name) and \
            isinstance(iter, ast.Call) and \
            iter.func.id == 'enumerate' and \
            'loader' in iter.args[0].id

        #| Transforms the target names to list of strings |#
        def targetToList(tgt):
            def extract(x):
                if isinstance(x, ast.Name): return x.id
                elif isinstance(x, ast.Tuple): return targetToList(x.elts)
                else: raise NotImplementedError

            return list(map(extract, tgt))

        def targetToFlatList(tgt):
            res = []
            for item in targetToList(tgt):
                if isinstance(item, list): res.extend(item)
                else: res.append(item)
            return res

        if isPyTorchDataLoader(node.target, node.iter):
            outer_fun_name = self.freshName("forfunc")
            outer_fun = ast.FunctionDef(
                name=outer_fun_name,
                args=ast.arguments(args=list(
                    map(lambda x: ast.arg(arg=x, annotation=None),
                        targetToFlatList(node.target.elts))),
                                   vararg=None,
                                   kwonlyargs=[],
                                   kwarg=None,
                                   defaults=[],
                                   kw_defaults=[]),
                body=node.body,
                decorator_list=[])
            ast.fix_missing_locations(outer_fun)

            # self.scope.visit(outer_fun)
            # self.visit(outer_fun)

            new_node = ast.Expr(
                ast.Call(func=ast.Name(id='_for_dataloader', ctx=ast.Load()),
                         args=[
                             node.iter.args[0],
                             ast.Name(id=outer_fun_name, ctx=ast.Load())
                         ],
                         keywords=[]))
            #ast.copy_location(new_node, node)
            ast.fix_missing_locations(new_node)
            return [outer_fun, new_node]
        else:
            bFun_name = self.freshName("body")
            bFun = ast.FunctionDef(name=bFun_name,
                                   args=ast.arguments(args=[
                                       ast.arg(arg=node.target.id,
                                               annotation=None)
                                   ],
                                                      vararg=None,
                                                      kwonlyargs=[],
                                                      kwarg=None,
                                                      defaults=[],
                                                      kw_defaults=[]),
                                   body=node.body,
                                   decorator_list=[],
                                   returns=None)
            ast.fix_missing_locations(bFun)

            # self.scope.visit(bFun)
            # self.visit(bFun)

            new_node = ast.Expr(
                ast.Call(
                    func=ast.Name(id='_for', ctx=ast.Load()),
                    args=[node.iter,
                          ast.Name(id=bFun_name, ctx=ast.Load())],
                    keywords=[]))
            ast.copy_location(new_node, node)
            ast.fix_missing_locations(new_node)
            return [bFun, new_node]
Exemplo n.º 57
0
 def from_arg_typed(symbol: str, typed: str):
     return arg(arg=symbol, annotation=Name(id=typed, ctx=Load()))
Exemplo n.º 58
0
 def from_arg_typed_expr(symbol: str, typed: str):
     return arg(arg=symbol, annotation=typed)
Exemplo n.º 59
0
def _lower_array_expr(lowerer, expr):
    '''Lower an array expression built by RewriteArrayExprs.
    '''
    expr_name = "__numba_array_expr_%s" % (hex(hash(expr)).replace("-", "_"))
    expr_filename = expr.loc.filename
    expr_var_list = expr.list_vars()
    # The expression may use a given variable several times, but we
    # should only create one parameter for it.
    expr_var_unique = sorted(set(expr_var_list), key=lambda var: var.name)

    # Arguments are the names external to the new closure
    expr_args = [var.name for var in expr_var_unique]

    # 1. Create an AST tree from the array expression.
    with _legalize_parameter_names(expr_var_unique) as expr_params:
        ast_args = [ast.arg(param_name, None) for param_name in expr_params]
        # Parse a stub function to ensure the AST is populated with
        # reasonable defaults for the Python version.
        ast_module = ast.parse('def {0}(): return'.format(expr_name),
                               expr_filename, 'exec')
        assert hasattr(ast_module, 'body') and len(ast_module.body) == 1
        ast_fn = ast_module.body[0]
        ast_fn.args.args = ast_args
        ast_fn.body[0].value, namespace = _arr_expr_to_ast(expr.expr)
        ast.fix_missing_locations(ast_module)

    # 2. Compile the AST module and extract the Python function.
    code_obj = compile(ast_module, expr_filename, 'exec')
    exec(code_obj, namespace)
    impl = namespace[expr_name]

    # 3. Now compile a ufunc using the Python function as kernel.

    context = lowerer.context
    builder = lowerer.builder
    outer_sig = expr.ty(*(lowerer.typeof(name) for name in expr_args))
    inner_sig_args = []
    for argty in outer_sig.args:
        if isinstance(argty, types.Optional):
            argty = argty.type
        if isinstance(argty, types.Array):
            inner_sig_args.append(argty.dtype)
        else:
            inner_sig_args.append(argty)
    inner_sig = outer_sig.return_type.dtype(*inner_sig_args)

    # Follow the Numpy error model.  Note this also allows e.g. vectorizing
    # division (issue #1223).
    flags = compiler.Flags()
    flags.set('error_model', 'numpy')
    cres = context.compile_subroutine(builder,
                                      impl,
                                      inner_sig,
                                      flags=flags,
                                      caching=False)

    # Create kernel subclass calling our native function
    from numba.np import npyimpl

    class ExprKernel(npyimpl._Kernel):
        def generate(self, *args):
            arg_zip = zip(args, self.outer_sig.args, inner_sig.args)
            cast_args = [
                self.cast(val, inty, outty) for val, inty, outty in arg_zip
            ]
            result = self.context.call_internal(builder, cres.fndesc,
                                                inner_sig, cast_args)
            return self.cast(result, inner_sig.return_type,
                             self.outer_sig.return_type)

    # create a fake ufunc object which is enough to trick numpy_ufunc_kernel
    ufunc = SimpleNamespace(nin=len(expr_args), nout=1, __name__=expr_name)
    ufunc.nargs = ufunc.nin + ufunc.nout

    args = [lowerer.loadvar(name) for name in expr_args]
    return npyimpl.numpy_ufunc_kernel(context, builder, outer_sig, args, ufunc,
                                      ExprKernel)
Exemplo n.º 60
0
    class Patterns:
        """
        Stores the pattern nodes / templates to be used extracting information from the Python ast.

        Patterns are a 1-to-1 mapping from context and Python ast node to GTScript ast node. Context is encoded in the
        field types and all understood sementic is encoded in the structure.
        """

        Symbol = ast.Name(id=Capture("name"))

        IterationOrder = ast.withitem(
            context_expr=ast.Call(func=ast.Name(id="computation"),
                                  args=[ast.Name(id=Capture("order"))]))

        Constant = ast.Constant(value=Capture("value"))

        Interval = ast.withitem(context_expr=ast.Call(
            func=ast.Name(id="interval"),
            args=[Capture("start"), Capture("stop")]))

        # TODO(tehrengruber): this needs to be a function, since the uid must be generated each time
        LocationSpecification = ast.withitem(
            context_expr=ast.Call(func=ast.Name(id="location"),
                                  args=[ast.Name(id=Capture("location_type"))
                                        ]),
            optional_vars=Capture(
                "name",
                default=ast.Name(id=UIDGenerator.sequential_id(
                    prefix="location"))),
        )

        SubscriptSingle = ast.Subscript(
            value=Capture("value"),
            slice=ast.Index(value=ast.Name(id=Capture("index"))))

        SubscriptMultiple = ast.Subscript(
            value=Capture("value"),
            slice=ast.Index(value=ast.Tuple(elts=Capture("indices"))))

        BinaryOp = ast.BinOp(op=Capture("op"),
                             left=Capture("left"),
                             right=Capture("right"))

        Call = ast.Call(args=Capture("args"),
                        func=ast.Name(id=Capture("func")))

        LocationComprehension = ast.comprehension(target=Capture("target"),
                                                  iter=Capture("iterator"))

        Generator = ast.GeneratorExp(generators=Capture("generators"),
                                     elt=Capture("elt"))

        Assign = ast.Assign(targets=[Capture("target")],
                            value=Capture("value"))

        Stencil = ast.With(items=Capture("iteration_spec"),
                           body=Capture("body"))

        Pass = ast.Pass()

        Argument = ast.arg(arg=Capture("name"), annotation=Capture("type_"))

        Computation = ast.FunctionDef(
            args=ast.arguments(args=Capture("arguments")),
            body=Capture("stencils"),
            name=Capture("name"),
        )