def test_parse_lambda_in_expression(self): l = ( lambda x: lambda y: x + y + 1, lambda x: lambda y: x + y + 2, ) node, source = parser.parse_entity(l[0], future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda x: (lambda y: ((x + y) + 1)))') self.assertMatchesWithPotentialGarbage( source, 'lambda x: lambda y: x + y + 1', ',') node, source = parser.parse_entity(l[0](0), future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda y: ((x + y) + 1))') self.assertMatchesWithPotentialGarbage(source, 'lambda y: x + y + 1', ',') node, source = parser.parse_entity(l[1], future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda x: (lambda y: ((x + y) + 2)))') self.assertMatchesWithPotentialGarbage( source, 'lambda x: lambda y: x + y + 2', ',') node, source = parser.parse_entity(l[1](0), future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda y: ((x + y) + 2))') self.assertMatchesWithPotentialGarbage(source, 'lambda y: x + y + 2', ',')
def __repr__(self): if isinstance(self.ast_node, gast.FunctionDef): return 'def %s' % self.ast_node.name elif isinstance(self.ast_node, gast.ClassDef): return 'class %s' % self.ast_node.name elif isinstance(self.ast_node, gast.withitem): return parser.unparse( self.ast_node.context_expr, include_encoding_marker=False).strip() return parser.unparse(self.ast_node, include_encoding_marker=False).strip()
def assert_body_anfs_as_expected(self, expected_fn, test_fn): # Testing the code bodies only. Wrapping them in functions so the # syntax highlights nicely, but Python doesn't try to execute the # statements. node, _ = parser.parse_entity(test_fn, future_features=()) orig_source = parser.unparse(node, indentation=' ') orig_str = textwrap.dedent(orig_source).strip() config = [(anf.ANY, anf.LEAVE)] # Configuration to transform nothing node = anf.transform(node, self._simple_context(), config=config) new_source = parser.unparse(node, indentation=' ') new_str = textwrap.dedent(new_source).strip() self.assertEqual(orig_str, new_str)
def test_parse_lambda_resolution_by_signature(self): l = lambda x: lambda x, y: x + y node, source = parser.parse_entity(l, future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda x: (lambda x, y: (x + y)))') self.assertEqual(source, 'lambda x: lambda x, y: x + y') node, source = parser.parse_entity(l(0), future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda x, y: (x + y))') self.assertEqual(source, 'lambda x, y: x + y')
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive argument for a symbol. See lang/directives.py for details on directives. Example: # Given a directive in the code: ag.foo_directive(bar, baz=1) # One can write for an AST node Name(id='bar'): get_definition_directive(node, ag.foo_directive, 'baz') Args: node: ast.AST, the node representing the symbol for which the directive argument is needed. directive: Callable[..., Any], the directive to search. arg: str, the directive argument to return. default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default arg_values_found = [] for def_ in defs: if (directive in def_.directives and arg in def_.directives[directive]): arg_values_found.append(def_.directives[directive][arg]) if not arg_values_found: return default if len(arg_values_found) == 1: return arg_values_found[0] # If multiple annotations reach the symbol, they must all match. If they do, # return any of them. first_value = arg_values_found[0] for other_value in arg_values_found[1:]: if not ast_util.matches(first_value, other_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError( '%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, parser.unparse(other_value).strip(), parser.unparse(first_value).strip())) return first_value
def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals): """Returns the transformed function factory for a given input.""" if self._cache.has(fn, cache_subkey): return self._cached_factory(fn, cache_subkey) with self._cache_lock: # Check again under lock. if self._cache.has(fn, cache_subkey): return self._cached_factory(fn, cache_subkey) logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) nodes, ctx = self._transform_function(fn, user_context) if logging.has_verbosity(2): logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) factory = _TransformedFnFactory(ctx.info.name, fn.__code__.co_freevars, extra_locals) factory.create(nodes, ctx.namer, future_features=ctx.info.future_features) self._cache[fn][cache_subkey] = factory return factory
def do_parse_and_test(lam, **unused_kwargs): node, source = parser.parse_entity(lam, future_features=()) self.assertEqual( parser.unparse(node, include_encoding_marker=False), '(lambda x: x)') self.assertEqual(source, 'lambda x: x, named_arg=1)')
def test_parse_lambda_complex_body(self): l = lambda x: ( # pylint:disable=g-long-lambda x.y( [], x.z, (), x[0:2], ), x.u, 'abc', 1, ) node, source = parser.parse_entity(l, future_features=()) self.assertEqual( parser.unparse(node, include_encoding_marker=False), "(lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1))") self.assertEqual(source, ('lambda x: ( # pylint:disable=g-long-lambda\n' ' x.y(\n' ' [],\n' ' x.z,\n' ' (),\n' ' x[0:2],\n' ' ),\n' ' x.u,\n' ' \'abc\',\n' ' 1,'))
def test_ext_slice_roundtrip(self): def ext_slice(n): return n[:, :], n[0, :], n[:, 0] node, _ = parser.parse_entity(ext_slice, future_features=()) source = parser.unparse(node) self.assertAstMatches(node, source, expr=False)
def convert_entity_to_ast(o, program_ctx): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: NotImplementedError: if entity is of a type that is not yet supported. """ logging.log(1, 'Converting %s', o) nodes, name, entity_info = convert_func_to_ast(o, program_ctx) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def assertLambdaNodes(self, matching_nodes, expected_bodies): self.assertEqual(len(matching_nodes), len(expected_bodies)) for node in matching_nodes: self.assertIsInstance(node, gast.Lambda) self.assertIn( parser.unparse(node.body, include_encoding_marker=False).strip(), expected_bodies)
def test_unparse(self): node = gast.If(test=gast.Constant(1, kind=None), body=[ gast.Assign(targets=[ gast.Name('a', ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name('b', ctx=gast.Load(), annotation=None, type_comment=None)) ], orelse=[ gast.Assign(targets=[ gast.Name('a', ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant('c', kind=None)) ]) source = parser.unparse(node, indentation=' ') self.assertEqual( textwrap.dedent(""" # coding=utf-8 if 1: a = b else: a = 'c' """).strip(), source.strip())
def test_parse_lambda_complex_body(self): l = lambda x: ( # pylint:disable=g-long-lambda x.y( [], x.z, (), x[0:2], ), x.u, 'abc', 1, ) node, source = parser.parse_entity(l, future_features=()) self.assertEqual( parser.unparse(node, include_encoding_marker=False), "(lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1))") base_source = ('lambda x: ( # pylint:disable=g-long-lambda\n' ' x.y(\n' ' [],\n' ' x.z,\n' ' (),\n' ' x[0:2],\n' ' ),\n' ' x.u,\n' ' \'abc\',\n' ' 1,') # The complete source includes the trailing parenthesis. But that is only # detected in runtimes which correctly track end_lineno for ASTs. self.assertIn(source, (base_source, base_source + '\n )'))
def test_rename_symbols_function(self): node = parser.parse('def f():\n pass') node = ast_util.rename_symbols( node, {qual_names.QN('f'): qual_names.QN('f1')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'def f1():\n pass')
def test_parse_lambda_prefix_cleanup(self): lambda_lam = lambda x: x + 1 node, source = parser.parse_entity(lambda_lam, future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda x: (x + 1))') self.assertEqual(source, 'lambda x: x + 1')
def test_rename_symbols_attributes(self): node = parser.parse('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_global(self): node = parser.parse('global a, b, c') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b'): qual_names.QN('renamed_b')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'global a, renamed_b, c')
def transform_function(self, fn, user_context): """Transforms a function. See GenericTranspiler.trasnform_function. This overload wraps the parent's `transform_function`, adding caching and facilities to instantiate the output as a Python object. It also adds facilities to make new symbols available to the generated Python code, visible as local variables - see `get_extra_locals`. Args: fn: A function or lambda. user_context: An opaque object (may be None) that is forwarded to transform_ast, through the ctx.user_context argument. Returns: A tuple: * A function or lambda with the same signature and closure as `fn` * The temporary module into which the transformed function was loaded * The source map as a Dict[origin_info.LineLocation, origin_info.OriginInfo] """ cache_subkey = self.get_caching_key(user_context) if self._cache.has(fn, cache_subkey): # Fast path: use a lock-free check. factory = self._cached_factory(fn, cache_subkey) else: with self._cache_lock: # Check again under lock. if self._cache.has(fn, cache_subkey): factory = self._cached_factory(fn, cache_subkey) else: logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) # TODO(mdan): Confusing overloading pattern. Fix. nodes, ctx = super(PyToPy, self).transform_function( fn, user_context) if logging.has_verbosity(2): logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) factory = _PythonFnFactory(ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals()) factory.create(nodes, ctx.namer, future_features=ctx.info.future_features) self._cache[fn][cache_subkey] = factory transformed_fn = factory.instantiate(globals_=fn.__globals__, closure=fn.__closure__ or (), defaults=fn.__defaults__, kwdefaults=getattr( fn, '__kwdefaults__', None)) return transformed_fn, factory.module, factory.source_map
def test_parse_lambda_resolution_by_location(self): _ = lambda x: x + 1 l = lambda x: x + 1 _ = lambda x: x + 1 node, source = parser.parse_entity(l, future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda x: (x + 1))') self.assertEqual(source, 'lambda x: x + 1')
def test_rename_symbols_basic(self): node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.value.left.id, str) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_a + b')
def visit_If(self, node): self.emit('if ') # This is just for simplifity. A real generator will walk the tree and # emit proper code. self.emit(parser.unparse(node.test, include_encoding_marker=False)) self.emit(' {\n') self.visit_block(node.body) self.emit('} else {\n') self.visit_block(node.orelse) self.emit('}\n')
def test_parse_lambda_multiline(self): l = ( lambda x: lambda y: x + y # pylint:disable=g-long-lambda - 1) node, source = parser.parse_entity(l, future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda x: (lambda y: ((x + y) - 1)))') self.assertMatchesWithPotentialGarbage( source, ('lambda x: lambda y: x + y # pylint:disable=g-long-lambda\n' ' - 1'), ')') node, source = parser.parse_entity(l(0), future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda y: ((x + y) - 1))') self.assertMatchesWithPotentialGarbage( source, ('lambda y: x + y # pylint:disable=g-long-lambda\n' ' - 1'), ')')
def test_rename_symbols_basic(self): node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) source = parser.unparse(node, include_encoding_marker=False) expected_node_src = 'renamed_a + b' self.assertIsInstance(node.value.left.id, str) self.assertAstMatches(node, source) self.assertAstMatches(node, expected_node_src)
def test_convert_entity_to_ast_function_with_defaults(self): b = 2 c = 1 def f(a, d=c + 1): return a + b + d program_ctx = self._simple_program_ctx() nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx) fn_node, = nodes self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertEqual( parser.unparse(fn_node.args.defaults[0], include_encoding_marker=False).strip(), 'None')
def visit_IfExp(self, node): template = ''' ag__.if_exp( test, lambda: true_expr, lambda: false_expr, expr_repr) ''' expr_repr = parser.unparse(node.test, include_encoding_marker=False).strip() return templates.replace_as_expression(template, test=node.test, true_expr=node.body, false_expr=node.orelse, expr_repr=gast.Constant( expr_repr, kind=None))
def convert_entity_to_ast(o, program_ctx): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: NotImplementedError: if entity is of a type that is not yet supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): nodes, name, entity_info = convert_class_to_ast(o, program_ctx) elif tf_inspect.isfunction(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif tf_inspect.ismethod(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif hasattr(o, '__class__'): # Note: this should only be raised when attempting to convert the object # directly. converted_call should still support it. raise NotImplementedError( 'cannot convert entity "{}": object conversion is not yet' ' supported.'.format(o)) else: raise NotImplementedError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def test_unparse(self): node = gast.If( test=gast.Num(1), body=[ gast.Assign( targets=[gast.Name('a', gast.Store(), None)], value=gast.Name('b', gast.Load(), None)) ], orelse=[ gast.Assign( targets=[gast.Name('a', gast.Store(), None)], value=gast.Str('c')) ]) source = parser.unparse(node, indentation=' ') self.assertEqual( textwrap.dedent(""" # coding=utf-8 if 1: a = b else: a = 'c' """).strip(), source.strip())
def load_ast(nodes, indentation=' ', include_source_map=False, delete_on_exit=True): """Loads the given AST as a Python module. Compiling the AST code this way ensures that the source code is readable by e.g. `pdb` or `inspect`. Args: nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST object. indentation: Text, the string to use for indentation. include_source_map: bool, whether return a source map. delete_on_exit: bool, whether to delete the temporary file used for compilation on exit. Returns: Tuple[module, Text, Dict[LineLocation, OriginInfo]], containing: the module containing the unparsed nodes, the source code corresponding to nodes, and the source map. Is include_source_map is False, the source map will be None. """ if not isinstance(nodes, (list, tuple)): nodes = (nodes, ) source = parser.unparse(nodes, indentation=indentation) module, _ = load_source(source, delete_on_exit) if include_source_map: source_map = origin_info.create_source_map(nodes, source, module.__file__) else: source_map = None # TODO(mdan): Return a structured object. return module, source, source_map
def visit_Return(self, node): self.emit(parser.unparse(node, include_encoding_marker=False)) self.emit('\n')
def debug_print_src(self, node): """Helper method useful for debugging. Prints the AST as code.""" if __debug__: print(parser.unparse(node)) return node