Example #1
0
    def test_rename_symbols_function(self):
        node = parser.parse('def f():\n  pass')
        node = ast_util.rename_symbols(
            node, {qual_names.QN('f'): qual_names.QN('f1')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'def f1():\n    pass')
Example #2
0
    def test_replace_name_with_subscript(self):
        template = """
        foo = bar
    """
        replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key'))

        node = templates.replace(template, foo=replacement)[0].targets[0]
        self.assertIsInstance(node.ctx, gast.Store)
        self.assertIsInstance(node.value.ctx, gast.Load)
    def test_rename_symbols_annotations(self):
        node = parser.parse_str('a[i]')
        node = qual_names.resolve(node)
        anno.setanno(node, 'foo', 'bar')
        orig_anno = anno.getanno(node, 'foo')

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('b')})

        self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Example #4
0
    def test_rename_symbols_basic(self):
        node = parser.parse('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.value.left.id, str)
        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'renamed_a + b')
    def test_rename_symbols_basic(self):
        node = parser.parse_str('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.body[0].value.left.id, str)
        source = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_a + b')
Example #6
0
    def visit_alias(self, node):
        node = self.generic_visit(node)

        if node.asname is None:
            # Only the root name is a real symbol operation.
            qn = qual_names.QN(node.name.split('.')[0])
        else:
            qn = qual_names.QN(node.asname)

        self.scope.modified.add(qn)
        self.scope.bound.add(qn)
        return node
Example #7
0
    def test_rename_symbols_basic(self):
        node = parser.parse('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
        source = parser.unparse(node, include_encoding_marker=False)
        expected_node_src = 'renamed_a + b'

        self.assertIsInstance(node.value.left.id, str)
        self.assertAstMatches(node, source)
        self.assertAstMatches(node, expected_node_src)
Example #8
0
    def visit_FunctionDef(self, node):
        f_name = qual_names.QN(node.name)

        if node.decorator_list:
            raise NotImplementedError('decorators: {}'.format(
                node.decorator_list))

        ret_types = None
        if node.returns:
            ret_types, _ = self.resolver.res_name(
                self.namespace, self.types_in.types,
                anno.Basic.QN.of(node.returns))
            if __debug__:
                self._check_set(ret_types)

        if ret_types is None:
            ret_types = {Any}

        f_types = set()
        for rt in ret_types:
            f_types.add(Callable[[Any], rt])

        self.new_symbols[f_name] = f_types
        # The definition of a function is an expression, hence has no return value.
        return None
Example #9
0
    def visit_ClassDef(self, node):
        # The ClassDef node itself has a Scope object that tracks the creation
        # of its name, along with the usage of any decorator accompanying it.
        self._enter_scope(False)
        node.decorator_list = self.visit_block(node.decorator_list)
        self.scope.modified.add(qual_names.QN(node.name))
        self.scope.bound.add(qual_names.QN(node.name))
        node.bases = self.visit_block(node.bases)
        node.keywords = self.visit_block(node.keywords)
        self._exit_and_record_scope(node)

        # A separate Scope tracks the actual class definition.
        self._enter_scope(True)
        node = self.generic_visit(node)
        self._exit_scope()
        return node
Example #10
0
 def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
   if f_name == 'test_fn':
     test_self.assertFalse(f_is_local)
     test_self.assertEqual(name, qual_names.QN('a'))
     test_self.assertEqual(type_anno, qual_names.QN('int'))
   elif f_name == 'foo':
     test_self.assertTrue(f_is_local)
     if name == qual_names.QN('x'):
       test_self.assertEqual(type_anno, qual_names.QN('float'))
     elif name == qual_names.QN('y'):
       test_self.assertIsNone(type_anno)
     else:
       test_self.fail('unexpected argument {} for {}'.format(name, f_name))
   else:
     test_self.fail('unexpected function name {}'.format(f_name))
   return {str(name) + '_type'}
Example #11
0
    def visit_FunctionDef(self, node):
        # The FunctionDef node itself has a Scope object that tracks the creation
        # of its name, along with the usage of any decorator accompanying it.
        self._enter_scope(False)
        node.decorator_list = self.visit_block(node.decorator_list)
        self.scope.mark_modified(qual_names.QN(node.name))
        anno.setanno(node, anno.Static.SCOPE, self.scope)
        self._exit_scope()

        # A separate Scope tracks the actual function definition.
        self._enter_scope(True)
        assert not (self._in_function_def_args or self.state[_Lambda].level)
        self._in_function_def_args = True
        node.args = self.visit(node.args)
        self._in_function_def_args = False

        # Track the body separately. This is for compatibility reasons, it may not
        # be strictly needed.
        self._enter_scope(False)
        node.body = self.visit_block(node.body)
        anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
        self._exit_scope()

        self._exit_scope()
        return node
Example #12
0
    def visit_node(self, node):
        prev_defs_out = self.out[node]

        defs_in = _NodeState(self.extra_in.get(node.ast_node, None))
        for n in node.prev:
            defs_in |= self.out[n]

        if anno.hasanno(node.ast_node, anno.Static.SCOPE):
            node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
            # The definition objects created by each node must be singletons because
            # their ids are used in equality checks.
            if node not in self.gen_map:
                node_symbols = {}
                for s in node_scope.modified:
                    def_ = self._definition_factory()
                    if s in node_scope.params:
                        def_.param_of = weakref.ref(node_scope.params[s])
                    node_symbols[s] = def_
                self.gen_map[node] = _NodeState(node_symbols)

            gen = self.gen_map[node]
            kill = node_scope.modified | node_scope.deleted
            defs_out = gen | (defs_in - kill)

        elif isinstance(node.ast_node, (gast.Global, gast.Nonlocal)):
            # Special case for global and nonlocal: they generate a definition,
            # but are not tracked by activity analysis.
            if node not in self.gen_map:
                node_symbols = {}
                for s in node.ast_node.names:
                    qn = qual_names.QN(s)
                    if qn in defs_in.value:
                        # In Python 2, this is a syntax warning. In Python 3, it's an error.
                        raise ValueError(
                            '"{}" is assigned before global definition'.format(
                                s))
                    def_ = self._definition_factory()
                    node_symbols[qn] = def_
                self.gen_map[node] = _NodeState(node_symbols)

            gen = self.gen_map[node]
            defs_out = defs_in | gen

        else:
            # Nodes that don't have a scope annotation are assumed not to touch any
            # symbols.
            # This Name node below is a literal name, e.g. False
            # This can also happen if activity.py forgot to annotate the node with a
            # scope object.
            assert isinstance(node.ast_node,
                              (gast.Name, gast.Break, gast.Continue,
                               gast.Raise, gast.Pass)), (node.ast_node, node)
            defs_out = defs_in

        self.in_[node] = defs_in
        self.out[node] = defs_out

        # TODO(mdan): Move this to the superclass?
        return prev_defs_out != defs_out
Example #13
0
 def visit_Nonlocal(self, node):
     self._enter_scope(False)
     for name in node.names:
         qn = qual_names.QN(name)
         self.scope.read.add(qn)
         self.scope.bound.add(qn)
     self._exit_and_record_scope(node)
     return node
    def test_rename_symbols_attributes(self):
        node = parser.parse_str('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Example #15
0
    def test_rename_symbols_attributes(self):
        node = parser.parse('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Example #16
0
    def test_rename_symbols_global(self):
        node = parser.parse('global a, b, c')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b'): qual_names.QN('renamed_b')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'global a, renamed_b, c')
Example #17
0
  def res_call(self, ns, types_ns, node, f_type, args, keywords):
    name = anno.Basic.QN.of(node.func)
    if f_type == (TFRTypes.AG_BUILTIN_FUNC,):

      if name == QN(QN('ag__'), attr='if_stmt'):
        nouts = node.args[6].value
        # TODO(mdan): Look at the actual types out of if_body.
        side_effects = {
            qual_names.QN(n.value): {TFRTypes.TENSOR}
            for n in node.args[5].elts[:nouts]
        }
        return {type(None)}, side_effects

      if name == QN(QN('ag__'), attr='for_stmt'):
        assert isinstance(node.args[2], ast.Name)
        body_fn_name = str(anno.Basic.QN.of(node.args[2]))
        assert body_fn_name not in self._for_loop_body_fns, (
            'Previously used here: {}. Are you reusing the Resolver across '
            'transformations?').format(self._for_loop_body_fns[body_fn_name])
        self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node)

        iterated_type = args[0]
        assert iterated_type & {
            TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, List[int]
        }, (
            iterated_type)
        self._for_loop_target_types[body_fn_name] = iterated_type

        return {type(None)}, None

      # TODO(mdan): Actually resolve the type here instead.
      ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None)
      if ret_type is not None:
        return {ret_type}, None
      raise NotImplementedError('return type of {}'.format(name))

    elif f_type == (TFRTypes.TF_RAW_OP,):
      op_name = name.qn[1]
      op_def, _ = self._op_defs.lookup(op_name)
      if len(op_def.output_arg) == 1:
        return {_get_type_from_proto(op_def.output_arg[0])}, None
      return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
              None)

    elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
      assert name.is_simple()
      if name == QN('range'):
        return {List[int]}, None

      if name == QN('len'):
        return {TFRTypes.INDEX}, None

    elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,):
      return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None

    raise NotImplementedError('Function:', name, f_type)
Example #18
0
  def visit_ClassDef(self, node):
    # The ClassDef node itself has a Scope object that tracks the creation
    # of its name, along with the usage of any decorator accompanying it.
    self._enter_scope(False)
    node.decorator_list = self.visit_block(node.decorator_list)
    self.scope.mark_modified(qual_names.QN(node.name))
    anno.setanno(node, anno.Static.SCOPE, self.scope)
    self._exit_scope()

    # A separate Scope tracks the actual class definition.
    self._enter_scope(True)
    assert not (self._in_function_def_args or self.state[_Lambda].level)
    node = self.generic_visit(node)
    self._exit_scope()
    return node
Example #19
0
    def visit_FunctionDef(self, node):
        # The FunctionDef node itself has a Scope object that tracks the creation
        # of its name, along with the usage of any decorator accompanying it.
        self._enter_scope(False)
        node.decorator_list = self.visit_block(node.decorator_list)
        function_name = qual_names.QN(node.name)
        self.scope.modified.add(function_name)
        self.scope.bound.add(function_name)
        self._exit_and_record_scope(node)

        # A separate Scope tracks the actual function definition.
        self._enter_scope(True)
        node.args = self.visit(node.args)

        # Track the body separately. This is for compatibility reasons, it may not
        # be strictly needed.
        self._enter_scope(False)
        node.body = self.visit_block(node.body)
        self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE)

        self._exit_scope()
        return node
Example #20
0
    def visit_FunctionDef(self, node):
        with self.state[_FunctionOrClass] as fn:
            fn.node = node
            # The FunctionDef node itself has a Scope object that tracks the creation
            # of its name, along with the usage of any decorator accompanying it.
            self._enter_scope(False)
            node.decorator_list = self.visit_block(node.decorator_list)
            if node.returns:
                node.returns = self._process_annotation(node.returns)
            # Argument annotartions (includeing defaults) affect the defining context.
            node = self._visit_arg_annotations(node)

            function_name = qual_names.QN(node.name)
            self.scope.modified.add(function_name)
            self.scope.bound.add(function_name)
            self._exit_and_record_scope(node)

            # A separate Scope tracks the actual function definition.
            self._enter_scope(True)

            # Keep a separate scope for the arguments node, which is used in the CFG.
            self._enter_scope(False)

            # Arg declarations only affect the function itself, and have no effect
            # in the defining context whatsoever.
            node = self._visit_arg_declarations(node)

            self._exit_and_record_scope(node.args)

            # Track the body separately. This is for compatibility reasons, it may not
            # be strictly needed.
            self._enter_scope(False)
            node.body = self.visit_block(node.body)
            self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE)

            self._exit_and_record_scope(node, NodeAnno.ARGS_AND_BODY_SCOPE)
            return node
Example #21
0
 def res_arg(self, ns, types_ns, f_name, name, type_anno,
             f_is_local):
     if name == qual_names.QN('a'):
         return {int}
     else:
         return {float}
Example #22
0
 def res_name(self, ns, types_ns, name):
     test_self.assertEqual(name, qual_names.QN('g'))
     return {Callable[[Callable], None]}, g
Example #23
0
 def res_call(self, ns, types_ns, node, f_type, args, keywords):
     test_self.assertEqual(node.func.id, 'g')
     test_self.assertEqual(f_type, (Callable[[Callable], None], ))
     return None, {qual_names.QN('x'): {str}}
Example #24
0
def _live_tensors(f, attr_name="inputs"):
    """Returns the indices of the used inputs.

  Note: This currently only handles direct index accesses e.g. op.inputs[1].
  If the function has slicing or list comprehension on attr_name then returns
  _ALL. This ensure that this is correct even if inefficient.

  Args:
    f: A grad function, taking the op as first argument.
    attr_name: op attr to track. "inputs" or "outputs".

  Returns:
    Either one of:
      * set of integers representing individual indices of inputs used
      * the value _ALL, if indices are used but cannot be determined which
      * empty set, if no inputs are used
  """
    node, _ = parser.parse_entity(f, ())
    entity_info = transformer.EntityInfo(
        name=f.__name__,
        source_code=None,
        source_file=None,
        future_features=(),
        namespace=sys.modules[f.__module__].__dict__)
    ctx = transformer.Context(entity_info, None, None)

    graphs = cfg.build(node)
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)
    node = reaching_fndefs.resolve(node, ctx, graphs)
    node = liveness.resolve(node, ctx, graphs)

    op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN)
    op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name)

    special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name, ))
    node = special_tracker.visit(node)

    live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN)
    inputs_outputs_used_qns = set()
    for v in special_tracker.complex_reads:
        # Complicated patterns like op.inputs[:3]. Could be smarter about them
        # if they matter much.
        if v == op_inputs_outputs_name:
            return _ALL
    for v in live_vars_in:
        if v in special_tracker.reads:
            if (v.has_subscript() and v.parent == op_inputs_outputs_name):
                inputs_outputs_used_qns.add(v)
            elif v == op_inputs_outputs_name:
                # When op.{attr_name} is used directly, assume all tensors are
                # used for now. In that case, no point digging further.
                # TODO(mdan): We can descend into tuple expansions.
                return _ALL

    function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name)
    node = function_calls_tracker.visit(node)

    input_output_indices = set()

    for called_f in function_calls_tracker.calls:
        child_indices = _live_tensors(called_f, attr_name=attr_name)
        if child_indices is _ALL:
            return _ALL
        input_output_indices |= child_indices

    for v in inputs_outputs_used_qns:
        assert v.has_subscript()
        _, subscript = v.qn
        if not subscript.is_simple():
            # Not a number, assuming it can be anything.
            return _ALL
        subscript_val, = subscript.qn
        if (not isinstance(subscript_val, qual_names.Literal)
                and not isinstance(subscript_val.value, int)):
            # Not a number, assuming it can be anything.
            return _ALL
        input_output_indices.add(subscript_val.value)
    return input_output_indices
    def visit_Expr(self, node):
        self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            # Patterns of single function calls, like:
            #   opt.minimize(loss)
            # or:
            #   tf.py_func(...)

            # First, attempt to gate future evaluation of args. If that's not
            # possible, gate all remaining statements (and that may fail too, see
            # _visit_and_reindent.
            args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
            # NOTE: We can't guard object attributes because they may not be writable.
            # In addition, avoid renaming well-known names.
            # TODO(mdan): Move these names into config.
            unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
            guarded_args = tuple(
                s for s in args_scope.read
                if not s.is_composite() and s not in unguarded_names)

            # TODO(mdan): Include all arguments which depended on guarded_args too.
            # For example, the following will still cause a race:
            #   tf.assign(a, a + 1)
            #   b = a + 1
            #   tf.assign(a, a + 1)  # Control deps here should include `b`
            #   c = b + 1
            # Or maybe we should just raise an "unsafe assign" error?

            if guarded_args:
                # The aliases may need new names to avoid incorrectly making them local.
                # TODO(mdan): This is brutal. It will even rename modules - any fix?
                need_alias = tuple(s for s in guarded_args
                                   if s not in args_scope.parent.modified)
                aliased_new_names = tuple(
                    qual_names.QN(
                        self.ctx.namer.new_symbol(
                            s.ssf(), args_scope.parent.referenced))
                    for s in need_alias)
                alias_map = dict(zip(need_alias, aliased_new_names))
                if len(guarded_args) == 1:
                    s, = guarded_args
                    aliased_guarded_args = alias_map.get(s, s)
                else:
                    aliased_guarded_args = gast.Tuple(
                        [alias_map.get(s, s).ast() for s in guarded_args],
                        None)

                template = """
          with ag__.utils.control_dependency_on_returns(call):
            aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
        """
                control_deps_guard = templates.replace(
                    template,
                    call=node.value,
                    aliased_guarded_args=aliased_guarded_args,
                    guarded_args=guarded_args)[-1]
            else:
                alias_map = {}

                template = """
          with ag__.utils.control_dependency_on_returns(call):
            pass
        """
                control_deps_guard = templates.replace(template,
                                                       call=node.value)[-1]
                control_deps_guard.body = []

            node = control_deps_guard
            anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
                         (node.body, alias_map))
        return node
Example #26
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        nodes, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            do_rename=False)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = nodes[0]
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
Example #27
0
def convert_class_to_ast(c, program_ctx):
    """Specialization of `convert_entity_to_ast` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('cannot convert %s: no member methods' % c)

    # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
    # The assumption that one namespace suffices for all methods only holds if
    # all methods were defined in the same module.
    # If, instead, functions are imported from multiple modules and then spliced
    # into the class, then each function has its own globals and __future__
    # imports that need to stay separate.

    # For example, C's methods could both have `global x` statements referring to
    # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
    # from mod1 import f1
    # from mod2 import f2
    # class C(object):
    #   method1 = f1
    #   method2 = f2

    class_namespace = {}
    future_features = None
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        (node, ), _, entity_info = convert_func_to_ast(m,
                                                       program_ctx=program_ctx,
                                                       do_rename=False)
        class_namespace.update(entity_info.namespace)
        converted_members[m] = node

        # TODO(mdan): Similarly check the globals.
        if future_features is None:
            future_features = entity_info.future_features
        elif frozenset(future_features) ^ frozenset(
                entity_info.future_features):
            # Note: we can support this case if ever needed.
            raise ValueError(
                'cannot convert {}: if has methods built with mismatched future'
                ' features: {} and {}'.format(c, future_features,
                                              entity_info.future_features))
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    # TODO(mdan): Find a way better than forging this object.
    entity_info = transformer.EntityInfo(source_code=None,
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=class_namespace)

    return output_nodes, class_name, entity_info
Example #28
0
 def _process_list_of_strings(self, names):
     for i in range(len(names)):
         qn = qual_names.QN(names[i])
         if qn in self.name_map:
             names[i] = str(self.name_map[qn])
     return names
Example #29
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node[0]
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
Example #30
0
        assert isinstance(other, _SymbolTable)
        result = _SymbolTable(self)
        for s, other_types in other.value.items():
            if s not in result.value:
                self_types = set()
                result.value[s] = self_types
            else:
                self_types = result.value[s]
            self_types.update(other_types)
        return result

    def __repr__(self):
        return 'SymbolTable {}'.format(self.value)


_GETITEM = qual_names.QN('__getitem__')

_HANDLERS = {
    gast.Eq:
    qual_names.QN('__eq__'),
    gast.NotEq:
    qual_names.QN('__ne__'),
    gast.Lt:
    qual_names.QN('__lt__'),
    gast.LtE:
    qual_names.QN('__le__'),
    gast.Gt:
    qual_names.QN('__gt__'),
    gast.GtE:
    qual_names.QN('__ge__'),
    gast.In: