def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args) return 7 try: result, source = compiler.ast_to_object(node) result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( errors.rewrite_graph_construction_error) result.__dict__['ag__'] = fake_ag for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def compiled(self, node, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args) return 7 try: result, source = compiler.ast_to_object(node) result.tf = self.make_fake_mod('fake_tf', *symbols) result.autograph_utils = utils result.autograph_api = self.make_fake_mod('fake_api', converted_call) result.__dict__['__ops'] = operators yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def compiled(self, node, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args) return 7 try: result, source = compiler.ast_to_object(node) result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( errors.rewrite_graph_construction_error) result.__dict__['ag__'] = fake_ag yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def test_codegen_gens(self): np.random.seed(0) for _ in range(1000): node = codegen.generate_random_functiondef() fn = compiler.ast_to_object(node) self.assertIsNotNone( fn, 'Generated invalid AST that could not convert to source.')
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source = compiler.ast_to_object(node) expected_source = """ def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments(args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp(op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source = compiler.ast_to_object(node) expected_source = """ def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.extend(parser.parse_str(import_line).body) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): for key, val in inspect_utils.getnamespace(e).items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def test_basic(self): def test_function(): a = 0 return a node, _ = parser.parse_entity(test_function) node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function())
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) module = gast.Module([]) for dep in reversed(program_ctx.dependency_cache.values()): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object( module, source_prefix=program_ctx.required_imports) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def _remover_wrapper(self, f, remove_decorators): namespace = { 'self_removing_decorator': self_removing_decorator, 'simple_decorator': simple_decorator } node = self.parse_and_analyze(f, namespace) node, _ = decorators.transform(node, remove_decorators=remove_decorators) result, _ = compiler.ast_to_object(node) return getattr(result, f.__name__)
def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. output = parser.parse_str('b = 3') output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) result, _ = compiler.ast_to_object(output) self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. node = parser.parse_str('def f(b): pass').body[0] node.body.append(ast.Return(d)) result, _ = compiler.ast_to_object(node) self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. output = parser.parse_str('b = 3') output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d), ) result, _ = compiler.ast_to_object(output) self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn())
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result, _ = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def test_noop(self): def test_fn(a): return a node = self.parse_and_analyze(test_fn, {}) node = decorators.transform(node, self.ctx) result, _ = compiler.ast_to_object(node) self.assertEqual(1, result.test_fn(1))
def test_noop(self): def test_fn(a): return a node = self.parse_and_analyze(test_fn, {}) node, deps = decorators.transform(node, remove_decorators=()) result, _ = compiler.ast_to_object(node) self.assertFalse(deps) self.assertEqual(1, result.test_fn(1))
def test_replace_variable(self): template = """ def test_fn(a): a += 1 a = 2 * a + 1 return b """ node = templates.replace(template, a='b')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self): template = """ def fname(a): a += 1 a = 2 * a + 1 return a """ node = templates.replace(template, fname='test_fn')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_parser_compile_idempotent(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource( compiler.ast_to_object( parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def test_replace_name_with_call(self): template = """ def test_fn(): b = 5 def g(a): return 3 * a def f(): return g return foo """ source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(15, result.test_fn())
def _transform(self, f, autograph_decorators): namespace = { 'self_transform_decorator': self_transform_decorator, 'simple_decorator': simple_decorator, 'converter_testing': converter_testing, } node, ctx = self.prepare( f, namespace, recursive=False, autograph_decorators=autograph_decorators) node = decorators.transform(node, ctx) import_line = '\n'.join(ctx.program.additional_imports) result, _ = compiler.ast_to_object(node, source_prefix=import_line) return getattr(result, f.__name__)
def test_replace_code_block(self): template = """ def test_fn(a): block return a """ node = templates.replace( template, block=[ gast.Assign([ gast.Name('a', None, None) ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn(1))
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def test_replace_code_block(self): template = """ def test_fn(a): block return a """ node = templates.replace( template, block=[ gast.Assign([gast.Name('a', None, None)], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn(1))
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Converts a Python entity into equivalent code that uses TensorFlow ops. Supported Python entities include: * functions * classes Classes are converted by converting all their methods into a new class. Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. verbose: bool, whether to output the compiled code in the logs. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. Returns: Union[Callable, Type], the converted entity, which is the same kind as e (that is, a function is e is a function, a class if e is a class, etc.) but its code has been converted to use TF ops. Raises: ValueError: If the entity could not be converted. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. Raises: ValueError: If the converted function defines or refers to symbol names that are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. Raises: ValueError: If the converted function defines or refers to symbol names that are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled_fn = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. source_map_attribute_name = 'ag_source_map' if getattr(compiled_fn, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled_fn, source_map_attribute_name)) setattr(compiled_fn, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn