Exemplo n.º 1
0
    def test_rename_symbols_basic(self):
        node = parser.parse_str('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.body[0].value.left.id, str)
        self.assertEqual(compiler.ast_to_source(node).strip(), 'renamed_a + b')
Exemplo n.º 2
0
    def test_rename_symbols_annotations(self):
        node = parser.parse_str('a[i]')
        node = qual_names.resolve(node)
        anno.setanno(node, 'foo', 'bar')
        orig_anno = anno.getanno(node, 'foo')

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('b')})

        self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Exemplo n.º 3
0
 def _check_anno_matches(self, node, anno_name, var_names):
   if isinstance(var_names, str):
     var_names = (var_names,)
   qual_vars = set()
   for var_name in var_names:
     if isinstance(var_name, str):
       if '[' in var_name or ']' in var_name:
         raise ValueError('Annotation matching not supported with subscript.')
       if '.' not in var_name:
         qual_vars.add(qual_names.QN(var_name))
       else:
         attrs = var_name.split('.')
         this_qn = functools.reduce(qual_names.QN, attrs[1:],
                                    qual_names.QN(attrs[0]))
         qual_vars.add(this_qn)
   self.assertEqual(anno.getanno(node, anno_name), qual_vars)
Exemplo n.º 4
0
    def test_rename_symbols_attributes(self):
        node = parser.parse_str('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source, _ = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Exemplo n.º 5
0
 def visit_FunctionDef(self, node):
     if self.scope:
         qn = qual_names.QN(node.name)
         self.scope.mark_write(qn)
     current_scope = self.scope
     body_scope = Scope(current_scope, isolated=True)
     self.scope = body_scope
     self.generic_visit(node)
     anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope)
     self.scope = current_scope
     return node
Exemplo n.º 6
0
    def test_rename_symbols(self):
        node = ast.Tuple([
            ast.Name('a', ast.Load()),
            ast.Name('b', ast.Load()),
            ast.Attribute(ast.Name('b', None), 'c', ast.Store()),
            ast.Attribute(ast.Attribute(ast.Name('b', None), 'c', ast.Load()),
                          'd', None)
        ], None)
        node = qual_names.resolve(node)
        node = ast_util.rename_symbols(
            node, {
                qual_names.QN('a'):
                qual_names.QN('renamed_a'),
                qual_names.QN(qual_names.QN('b'), attr='c'):
                qual_names.QN('renamed_b_c'),
            })

        self.assertEqual(node.elts[0].id, 'renamed_a')
        self.assertTrue(isinstance(node.elts[0].ctx, ast.Load))
        self.assertEqual(node.elts[1].id, 'b')
        self.assertEqual(node.elts[2].id, 'renamed_b_c')
        self.assertTrue(isinstance(node.elts[2].ctx, ast.Store))
        self.assertEqual(node.elts[3].value.id, 'renamed_b_c')
        self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
  def visit_FunctionDef(self, node):
    # The FunctionDef node itself has a Scope object that tracks the creation
    # of its name, along with the usage of any decorator accompany it.
    self._enter_scope(False)
    node.decorator_list = self.visit_block(node.decorator_list)
    self.scope.mark_write(qual_names.QN(node.name))
    anno.setanno(node, anno.Static.SCOPE, self.scope)
    self._exit_scope()

    # A separate Scope tracks the actual function definition.
    self._enter_scope(True)
    node.args = self.visit(node.args)

    # Track the body separately. This is for compatibility reasons, it may not
    # be strictly needed.
    self._enter_scope(False)
    node.body = self.visit_block(node.body)
    anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
    self._exit_scope()

    self._exit_scope()
    return node
Exemplo n.º 8
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the sueprclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    bases = []
    for base in c.__bases__:
        if isinstance(object, base):
            bases.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        bases.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    output_nodes.append(
        gast.ClassDef(class_name,
                      bases=bases,
                      keywords=[],
                      body=list(converted_members.values()),
                      decorator_list=[]))
    node = gast.Module(output_nodes)

    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    node = qual_names.resolve(node)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    node = ast_util.rename_symbols(node, renames)

    return node, class_name, class_namespace
Exemplo n.º 9
0
  def visit_Expr(self, node):
    self.generic_visit(node)
    if isinstance(node.value, gast.Call):
      # Patterns of single function calls, like:
      #   opt.minimize(loss)
      # or:
      #   tf.py_func(...)

      # First, attempt to gate future evaluation of args. If that's not
      # possible, gate all remaining statements (and that may fail too, see
      # _visit_and_reindent.
      args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
      # NOTE: We can't guard object attributes because they may not be writable.
      # In addition, avoid renaming well-known names.
      # TODO(mdan): Move these names into config.
      unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
      guarded_args = tuple(s for s in args_scope.used
                           if not s.is_composite() and s not in unguarded_names)

      # TODO(mdan): Include all arguments which depended on guarded_args too.
      # For example, the following will still cause a race:
      #   tf.assign(a, a + 1)
      #   b = a + 1
      #   tf.assign(a, a + 1)  # Control deps here should include `b`
      #   c = b + 1
      # Or maybe we should just raise an "unsafe assign" error?

      if guarded_args:
        # The aliases may need new names to avoid incorrectly making them local.
        # TODO(mdan): This is brutal. It will even rename modules - any fix?
        need_alias = tuple(
            s for s in guarded_args if s not in args_scope.parent.modified)
        aliased_new_names = tuple(
            qual_names.QN(
                self.context.namer.new_symbol(
                    s.ssf(), args_scope.parent.referenced)) for s in need_alias)
        alias_map = dict(zip(need_alias, aliased_new_names))
        if len(guarded_args) == 1:
          s, = guarded_args
          aliased_guarded_args = alias_map.get(s, s)
        else:
          aliased_guarded_args = gast.Tuple(
              [alias_map.get(s, s).ast() for s in guarded_args], None)

        template = """
          with ag__.utils.control_dependency_on_returns(call):
            aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
        """
        control_deps_guard = templates.replace(
            template,
            call=node.value,
            aliased_guarded_args=aliased_guarded_args,
            guarded_args=guarded_args)[-1]
      else:
        alias_map = {}

        template = """
          with ag__.utils.control_dependency_on_returns(call):
            pass
        """
        control_deps_guard = templates.replace(template, call=node.value)[-1]
        control_deps_guard.body = []

      node = control_deps_guard
      anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
                   (node.body, alias_map))
    return node