def test_rename_symbols_function(self): node = parser.parse('def f():\n pass') node = ast_util.rename_symbols( node, {qual_names.QN('f'): qual_names.QN('f1')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'def f1():\n pass')
def test_replace_name_with_subscript(self): template = """ foo = bar """ replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key')) node = templates.replace(template, foo=replacement)[0].targets[0] self.assertIsInstance(node.ctx, gast.Store) self.assertIsInstance(node.value.ctx, gast.Load)
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_rename_symbols_basic(self): node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.value.left.id, str) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_a + b')
def test_rename_symbols_basic(self): node = parser.parse_str('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_a + b')
def visit_alias(self, node): node = self.generic_visit(node) if node.asname is None: # Only the root name is a real symbol operation. qn = qual_names.QN(node.name.split('.')[0]) else: qn = qual_names.QN(node.asname) self.scope.modified.add(qn) self.scope.bound.add(qn) return node
def test_rename_symbols_basic(self): node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) source = parser.unparse(node, include_encoding_marker=False) expected_node_src = 'renamed_a + b' self.assertIsInstance(node.value.left.id, str) self.assertAstMatches(node, source) self.assertAstMatches(node, expected_node_src)
def visit_FunctionDef(self, node): f_name = qual_names.QN(node.name) if node.decorator_list: raise NotImplementedError('decorators: {}'.format( node.decorator_list)) ret_types = None if node.returns: ret_types, _ = self.resolver.res_name( self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns)) if __debug__: self._check_set(ret_types) if ret_types is None: ret_types = {Any} f_types = set() for rt in ret_types: f_types.add(Callable[[Any], rt]) self.new_symbols[f_name] = f_types # The definition of a function is an expression, hence has no return value. return None
def visit_ClassDef(self, node): # The ClassDef node itself has a Scope object that tracks the creation # of its name, along with the usage of any decorator accompanying it. self._enter_scope(False) node.decorator_list = self.visit_block(node.decorator_list) self.scope.modified.add(qual_names.QN(node.name)) self.scope.bound.add(qual_names.QN(node.name)) node.bases = self.visit_block(node.bases) node.keywords = self.visit_block(node.keywords) self._exit_and_record_scope(node) # A separate Scope tracks the actual class definition. self._enter_scope(True) node = self.generic_visit(node) self._exit_scope() return node
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): if f_name == 'test_fn': test_self.assertFalse(f_is_local) test_self.assertEqual(name, qual_names.QN('a')) test_self.assertEqual(type_anno, qual_names.QN('int')) elif f_name == 'foo': test_self.assertTrue(f_is_local) if name == qual_names.QN('x'): test_self.assertEqual(type_anno, qual_names.QN('float')) elif name == qual_names.QN('y'): test_self.assertIsNone(type_anno) else: test_self.fail('unexpected argument {} for {}'.format(name, f_name)) else: test_self.fail('unexpected function name {}'.format(f_name)) return {str(name) + '_type'}
def visit_FunctionDef(self, node): # The FunctionDef node itself has a Scope object that tracks the creation # of its name, along with the usage of any decorator accompanying it. self._enter_scope(False) node.decorator_list = self.visit_block(node.decorator_list) self.scope.mark_modified(qual_names.QN(node.name)) anno.setanno(node, anno.Static.SCOPE, self.scope) self._exit_scope() # A separate Scope tracks the actual function definition. self._enter_scope(True) assert not (self._in_function_def_args or self.state[_Lambda].level) self._in_function_def_args = True node.args = self.visit(node.args) self._in_function_def_args = False # Track the body separately. This is for compatibility reasons, it may not # be strictly needed. self._enter_scope(False) node.body = self.visit_block(node.body) anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) self._exit_scope() self._exit_scope() return node
def visit_node(self, node): prev_defs_out = self.out[node] defs_in = _NodeState(self.extra_in.get(node.ast_node, None)) for n in node.prev: defs_in |= self.out[n] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) # The definition objects created by each node must be singletons because # their ids are used in equality checks. if node not in self.gen_map: node_symbols = {} for s in node_scope.modified: def_ = self._definition_factory() if s in node_scope.params: def_.param_of = weakref.ref(node_scope.params[s]) node_symbols[s] = def_ self.gen_map[node] = _NodeState(node_symbols) gen = self.gen_map[node] kill = node_scope.modified | node_scope.deleted defs_out = gen | (defs_in - kill) elif isinstance(node.ast_node, (gast.Global, gast.Nonlocal)): # Special case for global and nonlocal: they generate a definition, # but are not tracked by activity analysis. if node not in self.gen_map: node_symbols = {} for s in node.ast_node.names: qn = qual_names.QN(s) if qn in defs_in.value: # In Python 2, this is a syntax warning. In Python 3, it's an error. raise ValueError( '"{}" is assigned before global definition'.format( s)) def_ = self._definition_factory() node_symbols[qn] = def_ self.gen_map[node] = _NodeState(node_symbols) gen = self.gen_map[node] defs_out = defs_in | gen else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False # This can also happen if activity.py forgot to annotate the node with a # scope object. assert isinstance(node.ast_node, (gast.Name, gast.Break, gast.Continue, gast.Raise, gast.Pass)), (node.ast_node, node) defs_out = defs_in self.in_[node] = defs_in self.out[node] = defs_out # TODO(mdan): Move this to the superclass? return prev_defs_out != defs_out
def visit_Nonlocal(self, node): self._enter_scope(False) for name in node.names: qn = qual_names.QN(name) self.scope.read.add(qn) self.scope.bound.add(qn) self._exit_and_record_scope(node) return node
def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_attributes(self): node = parser.parse('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_global(self): node = parser.parse('global a, b, c') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b'): qual_names.QN('renamed_b')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'global a, renamed_b, c')
def res_call(self, ns, types_ns, node, f_type, args, keywords): name = anno.Basic.QN.of(node.func) if f_type == (TFRTypes.AG_BUILTIN_FUNC,): if name == QN(QN('ag__'), attr='if_stmt'): nouts = node.args[6].value # TODO(mdan): Look at the actual types out of if_body. side_effects = { qual_names.QN(n.value): {TFRTypes.TENSOR} for n in node.args[5].elts[:nouts] } return {type(None)}, side_effects if name == QN(QN('ag__'), attr='for_stmt'): assert isinstance(node.args[2], ast.Name) body_fn_name = str(anno.Basic.QN.of(node.args[2])) assert body_fn_name not in self._for_loop_body_fns, ( 'Previously used here: {}. Are you reusing the Resolver across ' 'transformations?').format(self._for_loop_body_fns[body_fn_name]) self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node) iterated_type = args[0] assert iterated_type & { TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, List[int] }, ( iterated_type) self._for_loop_target_types[body_fn_name] = iterated_type return {type(None)}, None # TODO(mdan): Actually resolve the type here instead. ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None) if ret_type is not None: return {ret_type}, None raise NotImplementedError('return type of {}'.format(name)) elif f_type == (TFRTypes.TF_RAW_OP,): op_name = name.qn[1] op_def, _ = self._op_defs.lookup(op_name) if len(op_def.output_arg) == 1: return {_get_type_from_proto(op_def.output_arg[0])}, None return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, None) elif f_type == (TFRTypes.PY_BUILTIN_FUNC,): assert name.is_simple() if name == QN('range'): return {List[int]}, None if name == QN('len'): return {TFRTypes.INDEX}, None elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,): return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None raise NotImplementedError('Function:', name, f_type)
def visit_ClassDef(self, node): # The ClassDef node itself has a Scope object that tracks the creation # of its name, along with the usage of any decorator accompanying it. self._enter_scope(False) node.decorator_list = self.visit_block(node.decorator_list) self.scope.mark_modified(qual_names.QN(node.name)) anno.setanno(node, anno.Static.SCOPE, self.scope) self._exit_scope() # A separate Scope tracks the actual class definition. self._enter_scope(True) assert not (self._in_function_def_args or self.state[_Lambda].level) node = self.generic_visit(node) self._exit_scope() return node
def visit_FunctionDef(self, node): # The FunctionDef node itself has a Scope object that tracks the creation # of its name, along with the usage of any decorator accompanying it. self._enter_scope(False) node.decorator_list = self.visit_block(node.decorator_list) function_name = qual_names.QN(node.name) self.scope.modified.add(function_name) self.scope.bound.add(function_name) self._exit_and_record_scope(node) # A separate Scope tracks the actual function definition. self._enter_scope(True) node.args = self.visit(node.args) # Track the body separately. This is for compatibility reasons, it may not # be strictly needed. self._enter_scope(False) node.body = self.visit_block(node.body) self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE) self._exit_scope() return node
def visit_FunctionDef(self, node): with self.state[_FunctionOrClass] as fn: fn.node = node # The FunctionDef node itself has a Scope object that tracks the creation # of its name, along with the usage of any decorator accompanying it. self._enter_scope(False) node.decorator_list = self.visit_block(node.decorator_list) if node.returns: node.returns = self._process_annotation(node.returns) # Argument annotartions (includeing defaults) affect the defining context. node = self._visit_arg_annotations(node) function_name = qual_names.QN(node.name) self.scope.modified.add(function_name) self.scope.bound.add(function_name) self._exit_and_record_scope(node) # A separate Scope tracks the actual function definition. self._enter_scope(True) # Keep a separate scope for the arguments node, which is used in the CFG. self._enter_scope(False) # Arg declarations only affect the function itself, and have no effect # in the defining context whatsoever. node = self._visit_arg_declarations(node) self._exit_and_record_scope(node.args) # Track the body separately. This is for compatibility reasons, it may not # be strictly needed. self._enter_scope(False) node.body = self.visit_block(node.body) self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE) self._exit_and_record_scope(node, NodeAnno.ARGS_AND_BODY_SCOPE) return node
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): if name == qual_names.QN('a'): return {int} else: return {float}
def res_name(self, ns, types_ns, name): test_self.assertEqual(name, qual_names.QN('g')) return {Callable[[Callable], None]}, g
def res_call(self, ns, types_ns, node, f_type, args, keywords): test_self.assertEqual(node.func.id, 'g') test_self.assertEqual(f_type, (Callable[[Callable], None], )) return None, {qual_names.QN('x'): {str}}
def _live_tensors(f, attr_name="inputs"): """Returns the indices of the used inputs. Note: This currently only handles direct index accesses e.g. op.inputs[1]. If the function has slicing or list comprehension on attr_name then returns _ALL. This ensure that this is correct even if inefficient. Args: f: A grad function, taking the op as first argument. attr_name: op attr to track. "inputs" or "outputs". Returns: Either one of: * set of integers representing individual indices of inputs used * the value _ALL, if indices are used but cannot be determined which * empty set, if no inputs are used """ node, _ = parser.parse_entity(f, ()) entity_info = transformer.EntityInfo( name=f.__name__, source_code=None, source_file=None, future_features=(), namespace=sys.modules[f.__module__].__dict__) ctx = transformer.Context(entity_info, None, None) graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_fndefs.resolve(node, ctx, graphs) node = liveness.resolve(node, ctx, graphs) op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN) op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name) special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name, )) node = special_tracker.visit(node) live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN) inputs_outputs_used_qns = set() for v in special_tracker.complex_reads: # Complicated patterns like op.inputs[:3]. Could be smarter about them # if they matter much. if v == op_inputs_outputs_name: return _ALL for v in live_vars_in: if v in special_tracker.reads: if (v.has_subscript() and v.parent == op_inputs_outputs_name): inputs_outputs_used_qns.add(v) elif v == op_inputs_outputs_name: # When op.{attr_name} is used directly, assume all tensors are # used for now. In that case, no point digging further. # TODO(mdan): We can descend into tuple expansions. return _ALL function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name) node = function_calls_tracker.visit(node) input_output_indices = set() for called_f in function_calls_tracker.calls: child_indices = _live_tensors(called_f, attr_name=attr_name) if child_indices is _ALL: return _ALL input_output_indices |= child_indices for v in inputs_outputs_used_qns: assert v.has_subscript() _, subscript = v.qn if not subscript.is_simple(): # Not a number, assuming it can be anything. return _ALL subscript_val, = subscript.qn if (not isinstance(subscript_val, qual_names.Literal) and not isinstance(subscript_val.value, int)): # Not a number, assuming it can be anything. return _ALL input_output_indices.add(subscript_val.value) return input_output_indices
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) # NOTE: We can't guard object attributes because they may not be writable. # In addition, avoid renaming well-known names. # TODO(mdan): Move these names into config. unguarded_names = (qual_names.QN('self'), qual_names.QN('tf')) guarded_args = tuple( s for s in args_scope.read if not s.is_composite() and s not in unguarded_names) # TODO(mdan): Include all arguments which depended on guarded_args too. # For example, the following will still cause a race: # tf.assign(a, a + 1) # b = a + 1 # tf.assign(a, a + 1) # Control deps here should include `b` # c = b + 1 # Or maybe we should just raise an "unsafe assign" error? if guarded_args: # The aliases may need new names to avoid incorrectly making them local. # TODO(mdan): This is brutal. It will even rename modules - any fix? need_alias = tuple(s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( self.ctx.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: s, = guarded_args aliased_guarded_args = alias_map.get(s, s) else: aliased_guarded_args = gast.Tuple( [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ with ag__.utils.control_dependency_on_returns(call): aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, call=node.value, aliased_guarded_args=aliased_guarded_args, guarded_args=guarded_args)[-1] else: alias_map = {} template = """ with ag__.utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] control_deps_guard.body = [] node = control_deps_guard anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, (node.body, alias_map)) return node
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue nodes, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, do_rename=False) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = nodes[0] namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def convert_class_to_ast(c, program_ctx): """Specialization of `convert_entity_to_ast` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('cannot convert %s: no member methods' % c) # TODO(mdan): Don't clobber namespaces for each method in one class namespace. # The assumption that one namespace suffices for all methods only holds if # all methods were defined in the same module. # If, instead, functions are imported from multiple modules and then spliced # into the class, then each function has its own globals and __future__ # imports that need to stay separate. # For example, C's methods could both have `global x` statements referring to # mod1.x and mod2.x, but using one namespace for C would cause a conflict. # from mod1 import f1 # from mod2 import f2 # class C(object): # method1 = f1 # method2 = f2 class_namespace = {} future_features = None for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue (node, ), _, entity_info = convert_func_to_ast(m, program_ctx=program_ctx, do_rename=False) class_namespace.update(entity_info.namespace) converted_members[m] = node # TODO(mdan): Similarly check the globals. if future_features is None: future_features = entity_info.future_features elif frozenset(future_features) ^ frozenset( entity_info.future_features): # Note: we can support this case if ever needed. raise ValueError( 'cannot convert {}: if has methods built with mismatched future' ' features: {} and {}'.format(c, future_features, entity_info.future_features)) namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) # TODO(mdan): Find a way better than forging this object. entity_info = transformer.EntityInfo(source_code=None, source_file=None, future_features=future_features, namespace=class_namespace) return output_nodes, class_name, entity_info
def _process_list_of_strings(self, names): for i in range(len(names)): qn = qual_names.QN(names[i]) if qn in self.name_map: names[i] = str(self.name_map[qn]) return names
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
assert isinstance(other, _SymbolTable) result = _SymbolTable(self) for s, other_types in other.value.items(): if s not in result.value: self_types = set() result.value[s] = self_types else: self_types = result.value[s] self_types.update(other_types) return result def __repr__(self): return 'SymbolTable {}'.format(self.value) _GETITEM = qual_names.QN('__getitem__') _HANDLERS = { gast.Eq: qual_names.QN('__eq__'), gast.NotEq: qual_names.QN('__ne__'), gast.Lt: qual_names.QN('__lt__'), gast.LtE: qual_names.QN('__le__'), gast.Gt: qual_names.QN('__gt__'), gast.GtE: qual_names.QN('__ge__'), gast.In: