def replace(template, **replacements): """Replaces placeholders in a Python template. AST Name and Tuple nodes always receive the context that inferred from the template. However, when replacing more complex nodes (that can potentially contain Name children), then the caller is responsible for setting the appropriate context. Args: template: A string representing Python code. Any symbol name can be used that appears in the template code can be used as placeholder. **replacements: A mapping from placeholder names to (lists of) AST nodes that these placeholders will be replaced by. String values are also supported as a shorthand for AST Name nodes with the respective ID. Returns: An AST node or list of AST nodes with the replacements made. If the template was a function, a list will be returned. If the template was a node, the same node will be returned. If the template was a string, an AST node will be returned (a `Module` node in the case of a multi-line string, an `Expr` node otherwise). Raises: ValueError: if the arguments are incorrect. """ if not isinstance(template, str): raise ValueError('Expected string template, got %s' % type(template)) tree = parser.parse_str(textwrap.dedent(template)) for k in replacements: replacements[k] = _convert_to_ast(replacements[k]) results = ReplaceTransformer(replacements).visit(tree).body if isinstance(results, list): return [qual_names.resolve(r) for r in results] return qual_names.resolve(results)
def transform(node, ctx, default_to_null_return=True): """Ensure a function has only a single return, at the end.""" node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) # Note: Technically, these two could be merged into a single walk, but # keeping them separate helps with readability. node = ConditionalReturnRewriter(ctx).visit(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) transformer = ReturnStatementsTransformer( ctx, allow_missing_return=default_to_null_return) node = transformer.visit(node) return node
def transform(node, ctx): node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) transformer = BreakTransformer(ctx) node = transformer.visit(node) return node
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo( name=test_fn.__name__, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = transformer.Context(entity_info, namer, program_ctx) origin_info.resolve_entity(node, source, test_fn) graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_definitions.resolve(node, ctx, graphs) anno.dup( node, { anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) return node, ctx
def transform_ast(self, node, ctx): # TODO(mdan): Insert list_comprehensions somewhere. unsupported_features_checker.verify(node) # Run initial analysis. graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_definitions.resolve(node, ctx, graphs) anno.dup( node, { anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) node = functions.transform(node, ctx) node = directives.transform(node, ctx) node = break_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = continue_statements.transform(node, ctx) node = return_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.LISTS): node = lists.transform(node, ctx) node = slices.transform(node, ctx) node = call_trees.transform(node, ctx) node = control_flow.transform(node, ctx) node = conditional_expressions.transform(node, ctx) node = logical_expressions.transform(node, ctx) return node
def test_subscript_resolve(self): samples = """ x[i] x[i.b] a.b[c] a.b[x.y] a[z[c]] a[b[c[d]]] a[b].c a.b.c[d].e.f a.b[c[d]].e.f a.b[c[d.e.f].g].h """ nodes = resolve(parser.parse_str(textwrap.dedent(samples))) nodes = tuple(n.value for n in nodes.body) self.assertQNStringIs(nodes[0], 'x[i]') self.assertQNStringIs(nodes[1], 'x[i.b]') self.assertQNStringIs(nodes[2], 'a.b[c]') self.assertQNStringIs(nodes[3], 'a.b[x.y]') self.assertQNStringIs(nodes[4], 'a[z[c]]') self.assertQNStringIs(nodes[5], 'a[b[c[d]]]') self.assertQNStringIs(nodes[6], 'a[b].c') self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f') self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f') self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h')
def standard_analysis(node, context, is_initial=False): """Performs a complete static analysis of the given code. Args: node: ast.AST context: converter.EntityContext is_initial: bool, whether this is the initial analysis done on the input source code Returns: ast.AST, same as node, with the static analysis annotations added """ # TODO(mdan): Clear static analysis here. # TODO(mdan): Consider not running all analyses every time. # TODO(mdan): Don't return a node because it's modified by reference. graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, context, None) node = reaching_definitions.resolve(node, context, graphs, AnnotatedDef) node = liveness.resolve(node, context, graphs) node = live_values.resolve(node, context, config.PYTHON_LITERALS) node = type_info.resolve(node, context) # This second call allows resolving first-order class attributes. node = live_values.resolve(node, context, config.PYTHON_LITERALS) if is_initial: anno.dup( node, { anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) return node
def test_subscript_resolve(self): samples = """ x[i] x[i.b] a.b[c] a.b[x.y] a[z[c]] a[b[c[d]]] a[b].c a.b.c[d].e.f a.b[c[d]].e.f a.b[c[d.e.f].g].h """ nodes = parser.parse(textwrap.dedent(samples), single_node=False) nodes = tuple(resolve(node).value for node in nodes) self.assertQNStringIs(nodes[0], 'x[i]') self.assertQNStringIs(nodes[1], 'x[i.b]') self.assertQNStringIs(nodes[2], 'a.b[c]') self.assertQNStringIs(nodes[3], 'a.b[x.y]') self.assertQNStringIs(nodes[4], 'a[z[c]]') self.assertQNStringIs(nodes[5], 'a[b[c[d]]]') self.assertQNStringIs(nodes[6], 'a[b].c') self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f') self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f') self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h')
def standard_analysis(node, context, is_initial=False): """Performs a complete static analysis of the given code. Args: node: ast.AST context: converter.EntityContext is_initial: bool, whether this is the initial analysis done on the input source code Returns: ast.AST, same as node, with the static analysis annotations added """ # TODO(mdan): Clear static analysis here. # TODO(mdan): Consider not running all analyses every time. # TODO(mdan): Don't return a node because it's modified by reference. graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, context.info, None) node = reaching_definitions.resolve(node, context.info, graphs, AnnotatedDef) node = liveness.resolve(node, context.info, graphs) node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) node = type_info.resolve(node, context.info) # This second call allows resolving first-order class attributes. node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) if is_initial: anno.dup( node, { anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) return node
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) return node, entity_info
def transform_ast(self, node, ctx): node = qual_names.resolve(node) node = activity.resolve(node, ctx) graphs = cfg.build(node) node = reaching_definitions.resolve(node, ctx, graphs) node = reaching_fndefs.resolve(node, ctx, graphs) node = type_inference.resolve(node, ctx, graphs, self.resolver) return node
def transform(node, ctx): graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef) node = liveness.resolve(node, ctx, graphs) node = ControlFlowTransformer(ctx).visit(node) return node
def test_rename_symbols_attributes(self): node = parser.parse('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_global(self): node = parser.parse('global a, b, c') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b'): qual_names.QN('renamed_b')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'global a, renamed_b, c')
def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols(node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_rename_symbols_basic(self): node = parser.parse_str('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_a + b')
def convert2(): node, ctx = get_node_and_ctx(f2) node = qual_names.resolve(node) node = activity.resolve(node, ctx) fn_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) # Note: tag will be changed soon. print('read:', fn_scope.read) print('modified:', fn_scope.modified)
def test_rename_symbols_basic(self): node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.value.left.id, str) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_a + b')
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) node = reaching_definitions.resolve(node, ctx, graphs, reaching_definitions.Definition) return node
def _parse_and_analyze(self, test_fn): node, source, _ = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) return node, entity_info
def initial_analysis(self, node, ctx): graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_definitions.resolve(node, ctx, graphs) anno.dup( node, { anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) return node
def test_rename_symbols_basic(self): node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) source = parser.unparse(node, include_encoding_marker=False) expected_node_src = 'renamed_a + b' self.assertIsInstance(node.value.left.id, str) self.assertAstMatches(node, source) self.assertAstMatches(node, expected_node_src)
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) return node, entity_info
def convert4(): node, ctx = get_node_and_ctx(f4) node = qual_names.resolve(node) cfgs = cfg.build(node) node = activity.resolve(node, ctx) node = reaching_definitions.resolve(node, ctx, cfgs) node = reaching_fndefs.resolve(node, ctx, cfgs) node = liveness.resolve(node, ctx, cfgs) print('live into `b = a + 1`:', anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN)) print('live into `return b`:', anno.getanno(node.body[1], anno.Static.LIVE_VARS_IN))
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) graphs = cfg.build(node) liveness.resolve(node, entity_info, graphs) return node
def _parse_and_analyze(self, test_fn): node, _, source = parser.parse_entity(test_fn, future_imports=()) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) liveness.resolve(node, ctx, graphs) return node
def mlir_gen_internal(node, entity_info): """Returns mlir module for unprocessed node `node`.""" namer = naming.Namer({}) graphs = cfg.build(node) ctx = transformer.Context(entity_info, namer, None) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = reaching_definitions.resolve(node, ctx, graphs) node = reaching_fndefs.resolve(node, ctx, graphs) node = liveness.resolve(node, ctx, graphs) mlir_generator = MLIRGen(ctx) mlir_generator.visit(node) return mlir_generator.prog
def _parse_and_analyze(self, test_fn): # TODO(mdan): Use a custom FunctionTransformer here. node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo(name=test_fn.__name__, source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) namer = naming.Namer({}) ctx = transformer.Context(entity_info, namer, None) node = activity.resolve(node, ctx) return node, entity_info
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) graphs = cfg.build(node) node = reaching_definitions.resolve(node, entity_info, graphs, reaching_definitions.Definition) return node
def transform(node, ctx): """Transform function call to the compiled counterparts. Args: node: AST ctx: EntityContext Returns: A tuple (node, new_names): node: The transformed AST new_names: set(string), containing any newly-generated names """ node = qual_names.resolve(node) node = CallTreeTransformer(ctx).visit(node) return node
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) liveness.resolve(node, ctx, graphs) return node
def transform_ast(self, node, ctx): node = _apply_py_to_tf_passes(node, ctx) # TODO(mdan): Enable this. # node = anf.transform(node, ctx) graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = reaching_definitions.resolve(node, ctx, graphs) node = reaching_fndefs.resolve(node, ctx, graphs) node = type_inference.resolve(node, ctx, graphs, TFRTypeResolver(self._op_defs)) mlir_generator = TFRGen(ctx, self._op_defs) mlir_generator.visit(node) return mlir_generator.code_buffer
def test_function_calls(self): samples = """ a.b a.b() a().b z[i] z[i]() z()[i] """ nodes = resolve(parser.parse_str(textwrap.dedent(samples))) nodes = tuple(n.value for n in nodes.body) self.assertQNStringIs(nodes[0], 'a.b') self.assertQNStringIs(nodes[1].func, 'a.b') self.assertQNStringIs(nodes[2].value.func, 'a') self.assertQNStringIs(nodes[3], 'z[i]') self.assertQNStringIs(nodes[4].func, 'z[i]') self.assertQNStringIs(nodes[5].value.func, 'z')
def test_function_calls(self): samples = """ a.b a.b() a().b z[i] z[i]() z()[i] """ nodes = parser.parse(textwrap.dedent(samples), single_node=False) nodes = tuple(resolve(node).value for node in nodes) self.assertQNStringIs(nodes[0], 'a.b') self.assertQNStringIs(nodes[1].func, 'a.b') self.assertQNStringIs(nodes[2].value.func, 'a') self.assertQNStringIs(nodes[3], 'z[i]') self.assertQNStringIs(nodes[4].func, 'z[i]') self.assertQNStringIs(nodes[5].value.func, 'z')
def replace_as_expression(template, **replacements): """Variant of replace that generates expressions, instead of code blocks.""" replacement = replace(template, **replacements) if len(replacement) != 1: raise ValueError( 'single expression expected; for more general templates use replace') node = replacement[0] node = qual_names.resolve(node) if isinstance(node, gast.Expr): return node.value elif isinstance(node, gast.Name): return node raise ValueError( 'the template is expected to generate an expression or a name node;' ' instead found %s' % node)
def test_resolve(self): samples = """ a a.b (c, d.e) [f, (g.h.i)] j(k, l) """ nodes = resolve(parser.parse_str(textwrap.dedent(samples))) nodes = tuple(n.value for n in nodes.body) self.assertQNStringIs(nodes[0], 'a') self.assertQNStringIs(nodes[1], 'a.b') self.assertQNStringIs(nodes[2].elts[0], 'c') self.assertQNStringIs(nodes[2].elts[1], 'd.e') self.assertQNStringIs(nodes[3].elts[0], 'f') self.assertQNStringIs(nodes[3].elts[1], 'g.h.i') self.assertQNStringIs(nodes[4].func, 'j') self.assertQNStringIs(nodes[4].args[0], 'k') self.assertQNStringIs(nodes[4].args[1], 'l')
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=None) node = qual_names.resolve(node) graphs = cfg.build(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) node = reaching_definitions.resolve(node, ctx, graphs, reaching_definitions.Definition) node = live_values.resolve(node, ctx, {}) node = type_info.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) return node
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef( class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def convert_class_to_ast(c, program_ctx): """Specialization of `convert_entity_to_ast` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('cannot convert %s: no member methods' % c) # TODO(mdan): Don't clobber namespaces for each method in one class namespace. # The assumption that one namespace suffices for all methods only holds if # all methods were defined in the same module. # If, instead, functions are imported from multiple modules and then spliced # into the class, then each function has its own globals and __future__ # imports that need to stay separate. # For example, C's methods could both have `global x` statements referring to # mod1.x and mod2.x, but using one namespace for C would cause a conflict. # from mod1 import f1 # from mod2 import f2 # class C(object): # method1 = f1 # method2 = f2 class_namespace = {} future_features = None for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue (node,), _, entity_info = convert_func_to_ast( m, program_ctx=program_ctx, do_rename=False) class_namespace.update(entity_info.namespace) converted_members[m] = node # TODO(mdan): Similarly check the globals. if future_features is None: future_features = entity_info.future_features elif frozenset(future_features) ^ frozenset(entity_info.future_features): # Note: we can support this case if ever needed. raise ValueError( 'cannot convert {}: if has methods built with mismatched future' ' features: {} and {}'.format(c, future_features, entity_info.future_features)) namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef( class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) # TODO(mdan): Find a way better than forging this object. entity_info = transformer.EntityInfo( source_code=None, source_file=None, future_features=future_features, namespace=class_namespace) return output_nodes, class_name, entity_info