def setUp(self):
     self.source = """
       def test_fn(x, y):
         z = 1
         if x > y:
            z = x * x
            z = z + y
         return z
     """
     self.all_name_ids = {
         'x': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
         ],
         'y': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
         ],
         'z': [
             gast.Store(),
             gast.Store(),
             gast.Load(),
             gast.Store(),
             gast.Load(),
         ]
     }
 def setUp(self):
     self.source = """
       def test_fn(x, y):
         a = 1
         x = y + a
         if x > y:
            z = x * x
            z = z + a
         else:
            z = y * y
         return z
     """
     self.all_name_ids = {
         'x': [
             gast.Param(), gast.Store(), gast.Load(), gast.Load(),
             gast.Load()
         ],
         'a': [gast.Store(), gast.Load(), gast.Load()],
         'y': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
         ],
         'z': [
             gast.Store(),
             gast.Load(),
             gast.Store(),
             gast.Store(),
             gast.Load(),
         ]
     }
Example #3
0
def create_while_node(condition_name, body_name, loop_var_names):
    while_args = []
    while_args.append(
        gast.Name(id=condition_name,
                  ctx=gast.Param(),
                  annotation=None,
                  type_comment=None))
    while_args.append(
        gast.Name(id=body_name,
                  ctx=gast.Param(),
                  annotation=None,
                  type_comment=None))
    assign_targets = [
        gast.Name(id=var_name,
                  ctx=gast.Param(),
                  annotation=None,
                  type_comment=None) for var_name in loop_var_names
    ]
    while_args.append(gast.List(elts=assign_targets, ctx=gast.Param()))

    while_func_id = gast.parse('fluid.layers.while_loop').body[0].value
    while_node = gast.Call(func=while_func_id, args=while_args, keywords=[])
    assign_node = gast.Assign(
        targets=[gast.Tuple(elts=assign_targets, ctx=gast.Store())],
        value=while_node)
    return assign_node
Example #4
0
    def visit_Lambda(self, node):
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.imports.append(import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.passmanager.gather(ImportedIds, node, self.ctx)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [ast.Name(iin, ast.Load(), None) for iin in sorted(ii)]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None)
             for iin in sorted(ii)] + node.args.args)
        forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                      [ast.Return(node.body)], [], None)
        self.lambda_functions.append(forged_fdef)
        self.global_declarations[forged_name] = forged_fdef
        proxy_call = ast.Name(forged_name, ast.Load(), None)
        if binded_args:
            return ast.Call(
                ast.Attribute(ast.Name(mangle('functools'), ast.Load(), None),
                              "partial", ast.Load()),
                [proxy_call] + binded_args, [])
        else:
            return proxy_call
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source, _ = compiler.ast_to_object(node)

    expected_source = """
      # coding=utf-8
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
    def visit_FunctionDef(self, node):
        self.update = True
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.ctx.module.body.insert(0, import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.ctx.module.body.append(node)

        former_name = node.name
        seed = 0
        new_name = "pythran_{}{}"

        while new_name.format(former_name, seed) in self.identifiers:
            seed += 1

        new_name = new_name.format(former_name, seed)
        self.identifiers.add(new_name)

        ii = self.gather(ImportedIds, node)
        binded_args = [
            ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)
        ]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None, None)
             for iin in sorted(ii)] + node.args.args)

        metadata.add(node, metadata.Local())

        class Renamer(ast.NodeTransformer):
            def visit_Call(self, node):
                self.generic_visit(node)
                if (isinstance(node.func, ast.Name)
                        and node.func.id == former_name):
                    node.func.id = new_name
                    node.args = ([
                        ast.Name(iin, ast.Load(), None, None)
                        for iin in sorted(ii)
                    ] + node.args)
                return node

        Renamer().visit(node)

        node.name = new_name
        self.global_declarations[node.name] = node
        proxy_call = ast.Name(new_name, ast.Load(), None, None)

        new_node = ast.Assign([ast.Name(former_name, ast.Store(), None, None)],
                              ast.Call(
                                  ast.Attribute(
                                      ast.Name(mangle('functools'), ast.Load(),
                                               None, None), "partial",
                                      ast.Load()),
                                  [proxy_call] + binded_args,
                                  [],
                              ))

        self.generic_visit(node)
        return new_node
Example #7
0
def outline(name, formal_parameters, out_parameters, stmts, has_return):
    args = ast.arguments(
        [ast.Name(fp, ast.Param(), None) for fp in formal_parameters], None,
        [], [], None, [])

    if isinstance(stmts, ast.expr):
        assert not out_parameters, "no out parameters with expr"
        fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None)
    else:
        fdef = ast.FunctionDef(name, args, stmts, [], None)

        # this is part of a huge trick that plays with delayed type inference
        # it basically computes the return type based on out parameters, and
        # the return statement is unconditionally added so if we have other
        # returns, there will be a computation of the output type based on the
        # __combined of the regular return types and this one The original
        # returns have been patched above to have a different type that
        # cunningly combines with this output tuple
        #
        # This is the only trick I found to let pythran compute both the output
        # variable type and the early return type. But hey, a dirty one :-/

        stmts.append(
            ast.Return(
                ast.Tuple(
                    [ast.Name(fp, ast.Load(), None) for fp in out_parameters],
                    ast.Load())))
        if has_return:
            pr = PatchReturn(stmts[-1])
            pr.visit(fdef)

    return fdef
Example #8
0
 def visit_arg(self, node):
     new_node = gast.Name(
         self._visit(node.arg),
         gast.Param(),
         self._visit(node.annotation),
     )
     return ast.copy_location(new_node, node)
Example #9
0
    def visit_Lambda(self, node):
        op = issimpleoperator(node)
        if op is not None:
            if mangle('operator') not in self.global_declarations:
                import_ = ast.Import(
                    [ast.alias('operator', mangle('operator'))])
                self.imports.append(import_)
                operator_module = MODULES['operator']
                self.global_declarations[mangle('operator')] = operator_module
            return ast.Attribute(
                ast.Name(mangle('operator'), ast.Load(), None, None), op,
                ast.Load())

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.gather(ImportedIds, node)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [
            ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)
        ]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None, None)
             for iin in sorted(ii)] + node.args.args)
        for patternname, pattern in self.patterns.items():
            if issamelambda(pattern, node):
                proxy_call = ast.Name(patternname, ast.Load(), None, None)
                break
        else:
            duc = ExtendedDefUseChains()
            nodepattern = deepcopy(node)
            duc.visit(ast.Module([ast.Expr(nodepattern)], []))
            self.patterns[forged_name] = nodepattern, duc

            forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                          [ast.Return(node.body)], [], None,
                                          None)
            metadata.add(forged_fdef, metadata.Local())
            self.lambda_functions.append(forged_fdef)
            self.global_declarations[forged_name] = forged_fdef
            proxy_call = ast.Name(forged_name, ast.Load(), None, None)

        if binded_args:
            if MODULES['functools'] not in self.global_declarations.values():
                import_ = ast.Import(
                    [ast.alias('functools', mangle('functools'))])
                self.imports.append(import_)
                functools_module = MODULES['functools']
                self.global_declarations[mangle(
                    'functools')] = functools_module

            return ast.Call(
                ast.Attribute(
                    ast.Name(mangle('functools'), ast.Load(), None, None),
                    "partial", ast.Load()), [proxy_call] + binded_args, [])
        else:
            return proxy_call
Example #10
0
 def __init__(self, **kwargs):
     self.argument_effects = kwargs.get('argument_effects',
                                        (UpdateEffect(), ) * DefaultArgNum)
     self.global_effects = kwargs.get('global_effects', False)
     self.return_alias = kwargs.get('return_alias',
                                    lambda x: {UnboundValue})
     self.args = ast.arguments([
         ast.Name(n, ast.Param(), None, None)
         for n in kwargs.get('args', [])
     ], [], None, [
         ast.Name(n, ast.Param(), None, None)
         for n in kwargs.get('kwonlyargs', [])
     ], [], None, [to_ast(d) for d in kwargs.get('defaults', [])])
     self.return_range = kwargs.get("return_range",
                                    lambda call: UNKNOWN_RANGE)
     self.return_range_content = kwargs.get("return_range_content",
                                            lambda c: UNKNOWN_RANGE)
Example #11
0
    def visit_ListComp(self, node):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iterList = []
            varList = []

            for gen in node.generators:
                iterList.append(self.make_Iterator(gen))
                varList.append(ast.Name(gen.target.id, ast.Param(), None))

            # If dim = 1, product is useless
            if len(iterList) == 1:
                iterAST = iterList[0]
                varAST = ast.arguments([varList[0]], None, [], [], None, [])
            else:
                self.use_itertools = True
                prodName = ast.Attribute(value=ast.Name(id=mangle('itertools'),
                                                        ctx=ast.Load(),
                                                        annotation=None),
                                         attr='product',
                                         ctx=ast.Load())

                varid = varList[0].id  # retarget this id, it's free
                renamings = {v.id: (i, ) for i, v in enumerate(varList)}
                node.elt = ConvertToTuple(varid, renamings).visit(node.elt)
                iterAST = ast.Call(prodName, iterList, [])
                varAST = ast.arguments([ast.Name(varid, ast.Param(), None)],
                                       None, [], [], None, [])

            mapName = ast.Attribute(value=ast.Name(id='__builtin__',
                                                   ctx=ast.Load(),
                                                   annotation=None),
                                    attr='map',
                                    ctx=ast.Load())

            ldBodymap = node.elt
            ldmap = ast.Lambda(varAST, ldBodymap)

            return ast.Call(mapName, [ldmap, iterAST], [])

        else:
            return self.generic_visit(node)
Example #12
0
 def _make_annotated_arg(self, parent, identifier, annotation):
     if identifier is None:
         return None
     new_node = gast.Name(
         self._visit(identifier),
         gast.Param(),
         self._visit(annotation),
     )
     return ast.copy_location(new_node, parent)
Example #13
0
 def visit_Lambda(self, node):
     self.generic_visit(node)
     for i, arg in enumerate(node.args.args):
         renamings = OrderedDict()
         self.traverse_tuples(arg, (), renamings)
         if renamings:
             nname = self.get_new_id()
             node.args.args[i] = ast.Name(nname, ast.Param(), None)
             node.body = ConvertToTuple(nname, renamings).visit(node.body)
     return node
Example #14
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(ast.Name(forged_name, ast.Load(), None),
                            binded_args, [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)

            if is_trivially_copied(node.left):
                prev_holder = node.left
            else:
                body.append(
                    ast.Assign([ast.Name('$0', ast.Store(), None)], node.left))
                prev_holder = ast.Name('$0', ast.Load(), None)

            for i, exp in enumerate(node.comparators):
                if is_trivially_copied(exp):
                    holder = exp
                else:
                    body.append(
                        ast.Assign(
                            [ast.Name('${}'.format(i + 1), ast.Store(), None)],
                            exp))
                    holder = ast.Name('${}'.format(i + 1), ast.Load(), None)
                cond = ast.Compare(prev_holder, [node.ops[i]], [holder])
                body.append(
                    ast.If(
                        cond, [ast.Pass()],
                        [ast.Return(path_to_attr(('__builtin__', 'False')))]))
                prev_holder = holder

            body.append(ast.Return(path_to_attr(('__builtin__', 'True'))))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
Example #15
0
    def visitComp(self, node, make_attr):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iters = [self.make_Iterator(gen) for gen in node.generators]
            variables = [
                ast.Name(gen.target.id, ast.Param(), None, None)
                for gen in node.generators
            ]

            # If dim = 1, product is useless
            if len(iters) == 1:
                iterAST = iters[0]
                varAST = ast.arguments([variables[0]], [], None, [], [], None,
                                       [])
            else:
                self.use_itertools = True
                prodName = ast.Attribute(value=ast.Name(id=mangle('itertools'),
                                                        ctx=ast.Load(),
                                                        annotation=None,
                                                        type_comment=None),
                                         attr='product',
                                         ctx=ast.Load())

                varid = variables[0].id  # retarget this id, it's free
                renamings = {v.id: (i, ) for i, v in enumerate(variables)}
                node.elt = ConvertToTuple(varid, renamings).visit(node.elt)
                iterAST = ast.Call(prodName, iters, [])
                varAST = ast.arguments(
                    [ast.Name(varid, ast.Param(), None, None)], [], None, [],
                    [], None, [])

            ldBodymap = node.elt
            ldmap = ast.Lambda(varAST, ldBodymap)

            return make_attr(ldmap, iterAST)

        else:
            return self.generic_visit(node)
Example #16
0
 def visit_Print(self, node):
     self.generic_visit(node)
     for n in node.values:
         n.ctx = gast.Param()
     call_node = gast.Call(func=gast.Name('print', gast.Load(), None),
                           args=node.values,
                           keywords=[])
     anno.setanno(call_node.func, 'live_val', print)
     anno.setanno(call_node.func, 'fqn', 'print')
     anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope'))
     node = gast.Expr(call_node)
     return node
Example #17
0
    def visit_arg(self, node):
        if sys.version_info.minor < 8:
            extra_arg = None
        else:
            extra_arg = self._visit(node.type_comment)

        new_node = gast.Name(
            node.arg,  # micro-optimization here, don't call self._visit
            gast.Param(),
            self._visit(node.annotation),
            extra_arg  # type_comment
        )
        return ast.copy_location(new_node, node)
Example #18
0
    def visit_arg(self, node):
        if sys.version_info.minor < 8:
            extra_args = [None]
        else:
            extra_args = [self._visit(node.type_comment)]

        new_node = gast.Name(
            self._visit(node.arg),
            gast.Param(),
            self._visit(node.annotation),
            *extra_args  # type_comment
        )
        ast.copy_location(new_node, node)
        return new_node
Example #19
0
    def visit_AnyComp(self, node, comp_type, *path):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "{0}_comprehension{1}".format(comp_type, self.count)
        self.count += 1
        args = self.gather(ImportedIds, node)
        self.count_iter = 0

        starget = "__target"
        body = reduce(self.nest_reducer,
                      reversed(node.generators),
                      ast.Expr(
                          ast.Call(
                              reduce(lambda x, y: ast.Attribute(x, y,
                                                                ast.Load()),
                                     path[1:],
                                     ast.Name(path[0], ast.Load(),
                                              None, None)),
                              [ast.Name(starget, ast.Load(), None, None),
                               node.elt],
                              [],
                              )
                          )
                      )
        # add extra metadata to this node
        metadata.add(body, metadata.Comprehension(starget))
        init = ast.Assign(
            [ast.Name(starget, ast.Store(), None, None)],
            ast.Call(
                ast.Attribute(
                    ast.Name('builtins', ast.Load(), None, None),
                    comp_type,
                    ast.Load()
                    ),
                [], [],)
            )
        result = ast.Return(ast.Name(starget, ast.Load(), None, None))
        sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, [], None, [], [], None, []),
                             [init, body, result],
                             [], None, None)
        metadata.add(fd, metadata.Local())
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None, None),
            [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
            [],
            )  # no sharing !
Example #20
0
 def make_Iterator(self, gen):
     if gen.ifs:
         ldFilter = ast.Lambda(
             ast.arguments([ast.Name(gen.target.id, ast.Param(), None)],
                           None, [], [], None, []),
             ast.BoolOp(ast.And(), gen.ifs)
             if len(gen.ifs) > 1 else gen.ifs[0])
         ifilterName = ast.Attribute(value=ast.Name(id=MODULE,
                                                    ctx=ast.Load(),
                                                    annotation=None),
                                     attr=IFILTER,
                                     ctx=ast.Load())
         return ast.Call(ifilterName, [ldFilter, gen.iter], [])
     else:
         return gen.iter
Example #21
0
  def generate_FunctionDef(self):
    """Generate a FunctionDef node."""

    # Generate the arguments, register them as available
    arg_vars = self.sample_node_list(
        low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
    args = gast.arguments(arg_vars, None, [], [], None, [])

    # Generate the function body
    body = self.sample_node_list(
        low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement)
    body.append(self.generate_Return())
    fn_name = self.generate_Name().id
    node = gast.FunctionDef(fn_name, args, body, (), None)
    return node
Example #22
0
    def visit_BinOp(self, node):
        self.generic_visit(node)
        wl, wr = [self.result[x].isweak() for x in (node.left, node.right)]
        if(isinstance(node.op, ast.Add) and any([wl, wr]) and
           not all([wl, wr])):
            # assumes the + operator always has the same operand type
            # on left and right side
            F = operator.add
        else:
            def F(x, y):
                return ExpressionType(operator_to_lambda[type(node.op)],
                                      [x, y])

        fake_node = ast.Name("#", ast.Param(), None)
        self.combine(fake_node, node.left, F)
        self.combine(fake_node, node.right, F)
        self.combine(node, fake_node)
        del self.result[fake_node]
  def test_load_ast(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[
                gast.Name(
                    'a', ctx=gast.Param(), annotation=None, type_comment=None)
            ],
            posonlyargs=[],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name(
                        'a',
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None),
                    right=gast.Constant(1, kind=None)))
        ],
        decorator_list=[],
        returns=None,
        type_comment=None)

    module, source, _ = loader.load_ast(node)

    expected_source = """
      # coding=utf-8
      def f(a):
          return (a + 1)
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
Example #24
0
    def visit_While(self, node):
        self.generic_visit(node)
        # Scrape out the data flow analysis
        body_scope = anno.getanno(node, 'body_scope')
        parent_scope_values = anno.getanno(node, 'parent_scope_values')
        body_closure = tuple(body_scope.modified - body_scope.created)

        def template(
                state_args,  # pylint:disable=unused-argument
                state_locals,
                state_results,  # pylint:disable=unused-argument
                test_name,
                test,  # pylint:disable=unused-argument
                body_name,
                body,
                state_init):
            def test_name(state_args):  # pylint:disable=function-redefined,unused-argument
                return test

            def body_name(state_args):  # pylint:disable=function-redefined,unused-argument
                body  # pylint:disable=pointless-statement
                return state_locals

            state_results = tf.while_loop(test_name, body_name, [state_init])  # pylint:disable=undefined-variable

        test_name = self.namer.new_symbol('loop_test', body_scope.used)
        body_name = self.namer.new_symbol('loop_body', body_scope.used)
        node = templates.replace(
            template,
            state_args=self._tuple_or_item(
                gast.Name(n, gast.Param(), None) for n in body_closure),
            state_locals=self._ast_tuple_or_item(
                (gast.Name(n, gast.Load(), None) for n in body_closure),
                gast.Load()),
            state_results=self._ast_tuple_or_item(
                (gast.Name(n, gast.Store(), None) for n in body_closure),
                gast.Store()),
            test_name=gast.Name(test_name, gast.Load(), None),
            test=node.test,
            body_name=gast.Name(body_name, gast.Load(), None),
            body=node.body,
            state_init=[parent_scope_values.getval(n) for n in body_closure])

        return node
    def visit_Assign(self, node):
        self.generic_visit(node)

        val = node.value
        if isinstance(val, gast.Call):
            if isinstance(val.func, gast.Call):
                if isinstance(val.func.func, gast.Attribute):
                    if isinstance(val.func.func.value, gast.Name):
                        if (val.func.func.value.id == 'tfe'
                                and val.func.func.attr
                                == 'value_and_gradients_function'):

                            # pylint:disable=unused-argument,undefined-variable

                            def template(loss_var, loss_fn, args, d_vars,
                                         wrt_vars):
                                loss_var = loss_fn(args)
                                d_vars = tf.gradients(loss_var, [wrt_vars])

                            # pylint:enable=unused-argument,undefined-variable

                            # How to get these values? Print out the node.
                            loss_var = gast.Name(node.targets[0].elts[0].id,
                                                 gast.Store(), None)
                            loss_fn = gast.Name(val.func.args[0].id,
                                                gast.Load(), None)
                            args = tuple(
                                gast.Name(a.id, gast.Param(), None)
                                for a in val.args)
                            d_vars = node.targets[0].elts[1]
                            wrt_vars = [
                                val.args[e.n] for e in val.func.args[1].elts
                            ]

                            node = templates.replace(template,
                                                     loss_var=loss_var,
                                                     loss_fn=loss_fn,
                                                     args=args,
                                                     d_vars=d_vars,
                                                     wrt_vars=wrt_vars)

        return node
Example #26
0
    def visit_GeneratorExp(self, node):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "generator_expression{0}".format(self.count)
        self.count += 1
        args = self.passmanager.gather(ImportedIds, node, self.ctx)
        self.count_iter = 0

        body = reduce(self.nest_reducer, reversed(node.generators),
                      ast.Expr(ast.Yield(node.elt)))

        sargs = [ast.Name(arg, ast.Param(), None) for arg in args]
        fd = ast.FunctionDef(name, ast.arguments(sargs, None, [], [], None,
                                                 []), [body], [], None)
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None),
            [ast.Name(arg.id, ast.Load(), None) for arg in sargs],
            [],
        )  # no sharing !
Example #27
0
def save_arguments(module_name, elements):
    """ Recursively save arguments name and default value. """
    for elem, signature in elements.items():
        if isinstance(signature, dict):  # Submodule case
            save_arguments(module_name + (elem, ), signature)
        else:
            # use introspection to get the Python obj
            try:
                themodule = __import__(".".join(module_name))
                obj = getattr(themodule, elem)
                spec = inspect.getargspec(obj)
                assert not signature.args.args
                signature.args.args = [
                    ast.Name(arg, ast.Param(), None) for arg in spec.args
                ]
                if spec.defaults:
                    signature.args.defaults = [
                        to_ast(default) for default in spec.defaults
                    ]
            except (AttributeError, ImportError, TypeError, ToNotEval):
                pass
Example #28
0
    def visit_GeneratorExp(self, node):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iters = [self.make_Iterator(gen) for gen in node.generators]
            variables = [
                ast.Name(gen.target.id, ast.Param(), None)
                for gen in node.generators
            ]

            # If dim = 1, product is useless
            if len(iters) == 1:
                iterAST = iters[0]
                varAST = ast.arguments([variables[0]], None, [], [], None, [])
            else:
                prodName = ast.Attribute(value=ast.Name(id='itertools',
                                                        ctx=ast.Load(),
                                                        annotation=None),
                                         attr='product',
                                         ctx=ast.Load())

                iterAST = ast.Call(prodName, iters, [])
                varAST = ast.arguments([ast.Tuple(variables, ast.Store())],
                                       None, [], [], None, [])

            imapName = ast.Attribute(value=ast.Name(id=MODULE,
                                                    ctx=ast.Load(),
                                                    annotation=None),
                                     attr=IMAP,
                                     ctx=ast.Load())

            ldBodyimap = node.elt
            ldimap = ast.Lambda(varAST, ldBodyimap)

            return ast.Call(imapName, [ldimap, iterAST], [])

        else:
            return self.generic_visit(node)
Example #29
0
def _wrap_into_factory(nodes, entity_name, inner_factory_name,
                       outer_factory_name, closure_vars, factory_args,
                       future_features):
  """Wraps an AST into the body of a factory with consistent lexical context.

  The AST is expected to define some symbol with a name given by `entity_name`.

  This mechanism ensures that the resulting transformed entity has lexical
  scoping identical to that of the source entity, while allowing extra
  parametrization.

  Two nested factories achieve the following:

   1. The inner factory dynamically creates the entity represented by `nodes`.
   2. The inner factory is parametrized by a custom set of arguments.
   3. The inner factory has a closure identical to that of the transformed
       entity.
   4. The inner factory has local variables named like `args`, which `nodes` may
       use as additional parameters.
   5. The inner factory returns the variables given by `entity_name`.
   6. The outer factory is niladic.
   7. The outer factory has no closure.
   8. The outer factory creates the necessary lexical scope for the inner
       factory, so that the loaded code has the given configuration for
       closure/globals.
   9. The outer factory returns the inner factory.

  Roughly speaking, the following code is generated:

      from __future__ import future_feature_1
      from __future__ import future_feature_2
      ...

      def outer_factory():
        closure_var_1 = None
        closure_var_2 = None
        ...

        def inner_factory(arg_1, arg_2, ...):
          <<nodes>>
          return entity

        return inner_factory

  The lexical scoping is created using dummy symbol declarations which create
  local fariables in the body of the outer factory, so that the Python parser
  correctly marks them as free non-global variables upon load (that is, it
  creates cell slots for each symbol. Thes symbols are initialized with None,
  but their values are not expected to be used; instead, the caller is expected
  to replace them with the cells of the source entity. For more details, see:
  https://docs.python.org/3/reference/executionmodel.html#binding-of-names

  Args:
    nodes: Tuple[ast.AST], the source code to wrap.
    entity_name: Union[Text, ast.AST], the name of the principal entity that
      `nodes` define.
    inner_factory_name: Text, the name of the inner factory.
    outer_factory_name: Text, the name of the outer factory.
    closure_vars: Iterable[Text], names of the closure variables for the inner
      factory.
    factory_args: Iterable[Text], names of additional arguments for the
      inner factory. Useful to configure variables that the converted code can
      use. Typically, these are modules.
    future_features: Iterable[Text], names of future statements to associate the
      code with.

  Returns:
    ast.AST
  """
  dummy_closure_defs = []
  for var_name in closure_vars:
    template = """
      var_name = None
    """
    dummy_closure_defs.extend(templates.replace(template, var_name=var_name))

  if future_features:
    future_imports = gast.ImportFrom(
        module='__future__',
        names=[gast.alias(name=name, asname=None) for name in future_features],
        level=0)
  else:
    future_imports = []

  factory_args = [
      gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
      for name in factory_args
  ]

  template = """
    future_imports
    def outer_factory_name():
      dummy_closure_defs
      def inner_factory_name(factory_args):
        entity_defs
        return entity_name
      return inner_factory_name
  """
  return templates.replace(
      template,
      dummy_closure_defs=dummy_closure_defs,
      entity_defs=nodes,
      entity_name=entity_name,
      factory_args=factory_args,
      future_imports=future_imports,
      inner_factory_name=inner_factory_name,
      outer_factory_name=outer_factory_name)
Example #30
0
    def get_while_stmt_nodes(self, node):
        # TODO: consider while - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # while x < 10 in dygraph should be convert into static tensor < 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        logical_op_transformer = LogicalOpTransformer(node.test)
        cond_value_node = logical_op_transformer.transform()

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_value_node)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)
        return new_stmts