Exemplo n.º 1
0
    def _update_name_to_var_shape(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        value_node = node.value

        if isinstance(target_node, gast.Tuple):
            has_updated = False
            for idx, element in enumerate(target_node.elts):
                target_id = ast_to_source_code(element).strip()

                if isinstance(value_node, gast.Name):
                    if value_node.id in self.name_to_var_shape:
                        index_value_node = gast.Constant(value=idx, kind=None)
                        slice_index_node = gast.Index(value=index_value_node)
                        var_shape_node = self.name_to_var_shape[value_node.id]
                        sub_node = gast.Subscript(value=var_shape_node,
                                                  slice=slice_index_node,
                                                  ctx=gast.Load())
                        self.name_to_var_shape[target_id] = sub_node
                        has_updated = True
                if isinstance(value_node, gast.Attribute):
                    if self.is_var_shape(value_node):  # eg: x.shape
                        index_value_node = gast.Constant(value=idx, kind=None)
                        slice_index_node = gast.Index(value=index_value_node)
                        sub_node = gast.Subscript(value=value_node,
                                                  slice=slice_index_node,
                                                  ctx=gast.Load())
                        self.name_to_var_shape[target_id] = sub_node
                        has_updated = True

            return has_updated
        else:
            target_id = ast_to_source_code(target_node).strip()

            if isinstance(value_node, gast.Name):
                if value_node.id in self.name_to_var_shape:
                    self.name_to_var_shape[target_id] = self.name_to_var_shape[
                        value_node.id]
                    return True
            if isinstance(value_node, gast.Attribute):
                if self.is_var_shape(value_node):  # eg: x.shape
                    self.name_to_var_shape[target_id] = value_node
                    return True
            if isinstance(value_node, gast.Subscript):
                if isinstance(value_node.value, gast.Attribute):
                    if self.is_var_shape(value_node.value):  # eg: x.shape[0]
                        self.name_to_var_shape[target_id] = value_node
                        return True
        return False
Exemplo n.º 2
0
 def visit_Assign(self, node):
     self.src = quoting.unquote(node)
     self.mark(node)
     self.trivializing = True
     self.namer.target = node.targets[0]
     if isinstance(node.targets[0], (gast.Subscript, gast.Attribute)):
         node.value = self.trivialize(node.value)
         node.targets[0] = self.visit(node.targets[0])
     elif isinstance(node.targets[0], gast.Tuple):
         node.value = self.visit(node.value)
         name = self.namer.name(node.targets[0])
         target = gast.Name(id=name, ctx=gast.Store(), annotation=None)
         for i, elt in enumerate(node.targets[0].elts):
             stmt = gast.Assign(targets=[elt],
                                value=gast.Subscript(
                                    value=gast.Name(id=name,
                                                    ctx=gast.Load(),
                                                    annotation=None),
                                    slice=gast.Index(value=gast.Num(n=i)),
                                    ctx=gast.Load()))
             self.mark(stmt)
             self.append(stmt)
         node.targets[0] = target
     elif not isinstance(node.targets[0], gast.Name):
         raise ValueError
     node = self.generic_visit(node)
     self.namer.target = None
     self.trivializing = False
     return node
Exemplo n.º 3
0
 def visit_Assign(self, node):
     self.generic_visit(node)
     # if the rhs is an identifier, we don't need to duplicate it
     # otherwise, better duplicate it...
     no_tmp = isinstance(node.value, (ast.Name, ast.Attribute))
     extra_assign = [] if no_tmp else [node]
     for i, t in enumerate(node.targets):
         if isinstance(t, ast.Tuple) or isinstance(t, ast.List):
             renamings = OrderedDict()
             self.traverse_tuples(t, (), renamings)
             if renamings:
                 if no_tmp:
                     gstore = deepcopy(node.value)
                 else:
                     gstore = ast.Name(self.get_new_id(), ast.Store(), None,
                                       None)
                 gload = deepcopy(gstore)
                 gload.ctx = ast.Load()
                 node.targets[i] = gstore
                 for rename, state in renamings.items():
                     nnode = reduce(
                         lambda x, y: ast.Subscript(x, ast.Constant(
                             y, None), ast.Load()), state, gload)
                     if isinstance(rename, str):
                         extra_assign.append(
                             ast.Assign([
                                 ast.Name(rename, ast.Store(), None, None)
                             ], nnode, None))
                     else:
                         extra_assign.append(
                             ast.Assign([rename], nnode, None))
     return extra_assign or node
Exemplo n.º 4
0
 def ast(self):
     # The caller must adjust the context appropriately.
     if self.has_subscript():
         return gast.Subscript(self.parent.ast(), str(self.qn[-1]), None)
     if self.has_attr():
         return gast.Attribute(self.parent.ast(), self.qn[-1], None)
     return gast.Name(self.qn[0], None, None)
Exemplo n.º 5
0
    def ast(self):
        """AST representation."""
        # The caller must adjust the context appropriately.
        if self.has_subscript():
            return gast.Subscript(value=self.parent.ast(),
                                  slice=gast.Index(self.qn[-1].ast()),
                                  ctx=CallerMustSetThis)
        if self.has_attr():
            return gast.Attribute(value=self.parent.ast(),
                                  attr=self.qn[-1],
                                  ctx=CallerMustSetThis)

        base = self.qn[0]
        if isinstance(base, str):
            return gast.Name(base,
                             ctx=CallerMustSetThis,
                             annotation=None,
                             type_comment=None)
        elif isinstance(base, StringLiteral):
            return gast.Constant(base.value, kind=None)
        elif isinstance(base, NumberLiteral):
            return gast.Constant(base.value, kind=None)
        else:
            assert False, ('the constructor should prevent types other than '
                           'str, StringLiteral and NumberLiteral')
Exemplo n.º 6
0
    def visit_Assign(self, node):
        self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            target = node.value.func
            if anno.hasanno(target, 'live_val'):
                target_obj = anno.getanno(target, 'live_val')
                if tf_inspect.isclass(target_obj):
                    # This is then a constructor.
                    anno.setanno(node.value, 'type', target_obj)
                    anno.setanno(node.value, 'type_fqn',
                                 anno.getanno(target, 'fqn'))
                    # TODO (mdan): Raise an error if constructor has side effects. id:2153 gh:2154
                    # We can have a whitelist of no-side-effects constructors.
                    # We can also step inside the constructor and further analyze.

        for n in node.targets:
            if isinstance(n, gast.Tuple):
                for i, e in enumerate(n.elts):
                    self.scope.setval(
                        e.id,
                        gast.Subscript(node.value,
                                       gast.Index(i),
                                       ctx=gast.Store()))
            else:
                self.scope.setval(n.id, node.value)

        return node
Exemplo n.º 7
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     # method name -> not a getattr
     if node.attr in methods:
         return node
     # imported module -> not a getattr
     elif (isinstance(node.value, ast.Name)
           and node.value.id in self.imports):
         module_id = self.imports[node.value.id]
         if node.attr not in MODULES[self.renamer(module_id, MODULES)[1]]:
             msg = ("`" + node.attr + "' is not a member of " + module_id +
                    " or Pythran does not support it")
             raise PythranSyntaxError(msg, node)
         node.value.id = module_id  # patch module aliasing
         self.update = True
         return node
     # not listed as attributed -> not a getattr
     elif node.attr not in attributes:
         return node
     # A getattr !
     else:
         self.update = True
         call = ast.Call(
             ast.Attribute(ast.Name('builtins', ast.Load(), None, None),
                           'getattr', ast.Load()),
             [node.value, ast.Constant(node.attr, None)], [])
         if isinstance(node.ctx, ast.Store):
             # the only situation where this arises is for real/imag of
             # a ndarray. As a call is not valid for a store, add a slice
             # to ends up with a valid lhs
             assert node.attr in ('real', 'imag'), "only store to imag/real"
             return ast.Subscript(call, ast.Slice(None, None, None),
                                  node.ctx)
         else:
             return call
Exemplo n.º 8
0
    def _process_variable_assignment(self, source, targets):
        if isinstance(source, gast.Call):
            func = source.func
            if anno.hasanno(func, 'live_val'):
                func_obj = anno.getanno(func, 'live_val')
                if tf_inspect.isclass(func_obj):
                    anno.setanno(source, 'is_constructor', True)
                    anno.setanno(source, 'type', func_obj)
                    anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
                    # TODO(mdan): Raise an error if constructor has side effects.
                    # We can have a whitelist of no-side-effects constructors.
                    # We can also step inside the constructor and further analyze.

        for t in targets:
            if isinstance(t, gast.Tuple):
                for i, e in enumerate(t.elts):
                    self.scope.setval(
                        anno.getanno(e, anno.Basic.QN),
                        gast.Subscript(source, gast.Index(i),
                                       ctx=gast.Store()))
            elif isinstance(t, (gast.Name, gast.Attribute)):
                self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
            else:
                raise ValueError('Dont know how to handle assignment to %s' %
                                 t)
Exemplo n.º 9
0
    def visit_Subscript(self, node):
        '''
        Resulting node alias stores the subscript relationship if we don't know
        anything about the subscripted node.

        >>> from pythran import passmanager
        >>> pm = passmanager.PassManager('demo')
        >>> module = ast.parse('def foo(a): return a[0]')
        >>> result = pm.gather(Aliases, module)
        >>> Aliases.dump(result, filter=ast.Subscript)
        a[0] => ['a[0]']

        If we know something about the container, e.g. in case of a list, we
        can use this information to get more accurate informations:

        >>> module = ast.parse('def foo(a, b, c): return [a, b][c]')
        >>> result = pm.gather(Aliases, module)
        >>> Aliases.dump(result, filter=ast.Subscript)
        [a, b][c] => ['a', 'b']

        Moreover, in case of a tuple indexed by a constant value, we can
        further refine the aliasing information:

        >>> fun = """
        ... def f(a, b): return a, b
        ... def foo(a, b): return f(a, b)[0]"""
        >>> module = ast.parse(fun)
        >>> result = pm.gather(Aliases, module)
        >>> Aliases.dump(result, filter=ast.Subscript)
        f(a, b)[0] => ['a']

        Nothing is done for slices, even if the indices are known :-/

        >>> module = ast.parse('def foo(a, b, c): return [a, b, c][1:]')
        >>> result = pm.gather(Aliases, module)
        >>> Aliases.dump(result, filter=ast.Subscript)
        [a, b, c][1:] => ['<unbound-value>']
        '''
        if isinstance(node.slice, ast.Index):
            aliases = set()
            self.visit(node.slice)
            value_aliases = self.visit(node.value)
            for alias in value_aliases:
                if isinstance(alias, ContainerOf):
                    if isinstance(node.slice.value, ast.Slice):
                        continue
                    if isnum(node.slice.value):
                        if node.slice.value.value != alias.index:
                            continue
                    # FIXME: what if the index is a slice variable...
                    aliases.add(alias.containee)
                elif isinstance(getattr(alias, 'ctx', None), ast.Param):
                    aliases.add(ast.Subscript(alias, node.slice, node.ctx))
            if not aliases:
                aliases = None
        else:
            # could be enhanced through better handling of containers
            aliases = None
            self.generic_visit(node)
        return self.add(node, aliases)
Exemplo n.º 10
0
def create_convert_shape_node(var_shape_node,
                              slice_node=None,
                              in_control_flow=False):
    assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))

    if isinstance(var_shape_node, gast.Attribute):
        args = [ast_to_source_code(var_shape_node.value).strip()]
        # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant
        # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant
        # In (1) case, we pass the number as 'idx' argument in convert_var_shape
        # In (2) case, we have to make it like `convert_var_shape(x)[slice]`
        if slice_node is not None and slice_is_num(slice_node):
            args.append(ast_to_source_code(slice_node.slice).strip())

        convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
            ",".join(args), in_control_flow)
        api_shape_node = gast.parse(convert_var_shape_func).body[0].value

        if slice_node is not None and not slice_is_num(slice_node):
            return gast.Subscript(value=api_shape_node,
                                  slice=slice_node.slice,
                                  ctx=gast.Load())
        return api_shape_node

    if isinstance(var_shape_node, gast.Subscript):
        result_node = copy.deepcopy(var_shape_node)
        result_node = create_convert_shape_node(result_node.value, result_node,
                                                in_control_flow)
        return result_node
Exemplo n.º 11
0
            def parametrize(exp):
                # constant(?) or global -> no change
                if isinstance(exp, (ast.Index, Intrinsic, ast.FunctionDef)):
                    return lambda _: {exp}
                elif isinstance(exp, ContainerOf):
                    pcontainee = parametrize(exp.containee)
                    index = exp.index
                    return lambda args: {
                        ContainerOf(pc, index)
                        for pc in pcontainee(args)
                    }
                elif isinstance(exp, ast.Name):
                    try:
                        w = node.args.args.index(exp)

                        def return_alias(args):
                            if w < len(args):
                                return {args[w]}
                            else:
                                return {node.args.defaults[w - len(args)]}
                        return return_alias
                    except ValueError:
                        return lambda _: self.get_unbound_value_set()
                elif isinstance(exp, ast.Subscript):
                    values = parametrize(exp.value)
                    slices = parametrize(exp.slice)
                    return lambda args: {
                        ast.Subscript(value, slice, ast.Load())
                        for value in values(args)
                        for slice in slices(args)}
                else:
                    return lambda _: self.get_unbound_value_set()
Exemplo n.º 12
0
    def visit_For(self, node):
        target = node.target
        if isinstance(target, ast.Tuple) or isinstance(target, ast.List):
            renamings = OrderedDict()
            self.traverse_tuples(target, (), renamings)
            if renamings:
                gtarget = self.get_new_id()
                node.target = ast.Name(gtarget, node.target.ctx, None)
                for rename, state in renamings.items():
                    nnode = reduce(
                        lambda x, y: ast.Subscript(
                            x,
                            ast.Index(ast.Num(y)),
                            ast.Load()),
                        state,
                        ast.Name(gtarget, ast.Load(), None))
                    if isinstance(rename, str):
                        node.body.insert(0,
                                         ast.Assign(
                                             [ast.Name(rename,
                                                       ast.Store(),
                                                       None)],
                                             nnode)
                                         )
                    else:
                        node.body.insert(0, ast.Assign([rename], nnode))

        self.generic_visit(node)
        return node
Exemplo n.º 13
0
 def visit_Assign(self, node):
     self.generic_visit(node)
     # if the rhs is an identifier, we don't need to duplicate it
     # otherwise, better duplicate it...
     no_tmp = isinstance(node.value, ast.Name)
     extra_assign = [] if no_tmp else [node]
     for i, t in enumerate(node.targets):
         if isinstance(t, ast.Tuple) or isinstance(t, ast.List):
             renamings = OrderedDict()
             self.traverse_tuples(t, (), renamings)
             if renamings:
                 gtarget = node.value.id if no_tmp else self.get_new_id()
                 node.targets[i] = ast.Name(gtarget,
                                            node.targets[i].ctx,
                                            None)
                 for rename, state in renamings.items():
                     nnode = reduce(
                         lambda x, y: ast.Subscript(
                             x,
                             ast.Index(ast.Num(y)),
                             ast.Load()),
                         state,
                         ast.Name(gtarget, ast.Load(), None))
                     if isinstance(rename, str):
                         extra_assign.append(
                             ast.Assign(
                                 [ast.Name(rename, ast.Store(), None)],
                                 nnode))
                     else:
                         extra_assign.append(ast.Assign([rename], nnode))
     return extra_assign or node
Exemplo n.º 14
0
 def _process_tuple_assignment(self, source, t):
     for i, e in enumerate(t.elts):
         if isinstance(e, gast.Tuple):
             self._process_tuple_assignment(source, e)
         else:
             self.scope.setval(
                 anno.getanno(e, anno.Basic.QN),
                 gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
Exemplo n.º 15
0
 def visit_Name(self, node):
     if node.id in self.renamings:
         nnode = reduce(
             lambda x, y: ast.Subscript(x, ast.Constant(y, None), ast.Load(
             )), self.renamings[node.id],
             ast.Name(self.tuple_id, ast.Load(), None, None))
         nnode.ctx = node.ctx
         return nnode
     return node
Exemplo n.º 16
0
    def visit_Attribute(self, node):
        node = self.generic_visit(node)
        # method name -> not a getattr
        if node.attr in methods:

            # Make sure parent is'nt a call, it's already handled in visit_Call
            for parent in reversed(self.ancestors.get(node, ())):
                if isinstance(parent, ast.Attribute):
                    continue
                if isinstance(parent, ast.Call):
                    return node
                break

            # we have a bound method which is not a call
            obj = self.baseobj(node)
            if obj is not None:
                self.update = True
                mod = methods[node.attr][0]
                self.to_import.add(mangle(mod[0]))
                func = self.attr_to_func(node)
                z = ast.Call(
                    ast.Attribute(
                        ast.Name(mangle('functools'), ast.Load(), None, None),
                        "partial", ast.Load()), [func, obj], [])
                return z
            else:
                return node
        # imported module -> not a getattr
        elif (isinstance(node.value, ast.Name)
              and node.value.id in self.imports):
            module_id = self.imports[node.value.id]
            if node.attr not in MODULES[self.renamer(module_id, MODULES)[1]]:
                msg = ("`" + node.attr + "' is not a member of " +
                       demangle(module_id) + " or Pythran does not support it")
                raise PythranSyntaxError(msg, node)
            node.value.id = module_id  # patch module aliasing
            self.update = True
            return node
        # not listed as attributed -> not a getattr
        elif node.attr not in attributes:
            return node
        # A getattr !
        else:
            self.update = True
            call = ast.Call(
                ast.Attribute(ast.Name('builtins', ast.Load(), None, None),
                              'getattr', ast.Load()),
                [node.value, ast.Constant(node.attr, None)], [])
            if isinstance(node.ctx, ast.Store):
                # the only situation where this arises is for real/imag of
                # a ndarray. As a call is not valid for a store, add a slice
                # to ends up with a valid lhs
                assert node.attr in ('real', 'imag'), "only store to imag/real"
                return ast.Subscript(call, ast.Slice(None, None, None),
                                     node.ctx)
            else:
                return call
Exemplo n.º 17
0
 def _build_var_slice_node(self):
     return gast.Subscript(
         value=self.iter_node,
         slice=gast.Index(value=gast.Name(
             id=self.iter_idx_name,
             ctx=gast.Load(),
             annotation=None,
             type_comment=None)),
         ctx=gast.Load())
Exemplo n.º 18
0
 def visit_Subscript(self, node):
     new_slice = self._visit(node.slice)
     new_node = gast.Subscript(
         self._visit(node.value),
         new_slice,
         self._visit(node.ctx),
     )
     gast.copy_location(new_node, node)
     new_node.end_lineno = new_node.end_col_offset = None
     return new_node
Exemplo n.º 19
0
 def _build_assign_var_slice_node(self):
     var_slice_node = gast.Subscript(
         value=self.iter_node,
         slice=gast.Index(value=gast.Name(id=self.iter_idx_name,
                                          ctx=gast.Load(),
                                          annotation=None,
                                          type_comment=None)),
         ctx=gast.Load(),
     )
     new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX)
     target_node, assign_node = create_assign_node(new_iter_var_name,
                                                   var_slice_node)
     return target_node, assign_node
Exemplo n.º 20
0
 def visit_Assign(self, node):
     self.visit(node.value)
     for t in node.targets:
         # We don't support subscript aliasing
         self.combine(t, node.value, register=True,
                      aliasing_type=isinstance(t, ast.Name))
         if t in self.curr_locals_declaration:
             self.result[t] = self.get_qualifier(t)(self.result[t])
         if isinstance(t, ast.Subscript):
             if self.visit_AssignedSubscript(t):
                 for alias in self.strict_aliases[t.value]:
                     fake = ast.Subscript(alias, t.value, ast.Store())
                     self.combine(fake, node.value, register=True)
Exemplo n.º 21
0
    def apply_to_single_assignments(self, targets, values, apply_fn):
        """Applies a function to each individual assignment.

    This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
    It tries to break down the unpacking if possible. In effect, it has the same
    effect as passing the assigned values in SSA form to apply_fn.

    Examples:

    The following will result in apply_fn(a, c), apply_fn(b, d):

        a, b = c, d

    The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):

        a, b = c

    The following will result in apply_fn(a, (b, c)):

        a = b, c

    It uses the visitor pattern to allow subclasses to process single
    assignments individually.

    Args:
      targets: list, tuple of or individual AST node. Should be used with the
        targets field of an ast.Assign node.
      values: an AST node.
      apply_fn: a function of a single argument, which will be called with the
        respective nodes of each single assignment. The signature is
        apply_fn(target, value), no return value.
    """
        if not isinstance(targets, (list, tuple)):
            targets = (targets, )
        for target in targets:
            if isinstance(target, (gast.Tuple, gast.List)):
                for i in range(len(target.elts)):
                    target_el = target.elts[i]
                    if isinstance(values, (gast.Tuple, gast.List)):
                        value_el = values.elts[i]
                    else:
                        value_el = gast.Subscript(values,
                                                  gast.Index(i),
                                                  ctx=gast.Store())
                    self.apply_to_single_assignments(target_el, value_el,
                                                     apply_fn)
            else:
                # TODO(mdan): Look into allowing to rewrite the AST here.
                apply_fn(target, values)
Exemplo n.º 22
0
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
    eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format(
        api_shape_name)
    args = [attr_shape_name, eval_exist_func]

    if slice_node is not None and slice_is_num(slice_node):
        args.append(ast_to_source_code(slice_node.slice).strip())
    choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
        ",".join(args))
    choose_shape_node = gast.parse(choose_shape_func).body[0].value
    if slice_node is not None and not slice_is_num(slice_node):
        return gast.Subscript(value=choose_shape_node,
                              slice=slice_node.slice,
                              ctx=gast.Load())
    return choose_shape_node
Exemplo n.º 23
0
 def _build_compare_node(self):
     if self.is_for_range_iter():
         compare_node = self.iter_args[
             0] if self.args_length == 1 else self.iter_args[1]
     else:
         compare_node = gast.Subscript(
             value=gast.Name(
                 id=self.iter_var_shape_name,
                 ctx=gast.Load(),
                 annotation=None,
                 type_comment=None),
             slice=gast.Index(value=gast.Constant(
                 value=0, kind=None)),
             ctx=gast.Load())
     return compare_node
Exemplo n.º 24
0
    def visit_Assign(self, node):
        """
        >>> import gast as ast
        >>> from pythran import passmanager, backend
        >>> pm = passmanager.PassManager("test")
        >>> node = ast.parse("def foo(a): b = a[2:];return b[0], b[1]")
        >>> _, node = pm.apply(ForwardSubstitution, node)
        >>> print(pm.dump(backend.Python, node))
        def foo(a):
            b = a
            return (b[2:][0], b[2:][1])
        """
        if not isinstance(node.value, ast.Subscript):
            return node
        if not isinstance(node.value.slice, ast.Slice):
            return node
        if not isinstance(node.targets[0], ast.Name):
            return node
        name = node.targets[0].id
        udgraph = self.use_def_chain[name]
        uses = [
            udgraph.node[n]['name'] for n in udgraph.nodes()
            if udgraph.node[n]['action'] == 'U'
        ]
        can_fold = True
        targets = []
        for use in uses:
            parent = self.ancestors[use][-1]
            if not isinstance(parent, ast.Subscript):
                can_fold = False
                break
            if not isinstance(parent.slice, ast.Index):
                can_fold = False
                break
            if not isinstance(parent.slice.value, ast.Num):
                can_fold = False
                break

            targets.append(parent)

        if can_fold:
            slice_ = node.value.slice
            node.value = node.value.value
            for target in targets:
                target.value = ast.Subscript(target.value, deepcopy(slice_),
                                             ast.Load())
            self.update = True
        return node
