def setUp(self):
     self.source = """
       def test_fn(x, y):
         a = 1
         x = y + a
         if x > y:
            z = x * x
            z = z + a
         else:
            z = y * y
         return z
     """
     self.all_name_ids = {
         'x': [
             gast.Param(), gast.Store(), gast.Load(), gast.Load(),
             gast.Load()
         ],
         'a': [gast.Store(), gast.Load(), gast.Load()],
         'y': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
         ],
         'z': [
             gast.Store(),
             gast.Load(),
             gast.Store(),
             gast.Store(),
             gast.Load(),
         ]
     }
 def setUp(self):
     self.source = """
       def test_fn(x, y):
         z = 1
         if x > y:
            z = x * x
            z = z + y
         return z
     """
     self.all_name_ids = {
         'x': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
         ],
         'y': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
         ],
         'z': [
             gast.Store(),
             gast.Store(),
             gast.Load(),
             gast.Store(),
             gast.Load(),
         ]
     }
Exemple #3
0
    def test_unparse(self):
        node = gast.If(test=gast.Constant(1, kind=None),
                       body=[
                           gast.Assign(targets=[
                               gast.Name('a',
                                         ctx=gast.Store(),
                                         annotation=None,
                                         type_comment=None)
                           ],
                                       value=gast.Name('b',
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None))
                       ],
                       orelse=[
                           gast.Assign(targets=[
                               gast.Name('a',
                                         ctx=gast.Store(),
                                         annotation=None,
                                         type_comment=None)
                           ],
                                       value=gast.Constant('c', kind=None))
                       ])

        source = parser.unparse(node, indentation='  ')
        self.assertEqual(
            textwrap.dedent("""
            # coding=utf-8
            if 1:
                a = b
            else:
                a = 'c'
        """).strip(), source.strip())
Exemple #4
0
 def trivialize_slice(self, node):
     if isinstance(node, gast.Slice):
         name = self.namer.name(node)
         target = gast.Name(id=name, ctx=gast.Store(), annotation=None)
         stmt = gast.Assign(targets=[target], value=None)
         self.prepend(stmt)
         stmt.value = gast.Call(
             func=gast.Name(id='slice', ctx=gast.Load(), annotation=None),
             args=[
                 self.trivialize(arg) if arg else gast.Name(
                     id='None', ctx=gast.Load(), annotation=None)
                 for arg in [node.lower, node.upper, node.step]
             ],
             keywords=[])
         return gast.Name(id=name, ctx=gast.Load(), annotation=None)
     elif isinstance(node, gast.ExtSlice):
         name = self.namer.name(node)
         target = gast.Name(id=name, ctx=gast.Store(), annotation=None)
         stmt = gast.Assign(targets=[target], value=None)
         self.prepend(stmt)
         dim_names = [self.trivialize_slice(s).id for s in node.dims]
         stmt.value = gast.Tuple(elts=[
             gast.Name(id=n, ctx=gast.Load(), annotation=None)
             for n in dim_names
         ],
                                 ctx=gast.Load())
         return gast.Name(id=name, ctx=gast.Load(), annotation=None)
     elif isinstance(node, gast.Index):
         return self.trivialize(node.value)
     else:
         raise ValueError(node)
 def visit_Return(self, node):
     modified_node = self.generic_visit(node)
     if node.value is None:
         node_value = gast.Constant(value=None, kind=None)
     else:
         node_value = node.value
     self.func_returned_stack[-1] = True
     returned_id = len(self.func_returned_stack)
     replacement = [
         gast.Assign(targets=[
             gast.Name(id=self.returned_flag + str(returned_id),
                       ctx=gast.Store(),
                       annotation=None,
                       type_comment=None)
         ],
                     value=gast.Constant(value=True, kind=None)),
         gast.Assign(targets=[
             gast.Name(id=self.returned_value_key,
                       ctx=gast.Store(),
                       annotation=None,
                       type_comment=None)
         ],
                     value=node_value)
     ]
     if isinstance(modified_node,
                   gast.If):  #TODO: Add location to returned value.
         modified_node.body = replacement
         return modified_node
     else:
         return replacement
 def visit_FunctionDef(self, node):
     modified_node = self.generic_visit(node)
     returned_id = len(self.func_returned_stack)
     returned_flags = self.func_returned_stack.pop()
     if returned_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=self.returned_flag + str(returned_id),
                           ctx=gast.Store(),
                           annotation=None,
                           type_comment=None)
             ],
                         value=gast.Constant(value=False, kind=None)))
     node.body.insert(
         0,
         gast.Assign(targets=[
             gast.Name(id=self.returned_value_key,
                       ctx=gast.Store(),
                       annotation=None,
                       type_comment=None)
         ],
                     value=gast.Constant(value=None, kind=None)))
     node.body.append(
         gast.Return(value=gast.Name(id=self.returned_value_key,
                                     ctx=gast.Load(),
                                     annotation=None,
                                     type_comment=None)))
     return modified_node
    def visit_For(self, node):
        modified_node = self.generic_visit(node)
        continued_id = len(self.for_continued_stack)
        continued_flags = self.for_continued_stack.pop()
        if continued_flags:
            node.body.insert(
                0,
                gast.Assign(targets=[
                    gast.Name(id=self.continued_flag + str(continued_id),
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.Constant(value=False, kind=None)))
        breaked_id = len(self.for_breaked_stack)
        breaked_flags = self.for_breaked_stack.pop()
        bool_values = []
        if breaked_flags:
            node.body.insert(
                0,
                gast.Assign(targets=[
                    gast.Name(id=self.breaked_flag + str(breaked_id),
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.Constant(value=False, kind=None)))
            bool_values.append(
                gast.Name(id=self.breaked_flag + str(breaked_id),
                          ctx=gast.Load(),
                          annotation=None,
                          type_comment=None))

        if len(self.func_returned_stack) > 0:
            returned_id = len(self.func_returned_stack)
            returned_flags = self.func_returned_stack[-1]
            if returned_flags:
                bool_values.append(
                    gast.Name(id=self.returned_flag + str(returned_id),
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None))

        if len(bool_values) > 0:
            if len(bool_values) == 1:
                cond = bool_values[0]
            elif len(bool_values) > 1:
                cond = gast.BoolOp(op=gast.Or(), values=bool_values)

            node.body.append(
                gast.Assign(targets=[
                    gast.Name(id=self.keepgoing_flag,
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.UnaryOp(op=gast.Not(), operand=cond)))
            node.body.append(gast.If(test=cond, body=[gast.Break()],
                                     orelse=[]))
        return modified_node
Exemple #8
0
 def visit_Assign(self, node):
     self.generic_visit(node)
     # if the rhs is an identifier, we don't need to duplicate it
     # otherwise, better duplicate it...
     no_tmp = isinstance(node.value, (ast.Name, ast.Attribute))
     extra_assign = [] if no_tmp else [node]
     for i, t in enumerate(node.targets):
         if isinstance(t, ast.Tuple) or isinstance(t, ast.List):
             renamings = OrderedDict()
             self.traverse_tuples(t, (), renamings)
             if renamings:
                 if no_tmp:
                     gstore = deepcopy(node.value)
                 else:
                     gstore = ast.Name(self.get_new_id(), ast.Store(), None,
                                       None)
                 gload = deepcopy(gstore)
                 gload.ctx = ast.Load()
                 node.targets[i] = gstore
                 for rename, state in renamings.items():
                     nnode = reduce(
                         lambda x, y: ast.Subscript(x, ast.Constant(
                             y, None), ast.Load()), state, gload)
                     if isinstance(rename, str):
                         extra_assign.append(
                             ast.Assign([
                                 ast.Name(rename, ast.Store(), None, None)
                             ], nnode, None))
                     else:
                         extra_assign.append(
                             ast.Assign([rename], nnode, None))
     return extra_assign or node
Exemple #9
0
  def visit_FunctionDef(self, node):
    """Intercepts function definitions.

    Converts function definitions to the corresponding `ProgramBuilder.function`
    construction.

    Args:
      node: An `ast.AST` node representing the function to convert.

    Returns:
      node: An updated node, representing the result.

    Raises:
      ValueError: If the input node does not adhere to the restrictions,
        e.g., failing to have a `return` statement at the end.
    """
    # Check input form
    return_node = node.body[-1]
    if not isinstance(return_node, gast.Return):
      msg = 'Last node in function body should be Return, not {}.'
      raise ValueError(msg.format(return_node))

    # Convert all args to _tfp_autobatching_context_.param()
    local_declarations = []
    for arg in node.args.args:
      # print('Creating param declaration for', arg, arg.id, type(arg.id))
      local_declarations.append(templates.replace(
          'target = _tfp_autobatching_context_.param(name=target_name)',
          target=arg.id,
          target_name=gast_util.Str(arg.id))[0])

    # Visit the content of the function
    node = self.generic_visit(node)

    # Prepend the declarations
    node.body = local_declarations + node.body

    # Convert the function into a
    # `with _tfp_autobatching_context_.define_function()` block.
    # Wrap the `with` block into a function so additional information (namely,
    # the auto-batching `ProgramBuilder` and the `instruction.Function`s that
    # may be called in the body) can be passed in through regular Python
    # variable references.
    callable_function_names = [
        gast_util.Name(n, ctx=gast.Store(), annotation=None)
        for n in self.known_functions]
    node = templates.replace(
        '''
        def func(_tfp_autobatching_context_,
                 _tfp_autobatching_available_functions_):
          names = _tfp_autobatching_available_functions_
          with _tfp_autobatching_context_.define_function(func):
            body
          return func''',
        func=node.name,
        names=gast.List(callable_function_names, ctx=gast.Store()),
        body=node.body)[0]

    return node
Exemple #10
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(ast.Name(forged_name, ast.Load(), None),
                            binded_args, [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)

            if is_trivially_copied(node.left):
                prev_holder = node.left
            else:
                body.append(
                    ast.Assign([ast.Name('$0', ast.Store(), None)], node.left))
                prev_holder = ast.Name('$0', ast.Load(), None)

            for i, exp in enumerate(node.comparators):
                if is_trivially_copied(exp):
                    holder = exp
                else:
                    body.append(
                        ast.Assign(
                            [ast.Name('${}'.format(i + 1), ast.Store(), None)],
                            exp))
                    holder = ast.Name('${}'.format(i + 1), ast.Load(), None)
                cond = ast.Compare(prev_holder, [node.ops[i]], [holder])
                body.append(
                    ast.If(
                        cond, [ast.Pass()],
                        [ast.Return(path_to_attr(('__builtin__', 'False')))]))
                prev_holder = holder

            body.append(ast.Return(path_to_attr(('__builtin__', 'True'))))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
Exemple #11
0
    def get_for_args_stmts(self, iter_name, args_list):
        '''
        Returns 3 gast stmt nodes for argument.
        1. Initailize of iterate variable
        2. Condition for the loop
        3. Statement for changing of iterate variable during the loop
        NOTE(TODO): Python allows to access iteration variable after loop, such
           as "for i in range(10)" will create i = 9 after the loop. But using
           current conversion will make i = 10. We should find a way to change it
        '''
        len_range_args = len(args_list)
        assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
        if len_range_args == 1:
            init_stmt = get_constant_variable_node(iter_name, 0)
        else:
            init_stmt = gast.Assign(
                targets=[
                    gast.Name(
                        id=iter_name,
                        ctx=gast.Store(),
                        annotation=None,
                        type_comment=None)
                ],
                value=args_list[0])

        range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
        step_node = args_list[2] if len_range_args == 3 else gast.Constant(
            value=1, kind=None)

        cond_stmt = gast.Compare(
            left=gast.BinOp(
                left=gast.Name(
                    id=iter_name,
                    ctx=gast.Load(),
                    annotation=None,
                    type_comment=None),
                op=gast.Add(),
                right=step_node),
            ops=[gast.LtE()],
            comparators=[range_max_node])

        change_stmt = gast.AugAssign(
            target=gast.Name(
                id=iter_name,
                ctx=gast.Store(),
                annotation=None,
                type_comment=None),
            op=gast.Add(),
            value=step_node)

        return init_stmt, cond_stmt, change_stmt
 def visit_For(self, node):
     modified_node = self.generic_visit(node)
     continue_flags = self.for_continue_stack.pop()
     for flag in continue_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=flag, ctx=gast.Store(), annotation=None)
             ],
                         value=gast.NameConstant(value=False)))
     breaked_flags = self.for_breaked_stack.pop()
     bool_values = []
     for flag in breaked_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=flag, ctx=gast.Store(), annotation=None)
             ],
                         value=gast.NameConstant(value=False)))
         bool_values.append(
             gast.Name(id=flag, ctx=gast.Load(), annotation=None))
     if len(bool_values) > 0:
         if len(bool_values) == 1:
             cond = bool_values[0]
         elif len(bool_values) > 1:
             cond = gast.BoolOp(op=gast.Or(), values=bool_values)
         if isinstance(modified_node, gast.For):
             modified_node.body.append(
                 gast.Assign(targets=[
                     gast.Name(id=self.keepgoing_flag,
                               ctx=gast.Store(),
                               annotation=None)
                 ],
                             value=gast.UnaryOp(op=gast.Not(),
                                                operand=cond)))
             modified_node.body.append(
                 gast.If(test=cond, body=[gast.Break()], orelse=[]))
         elif isinstance(modified_node, gast.If):
             if isinstance(modified_node.body[0], gast.For):
                 modified_node.body[0].body.append(
                     gast.Assign(targets=[
                         gast.Name(id=self.keepgoing_flag,
                                   ctx=gast.Store(),
                                   annotation=None)
                     ],
                                 value=gast.UnaryOp(op=gast.Not(),
                                                    operand=cond)))
                 modified_node.body[0].body.append(
                     gast.If(test=cond, body=[gast.Break()], orelse=[]))
     return modified_node
Exemple #13
0
 def visit_Assign(self, node):
     self.generic_visit(node)
     # if the rhs is an identifier, we don't need to duplicate it
     # otherwise, better duplicate it...
     no_tmp = isinstance(node.value, ast.Name)
     extra_assign = [] if no_tmp else [node]
     for i, t in enumerate(node.targets):
         if isinstance(t, ast.Tuple) or isinstance(t, ast.List):
             renamings = OrderedDict()
             self.traverse_tuples(t, (), renamings)
             if renamings:
                 gtarget = node.value.id if no_tmp else self.get_new_id()
                 node.targets[i] = ast.Name(gtarget,
                                            node.targets[i].ctx,
                                            None)
                 for rename, state in renamings.items():
                     nnode = reduce(
                         lambda x, y: ast.Subscript(
                             x,
                             ast.Index(ast.Num(y)),
                             ast.Load()),
                         state,
                         ast.Name(gtarget, ast.Load(), None))
                     if isinstance(rename, str):
                         extra_assign.append(
                             ast.Assign(
                                 [ast.Name(rename, ast.Store(), None)],
                                 nnode))
                     else:
                         extra_assign.append(ast.Assign([rename], nnode))
     return extra_assign or node
Exemple #14
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact because it uses
    # regex matching to adjust the exact location around the line number that
    # CPython records. This is particularly problematic for lambda functions,
    # where the entire containing lines are returned.
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
        if f.__name__ == '<lambda>':
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        else:
            raise ValueError(
                'Unable to identify source code of function {}. The source code'
                ' reported by Python did not include exactly one matching signature:'
                '\n{}\n. This is an extremely rare occurrence. Please report it to'
                ' the TensorFlow team.'.format(f, source))
    node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    else:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name, did_rename = namer.compiled_function_name(
            f.__name__, f, owner_type)
        if did_rename:
            node.name = new_name
        else:
            new_name = f.__name__
            assert node.name == new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace
Exemple #15
0
 def synchronize_lcds(self, node):
     node = FuseAttributes().visit(node)
     loads, lcds = defaultdict(list), set()
     for child in node.body:
         for n in gast.walk(child):
             if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load):
                 loads[n.id].append(n)
         if isinstance(child, gast.Assign):
             name = child.targets[0].id
             if name in loads:
                 if name in lcds:
                     raise NotImplementedError("cannot process LCD "
                                               "stored to twice")
                 lcds.add(name)
     node = SplitAttributes().visit(node)
     synchronizes = []
     for name in lcds:
         synchronize = gast.Assign(
             [gast.Name(name, gast.Store(), None)],
             gast.Call(
                 gast.Attribute(
                     gast.Name(name, gast.Load(), None),
                     gast.Name('_synchronize', gast.Load(), None), None),
                 [], []))
         synchronizes.append(synchronize)
     node.body.extend(synchronizes)
     return node
Exemple #16
0
def get_annotations(object_def, namespace):
    """Create the annotations from a definition node"""

    # print_dump(object_def)

    ast_annotations = ast.Assign(
        targets=[extast.Name("annotations", ast.Store())],
        value=ast.Dict(keys=[], values=[]),
        type_comment=None,
    )

    if isinstance(object_def, ast.FunctionDef):
        _fill_ast_annotations_function(object_def, ast_annotations)
    elif isinstance(object_def, ast.ClassDef):
        _fill_ast_annotations_class(object_def, ast_annotations)
    else:
        raise NotImplementedError

    # print_dump(ast_annotations)

    source = extast.unparse(ast_annotations)

    try:
        del namespace["__builtins__"]
    except KeyError:
        pass
    exec(source, namespace)
    return namespace["annotations"]
Exemple #17
0
    def _replace_after_node_to_if_in_stmt_list(
            self, stmt_list, node, return_name, parent_node_of_return):
        i = index_in_list(stmt_list, node)
        if i < 0 or i >= len(stmt_list):
            return False
        if i == len(stmt_list) - 1:
            # No need to add, we consider this as added successfully
            return True

        if_stmt = gast.If(test=gast.UnaryOp(
            op=gast.Not(),
            operand=gast.Name(
                id=return_name,
                ctx=gast.Store(),
                annotation=None,
                type_comment=None)),
                          body=stmt_list[i + 1:],
                          orelse=[])

        stmt_list[i + 1:] = [if_stmt]

        # Here assume that the parent node of return is gast.If
        if isinstance(parent_node_of_return, gast.If):
            # Prepend control flow boolean nodes such as '__return@1 = False'
            node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format(
                return_name,
                ast_to_source_code(parent_node_of_return.test).strip())
            assign_false_node = gast.parse(node_str).body[0]

            stmt_list[i:i] = [assign_false_node]
        return True
Exemple #18
0
    def _process_variable_assignment(self, source, targets):
        if isinstance(source, gast.Call):
            func = source.func
            if anno.hasanno(func, 'live_val'):
                func_obj = anno.getanno(func, 'live_val')
                if tf_inspect.isclass(func_obj):
                    anno.setanno(source, 'is_constructor', True)
                    anno.setanno(source, 'type', func_obj)
                    anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
                    # TODO(mdan): Raise an error if constructor has side effects.
                    # We can have a whitelist of no-side-effects constructors.
                    # We can also step inside the constructor and further analyze.

        for t in targets:
            if isinstance(t, gast.Tuple):
                for i, e in enumerate(t.elts):
                    self.scope.setval(
                        anno.getanno(e, anno.Basic.QN),
                        gast.Subscript(source, gast.Index(i),
                                       ctx=gast.Store()))
            elif isinstance(t, (gast.Name, gast.Attribute)):
                self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
            else:
                raise ValueError('Dont know how to handle assignment to %s' %
                                 t)
Exemple #19
0
    def visit_Assign(self, node):
        self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            target = node.value.func
            if anno.hasanno(target, 'live_val'):
                target_obj = anno.getanno(target, 'live_val')
                if tf_inspect.isclass(target_obj):
                    # This is then a constructor.
                    anno.setanno(node.value, 'type', target_obj)
                    anno.setanno(node.value, 'type_fqn',
                                 anno.getanno(target, 'fqn'))
                    # TODO (mdan): Raise an error if constructor has side effects. id:2153 gh:2154
                    # We can have a whitelist of no-side-effects constructors.
                    # We can also step inside the constructor and further analyze.

        for n in node.targets:
            if isinstance(n, gast.Tuple):
                for i, e in enumerate(n.elts):
                    self.scope.setval(
                        e.id,
                        gast.Subscript(node.value,
                                       gast.Index(i),
                                       ctx=gast.Store()))
            else:
                self.scope.setval(n.id, node.value)

        return node
Exemple #20
0
    def visit_Call(self, node):
        """
        Replace function call by inlined function's body.

        We can inline if it aliases on only one function.
        """
        func_aliases = self.aliases[node.func]
        if len(func_aliases) == 1:
            function_def = next(iter(func_aliases))
            if (isinstance(function_def, ast.FunctionDef)
                    and function_def.name in self.inlinable):
                self.update = True
                to_inline = copy.deepcopy(self.inlinable[function_def.name])
                arg_to_value = dict()
                values = node.args
                values += to_inline.args.defaults[len(node.args) -
                                                  len(to_inline.args.args):]
                for arg_fun, arg_call in zip(to_inline.args.args, values):
                    v_name = "__pythran_inline{}{}{}".format(
                        function_def.name, arg_fun.id, self.call_count)
                    new_var = ast.Name(id=v_name,
                                       ctx=ast.Store(),
                                       annotation=None,
                                       type_comment=None)
                    self.defs.append(
                        ast.Assign(targets=[new_var], value=arg_call))
                    arg_to_value[arg_fun.id] = ast.Name(id=v_name,
                                                        ctx=ast.Load(),
                                                        annotation=None,
                                                        type_comment=None)
                self.call_count += 1
                return Inliner(arg_to_value).visit(to_inline.body[0])
        return node
Exemple #21
0
 def visit_loop(self, node, update_mask=gast.NameConstant(value=None)):
     node = FuseAttributes().visit(node)
     loads, stores = defaultdict(list), set()
     for child in node.body:
         for n in gast.walk(child):
             if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load):
                 loads[n.id].append(n)
         if isinstance(child, gast.Assign):
             if len(child.targets) > 1:
                 raise NotImplementedError("cannot process LCD that is "
                                           "part of multiple assignment")
             name = child.targets[0].id
             if name in loads:
                 if name in stores:
                     raise NotImplementedError("cannot process LCD "
                                               "stored to twice")
                 # $var = $expr -> $var = $var._update($expr)
                 child.value = gast.Call(
                     gast.Attribute(gast.Name(name, gast.Load(), None),
                                    gast.Name('_update', gast.Load(), None),
                                    None), [child.value, update_mask], [])
                 stores.add(name)
     node = SplitAttributes().visit(node)
     synchronizes = []
     for name in stores:
         synchronize = gast.Assign(
             [gast.Name(name, gast.Store(), None)],
             gast.Call(
                 gast.Attribute(
                     gast.Name(name, gast.Load(), None),
                     gast.Name('_synchronize', gast.Load(), None), None),
                 [], []))
         synchronizes.append(synchronize)
     node.body.extend(synchronizes)
     return node
Exemple #22
0
def create_while_node(condition_name, body_name, loop_var_names):
    while_args = []
    while_args.append(
        gast.Name(id=condition_name,
                  ctx=gast.Param(),
                  annotation=None,
                  type_comment=None))
    while_args.append(
        gast.Name(id=body_name,
                  ctx=gast.Param(),
                  annotation=None,
                  type_comment=None))
    assign_targets = [
        gast.Name(id=var_name,
                  ctx=gast.Param(),
                  annotation=None,
                  type_comment=None) for var_name in loop_var_names
    ]
    while_args.append(gast.List(elts=assign_targets, ctx=gast.Param()))

    while_func_id = gast.parse('fluid.layers.while_loop').body[0].value
    while_node = gast.Call(func=while_func_id, args=while_args, keywords=[])
    assign_node = gast.Assign(
        targets=[gast.Tuple(elts=assign_targets, ctx=gast.Store())],
        value=while_node)
    return assign_node
Exemple #23
0
    def visit_For(self, node):
        target = node.target
        if isinstance(target, ast.Tuple) or isinstance(target, ast.List):
            renamings = OrderedDict()
            self.traverse_tuples(target, (), renamings)
            if renamings:
                gtarget = self.get_new_id()
                node.target = ast.Name(gtarget, node.target.ctx, None)
                for rename, state in renamings.items():
                    nnode = reduce(
                        lambda x, y: ast.Subscript(
                            x,
                            ast.Index(ast.Num(y)),
                            ast.Load()),
                        state,
                        ast.Name(gtarget, ast.Load(), None))
                    if isinstance(rename, str):
                        node.body.insert(0,
                                         ast.Assign(
                                             [ast.Name(rename,
                                                       ast.Store(),
                                                       None)],
                                             nnode)
                                         )
                    else:
                        node.body.insert(0, ast.Assign([rename], nnode))

        self.generic_visit(node)
        return node
    def visit_FunctionDef(self, node):
        self.update = True
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.ctx.module.body.insert(0, import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.ctx.module.body.append(node)

        former_name = node.name
        seed = 0
        new_name = "pythran_{}{}"

        while new_name.format(former_name, seed) in self.identifiers:
            seed += 1

        new_name = new_name.format(former_name, seed)
        self.identifiers.add(new_name)

        ii = self.gather(ImportedIds, node)
        binded_args = [
            ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)
        ]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None, None)
             for iin in sorted(ii)] + node.args.args)

        metadata.add(node, metadata.Local())

        class Renamer(ast.NodeTransformer):
            def visit_Call(self, node):
                self.generic_visit(node)
                if (isinstance(node.func, ast.Name)
                        and node.func.id == former_name):
                    node.func.id = new_name
                    node.args = ([
                        ast.Name(iin, ast.Load(), None, None)
                        for iin in sorted(ii)
                    ] + node.args)
                return node

        Renamer().visit(node)

        node.name = new_name
        self.global_declarations[node.name] = node
        proxy_call = ast.Name(new_name, ast.Load(), None, None)

        new_node = ast.Assign([ast.Name(former_name, ast.Store(), None, None)],
                              ast.Call(
                                  ast.Attribute(
                                      ast.Name(mangle('functools'), ast.Load(),
                                               None, None), "partial",
                                      ast.Load()),
                                  [proxy_call] + binded_args,
                                  [],
                              ))

        self.generic_visit(node)
        return new_node