Exemplo n.º 25
0
def apply_to_single_assignments(targets, values, apply_fn):
    """Applies a function to each individual assignment.

  This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
  It tries to break down the unpacking if possible. In effect, it has the same
  effect as passing the assigned values in SSA form to apply_fn.

  Examples:

  The following will result in apply_fn(a, c), apply_fn(b, d):

      a, b = c, d

  The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):

      a, b = c

  The following will result in apply_fn(a, (b, c)):

      a = b, c

  It uses the visitor pattern to allow subclasses to process single
  assignments individually.

  Args:
    targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
      used with the targets field of an ast.Assign node
    values: ast.AST
    apply_fn: Callable[[ast.AST, ast.AST], None], called with the respective
      nodes of each single assignment
  """
    if not isinstance(targets, (list, tuple)):
        targets = (targets, )
    for target in targets:
        if isinstance(target, (gast.Tuple, gast.List)):
            for i in range(len(target.elts)):
                target_el = target.elts[i]
                if isinstance(values, (gast.Tuple, gast.List)):
                    value_el = values.elts[i]
                else:
                    idx = parsing.parse_expression(str(i))
                    value_el = gast.Subscript(values,
                                              gast.Index(idx),
                                              ctx=gast.Load())
                apply_to_single_assignments(target_el, value_el, apply_fn)
        else:
            apply_fn(target, values)
Exemplo n.º 26
0
def create_grad(node, namer, tangent=False):
    """Given a variable, create a variable for the gradient.

  Args:
    node: A node to create a gradient for, can be a normal variable (`x`) or a
        subscript (`x[i]`).
    namer: The namer object which will determine the name to use for the
        gradient.
    tangent: Whether a tangent (instead of adjoint) is created.

  Returns:
    node: A node representing the gradient with the correct name e.g. the
        gradient of `x[i]` is `dx[i]`.

        Note that this returns an invalid node, with the `ctx` attribute
        missing. It is assumed that this attribute is filled in later.

        Node has an `adjoint_var` annotation referring to the node it is an
        adjoint of.
  """
    if not isinstance(node, (gast.Subscript, gast.Name, gast.Str)):
        raise TypeError

    if anno.hasanno(node, 'temp_var'):
        return create_grad(anno.getanno(node, 'temp_var'), namer, tangent)

    def _name_grad(node):
        if not isinstance(node, gast.Name):
            raise TypeError
        varname = node.id
        name = namer.grad(varname, tangent)
        grad_node = gast.Name(id=name, ctx=None, annotation=None)
        anno.setanno(grad_node, 'adjoint_var', node)
        return grad_node

    if isinstance(node, gast.Subscript):
        grad_node = create_grad(node.value, namer, tangent=tangent)
        grad_node.ctx = gast.Load()
        return gast.Subscript(value=grad_node, slice=node.slice, ctx=None)
    elif isinstance(node, gast.Str):
        grad_node = create_grad(gast.Name(id=node.s, ctx=None,
                                          annotation=None),
                                namer,
                                tangent=tangent)
        return gast.Str(grad_node.id)
    else:
        return _name_grad(node)