Exemple #25
0
 def visit_Assign(self, node):
     self.src = quoting.unquote(node)
     self.mark(node)
     self.trivializing = True
     self.namer.target = node.targets[0]
     if isinstance(node.targets[0], (gast.Subscript, gast.Attribute)):
         node.value = self.trivialize(node.value)
         node.targets[0] = self.visit(node.targets[0])
     elif isinstance(node.targets[0], gast.Tuple):
         node.value = self.visit(node.value)
         name = self.namer.name(node.targets[0])
         target = gast.Name(id=name, ctx=gast.Store(), annotation=None)
         for i, elt in enumerate(node.targets[0].elts):
             stmt = gast.Assign(targets=[elt],
                                value=gast.Subscript(
                                    value=gast.Name(id=name,
                                                    ctx=gast.Load(),
                                                    annotation=None),
                                    slice=gast.Index(value=gast.Num(n=i)),
                                    ctx=gast.Load()))
             self.mark(stmt)
             self.append(stmt)
         node.targets[0] = target
     elif not isinstance(node.targets[0], gast.Name):
         raise ValueError
     node = self.generic_visit(node)
     self.namer.target = None
     self.trivializing = False
     return node
    def make_control_flow_handlers(self, cont_n, status_n, expected_return,
                                   has_cont, has_break):
        '''
        Create the statements in charge of gathering control flow information
        for the static_if result, and executes the expected control flow
        instruction
        '''
        if expected_return:
            assign = cont_ass = [ast.Assign(
                [ast.Tuple(expected_return, ast.Store())],
                ast.Name(cont_n, ast.Load(), None, None), None)]
        else:
            assign = cont_ass = []

        if has_cont:
            cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
                               [ast.Eq()], [ast.Constant(LOOP_CONT, None)])
            cont_ass = [ast.If(cmpr,
                               deepcopy(assign) + [ast.Continue()],
                               cont_ass)]
        if has_break:
            cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
                               [ast.Eq()], [ast.Constant(LOOP_BREAK, None)])
            cont_ass = [ast.If(cmpr,
                               deepcopy(assign) + [ast.Break()],
                               cont_ass)]
        return cont_ass
Exemple #27
0
 def _build_enum_increase_node(self):
     return gast.AugAssign(target=gast.Name(id=self.enum_idx_name,
                                            ctx=gast.Store(),
                                            annotation=None,
                                            type_comment=None),
                           op=gast.Add(),
                           value=gast.Constant(value=1, kind=None))
Exemple #28
0
def create_assign_node(name, node):
    """
    Creates a `gast.Assign` node by given name_id as target and node as value.
    """
    targets = generate_name_node(name, ctx=gast.Store())
    assign_node = gast.Assign(targets=[targets], value=node)
    return targets, assign_node
Exemple #29
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
    else:
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace)
    context = converter.EntityContext(namer, entity_info, program_ctx,
                                      new_name)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
            gast.Name(new_name,
                      ctx=gast.Store(),
                      annotation=None,
                      type_comment=None)
        ],
                           value=node)
    elif do_rename:
        node.name = new_name
    else:
        assert node.name == new_name

    return (node, ), new_name, entity_info
Exemple #30
0
 def _process_tuple_assignment(self, source, t):
     for i, e in enumerate(t.elts):
         if isinstance(e, gast.Tuple):
             self._process_tuple_assignment(source, e)
         else:
             self.scope.setval(
                 anno.getanno(e, anno.Basic.QN),
                 gast.Subscript(source, gast.Index(i), ctx=gast.Store()))