Exemplo n.º 27
0
    def ast(self):
        # The caller must adjust the context appropriately.
        if self.has_subscript():
            return gast.Subscript(self.parent.ast(),
                                  gast.Index(self.qn[-1].ast()), None)
        if self.has_attr():
            return gast.Attribute(self.parent.ast(), self.qn[-1], None)

        base = self.qn[0]
        if isinstance(base, str):
            return gast.Name(base, None, None)
        elif isinstance(base, StringLiteral):
            return gast.Str(base.value)
        elif isinstance(base, NumberLiteral):
            return gast.Num(base.value)
        else:
            assert False, ('the constructor should prevent types other than '
                           'str, StringLiteral and NumberLiteral')
Exemplo n.º 28
0
 def visit_AugAssign(self, node):
     self.visit(node.value)
     if isinstance(node.target, ast.Subscript):
         if self.visit_AssignedSubscript(node.target):
             for alias in self.strict_aliases[node.target.value]:
                 fake = ast.Subscript(alias, node.target.value, ast.Store())
                 # We don't check more aliasing as it is a fake node.
                 self.combine(fake,
                              node.value,
                              lambda x, y: x + self.builder.ExpressionType(
                                  operator_to_lambda[type(node.op)],
                                  (x, y)),
                              register=True)
     # We don't support aliasing on subscript
     self.combine(node.target,
                  node.value,
                  lambda x, y: x + self.builder.ExpressionType(
                      operator_to_lambda[type(node.op)], (x, y)),
                  register=True,
                  aliasing_type=isinstance(node.target, ast.Name))
Exemplo n.º 29
0
  def _process_variable_assignment(self, source, targets):
    if isinstance(source, gast.Call):
      func = source.func
      if anno.hasanno(func, 'live_val'):
        func_obj = anno.getanno(func, 'live_val')
        if tf_inspect.isclass(func_obj):
          # This is then a constructor.
          anno.setanno(source, 'type', func_obj)
          anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
          # TODO(mdan): Raise an error if constructor has side effects.
          # We can have a whitelist of no-side-effects constructors.
          # We can also step inside the constructor and further analyze.

    for t in targets:
      if isinstance(t, gast.Tuple):
        for i, e in enumerate(t.elts):
          self.scope.setval(e.id,
                            gast.Subscript(
                                source, gast.Index(i), ctx=gast.Store()))
      else:
        self.scope.setval(t.id, source)
Exemplo n.º 30
0
    def _process_variable_assignment(self, source, targets):
        # Special case: constructors.
        if isinstance(source, gast.Call):
            func = source.func
            if anno.hasanno(func, 'live_val'):
                func_obj = anno.getanno(func, 'live_val')
                if tf_inspect.isclass(func_obj):
                    anno.setanno(source, 'is_constructor', True)
                    anno.setanno(source, 'type', func_obj)
                    anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
                    # TODO (mdan): Raise an error if constructor has side effects. id:554
                    # https://github.com/imdone/tensorflow/issues/555
                    # We can have a whitelist of no-side-effects constructors.
                    # We can also step inside the constructor and further analyze.

        # Multiple targets mean multiple assignment.
        for target in targets:
            # Tuple target means unpacking.
            if isinstance(target, (gast.Tuple, gast.List)):
                for i, target_item in enumerate(target.elts):
                    # Two cases here:
                    #   1. Static unpacking, e.g. a, b = c, d
                    #   2. Dynamic unpacking, e.g. a, b = c
                    # The former case is optimized away.
                    if isinstance(source, (gast.Tuple, gast.List)):
                        source_item = source.elts[i]
                    else:
                        source_item = gast.Subscript(source,
                                                     gast.Index(i),
                                                     ctx=None)
                    self._process_variable_assignment(source_item,
                                                      (target_item, ))
            elif isinstance(target, (gast.Name, gast.Attribute)):
                target_symbol = anno.getanno(target, anno.Basic.QN)
                self.scope.setval(target_symbol, source)
            else:
                raise ValueError('assignment target has unknown type: %s' %
                                 target)