def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args[3:]) # args only; see api.converted_call return RESULT_OF_MOCK_CONVERTED_CALL try: result, source, source_map = compiler.ast_to_object( node, include_source_map=True) # TODO(mdan): Move the unparsing from converter into pyct and reuse here. # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__.update(special_functions.__dict__) fake_ag.ConversionOptions = converter.ConversionOptions fake_ag.Feature = converter.Feature fake_ag.utils = utils fake_ag.function_scope = function_wrapping.function_scope result.ag__ = fake_ag result.ag_source_map__ = source_map for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args) return 7 try: result, source = compiler.ast_to_object(node, include_source_map=True) # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__.update(special_functions.__dict__) fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( errors.rewrite_graph_construction_error) fake_ag.__dict__['function_scope'] = function_wrapping.function_scope result.__dict__['ag__'] = fake_ag for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source = compiler.ast_to_object(node) expected_source = """ def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def module(self): """Constructs an `instructions.Module` for this `Context`. Returns: module: An `instructions.Module` representing the batched computation defined by all the functions decorated with `batch` in this `Context` so far. """ if self._module is not None: return self._module ab = dsl.ProgramBuilder() function_objects = [] for function, type_inference in self._tagged_functions: declared = ab.declare_function(function.__name__, type_inference) function_objects.append(declared) for function, _ in self._tagged_functions: name = function.__name__ node, ctx = _parse_and_analyze(function, self.function_names()) # print(compiler.ast_to_source(node, indentation=' ')) node = _AutoBatchingTransformer(self.function_names(), [ scoped_name for scoped_name, _ in _environment(function, [name]) ], ctx).visit(node) # print(compiler.ast_to_source(node, indentation=' ')) builder_module, _, _ = compiler.ast_to_object(node) for scoped_name, val in _environment(function, [name]): builder_module.__dict__[scoped_name] = val builder = getattr(builder_module, name) builder(ab, function_objects) self._module = ab.module() return self._module
def test_replace_code_block(self): template = """ def test_fn(a): block return a """ class ShouldBeReplaced(object): pass node = templates.replace( template, block=[ gast.Assign( [ gast.Name( 'a', ctx=ShouldBeReplaced, annotation=None, type_comment=None) ], gast.BinOp( gast.Name( 'a', ctx=ShouldBeReplaced, annotation=None, type_comment=None), gast.Add(), gast.Constant(1, kind=None)), ), ] * 2)[0] result, _, _ = compiler.ast_to_object(node) self.assertEqual(3, result.test_fn(1))
def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args) return 7 try: result, source = compiler.ast_to_object(node, include_source_map=True) # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( errors.rewrite_graph_construction_error) fake_ag.__dict__[ 'function_scope'] = function_wrapping.function_scope result.__dict__['ag__'] = fake_ag for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def test_codegen_gens(self): np.random.seed(0) for _ in range(1000): node = codegen.generate_random_functiondef() fn = compiler.ast_to_object(node) self.assertIsNotNone( fn, 'Generated invalid AST that could not convert to source.')
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source, _ = compiler.ast_to_object(node) expected_source = """ # coding=utf-8 def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def test_basic(self): def test_function(): a = 0 return a node, _ = parser.parse_entity(test_function) node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function())
def test_basic(self): def test_function(): a = 0 return a node, _, _ = parser.parse_entity(test_function, future_imports=()) node = anf.transform(node, self._simple_context()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function())
def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. node = parser.parse_str('def f(b): pass').body[0] node.body.append(ast.Return(d)) result, _ = compiler.ast_to_object(node) self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
def test_basic(self): def test_function(): a = 0 return a node, _ = parser.parse_entity(test_function) node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function())
def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. node = parser.parse_str('def f(b): pass').body[0] node.body.append(ast.Return(d)) result, _ = compiler.ast_to_object(node) self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn())
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result, _ = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result, _ = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn())
def test_replace_function_name(self): template = """ def fname(a): a += 1 a = 2 * a + 1 return a """ node = templates.replace(template, fname='test_fn')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_replace_variable(self): template = """ def test_fn(a): a += 1 a = 2 * a + 1 return b """ node = templates.replace(template, a='b')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_replace_variable(self): template = """ def test_fn(a): a += 1 a = 2 * a + 1 return b """ node = templates.replace(template, a='b')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self): template = """ def fname(a): a += 1 a = 2 * a + 1 return a """ node = templates.replace(template, fname='test_fn')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def _transform(self, f, strip_decorators): namespace = { 'self_transform_decorator': self_transform_decorator, 'simple_decorator': simple_decorator, 'converter_testing': converter_testing, } node, ctx = self.prepare( f, namespace, recursive=False, strip_decorators=strip_decorators) node = decorators.transform(node, ctx) import_line = '\n'.join(ctx.program.additional_imports) result, _ = compiler.ast_to_object(node, source_prefix=import_line) return getattr(result, f.__name__)
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names): """Returns a (possibly cached) factory for the converted result of entity.""" # The cache key is the entity's code object if it defined one, otherwise it's # the entity itself. Keying by the code object allows caching of functions # that are dynamically created e.g. in a loop. if hasattr(entity, '__code__'): key = entity.__code__ else: key = entity # The cache subkey encompases any conversion options on which the generated # code may depend. # The cached factory includes the necessary definitions to distinguish # between the global and non-global free variables. For this reason, the # cache subkey includes the names of the free non-globals. subkey = (program_ctx.options, frozenset(free_nonglobal_var_names)) with _CACHE_LOCK: # The cache values are _ConvertedEntityFactoryInfo objects. if _CACHE.has(key, subkey): # TODO(mdan): Check whether the module is still loaded. converted_entity_info = _CACHE[key][subkey] logging.log(3, 'Cache hit for entity %s key %s subkey %s: %s', entity, key, subkey, converted_entity_info) return converted_entity_info logging.log(1, 'Entity %s is not cached for key %s subkey %s', entity, key, subkey) nodes, converted_name, entity_info = convert_entity_to_ast( entity, program_ctx) namer = naming.Namer(entity_info.namespace) factory_factory_name = namer.new_symbol('create_converted_entity_factory', ()) factory_name = namer.new_symbol('create_converted_entity', ()) nodes = _wrap_into_dynamic_factory(nodes, converted_name, factory_factory_name, factory_name, free_nonglobal_var_names, entity_info.future_features) module, _, source_map = compiler.ast_to_object( nodes, include_source_map=True) module_name = module.__name__ converted_entity_info = _ConvertedEntityFactoryInfo( module_name=module_name, converted_name=converted_name, factory_factory_name=factory_factory_name, source_map=source_map) _CACHE[key][subkey] = converted_entity_info return converted_entity_info
def test_parser_compile_idempotent(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b node, _, _ = parser.parse_entity(test_fn, future_imports=()) self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource(compiler.ast_to_object([node])[0].test_fn))
def test_parser_compile_identity(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b node, _ = parser.parse_entity(test_fn, future_features=()) module, _, _ = compiler.ast_to_object(node) self.assertEqual(textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource(module.test_fn))
def test_parser_compile_idempotent(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b _, _, all_nodes = parser.parse_entity(test_fn) self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource(compiler.ast_to_object(all_nodes)[0].test_fn))
def _transform(self, f, strip_decorators): namespace = { 'self_transform_decorator': self_transform_decorator, 'simple_decorator': simple_decorator, 'converter_testing': converter_testing, } node, ctx = self.prepare(f, namespace, recursive=False, strip_decorators=strip_decorators) node = decorators.transform(node, ctx) import_line = '\n'.join(ctx.program.additional_imports) result, _ = compiler.ast_to_object(node, source_prefix=import_line) return getattr(result, f.__name__)
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def test_parser_compile_idempotent(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource( compiler.ast_to_object( parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def test_replace_name_with_call(self): template = """ def test_fn(): b = 5 def g(a): return 3 * a def f(): return g return foo """ source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(15, result.test_fn())
def test_replace_name_with_call(self): template = """ def test_fn(): b = 5 def g(a): return 3 * a def f(): return g return foo """ source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(15, result.test_fn())
def test_parser_compile_identity(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b node, _ = parser.parse_entity(test_fn, future_features=()) module, _, _ = compiler.ast_to_object(node) self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource(module.test_fn))
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names): """Returns a (possibly cached) factory for the converted result of entity.""" # The cache key is the entity's code object if it defined one, otherwise it's # the entity itself. Keying by the code object allows caching of functions # that are dynamically created e.g. in a loop. if hasattr(entity, '__code__'): key = entity.__code__ else: key = entity # The cache subkey encompases any conversion options on which the generated # code may depend. # The cached factory includes the necessary definitions to distinguish # between the global and non-global free variables. For this reason, the # cache subkey includes the names of the free non-globals. subkey = (program_ctx.options, frozenset(free_nonglobal_var_names)) # The cache values are _ConvertedEntityFactoryInfo objects. if _CACHE.has(key, subkey): # TODO(mdan): Check whether the module is still loaded. converted_entity_info = _CACHE[key][subkey] logging.log(3, 'Cache hit for entity %s key %s subkey %s: %s', entity, key, subkey, converted_entity_info) return converted_entity_info logging.log(1, 'Entity %s is not cached for key %s subkey %s', entity, key, subkey) nodes, converted_name, entity_info = convert_entity_to_ast( entity, program_ctx) namer = naming.Namer(entity_info.namespace) factory_factory_name = namer.new_symbol('create_converted_entity_factory', ()) factory_name = namer.new_symbol('create_converted_entity', ()) nodes = _wrap_into_dynamic_factory( nodes, converted_name, factory_factory_name, factory_name, free_nonglobal_var_names, entity_info.future_features) module, _, source_map = compiler.ast_to_object(nodes, include_source_map=True) module_name = module.__name__ converted_entity_info = _ConvertedEntityFactoryInfo( module_name=module_name, converted_name=converted_name, factory_factory_name=factory_factory_name, source_map=source_map) _CACHE[key][subkey] = converted_entity_info return converted_entity_info
def test_replace_code_block(self): template = """ def test_fn(a): block return a """ node = templates.replace( template, block=[ gast.Assign([gast.Name('a', None, None)], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn(1))
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def test_replace_code_block(self): template = """ def test_fn(a): block return a """ node = templates.replace( template, block=[ gast.Assign([ gast.Name('a', None, None) ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn(1))
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def compiled(self, node, namespace, symbols=()): source = None self.dynamic_calls = [] # See api.converted_call def converted_call(f, args, kwargs, unused_opts=None, unused_function_ctx=None): """Mock version of api.converted_call.""" self.dynamic_calls.append((args, kwargs)) if kwargs is None: kwargs = {} return f(*args, **kwargs) try: result, source, source_map = compiler.ast_to_object( node, include_source_map=True) # TODO(mdan): Move the unparsing from converter into pyct and reuse here. # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__.update(special_functions.__dict__) fake_ag.ConversionOptions = converter.ConversionOptions fake_ag.Feature = converter.Feature fake_ag.utils = utils fake_ag.FunctionScope = function_wrappers.FunctionScope result.ag__ = fake_ag result.ag_source_map__ = source_map for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args) return 7 class ConversionOptions(object): """Mock version of api.ConversionOptions.""" def __init__(self, recursive): self.recursive = recursive @classmethod def new(cls, recursive): cls(recursive) try: result, source = compiler.ast_to_object(node, include_source_map=True) result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( errors.rewrite_graph_construction_error) result.__dict__['ag__'] = fake_ag for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append( args[3:]) # args only; see api.converted_call return RESULT_OF_MOCK_CONVERTED_CALL try: result, source, source_map = compiler.ast_to_object( node, include_source_map=True) # TODO(mdan): Move the unparsing from converter into pyct and reuse here. # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__.update(special_functions.__dict__) fake_ag.ConversionOptions = converter.ConversionOptions fake_ag.Feature = converter.Feature fake_ag.utils = utils fake_ag.rewrite_graph_construction_error = ( errors.rewrite_graph_construction_error) fake_ag.function_scope = function_wrapping.function_scope result.ag__ = fake_ag result.ag_source_map__ = source_map for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def test_to_ast(self): opts = converter.ConversionOptions() opts_ast = opts.to_ast() template = ''' def test_fn(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _, _ = compiler.ast_to_object(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) reparsed_opts = reparsed.test_fn() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.user_requested, False) self.assertEqual(opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def test_to_ast(self): opts = converter.ConversionOptions() opts_ast = opts.to_ast() template = ''' def test_fn(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _, _ = compiler.ast_to_object(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) reparsed_opts = reparsed.test_fn() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion) self.assertEqual( opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def test_to_ast(self): opts = converter.ConversionOptions() namer = converter_testing.FakeNamer() program_ctx = converter.ProgramContext( options=opts, partial_types=None, autograph_module=None, uncompiled_modules=()) entity_info = transformer.EntityInfo( source_code='', source_file='<fragment>', namespace={}, arg_values=None, arg_types={}, owner_type=None) ctx = converter.EntityContext(namer, entity_info, program_ctx) opts_ast = opts.to_ast(ctx) template = ''' def test_fn(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _ = compiler.ast_to_object(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) reparsed_opts = reparsed.test_fn() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.verbose, reparsed_opts.verbose) self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion) self.assertEqual( opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def test_to_ast(self): opts = converter.ConversionOptions() namer = converter_testing.FakeNamer() program_ctx = converter.ProgramContext(options=opts, partial_types=None, autograph_module=None, uncompiled_modules=()) entity_info = transformer.EntityInfo(source_code='', source_file='<fragment>', namespace={}, arg_values=None, arg_types={}, owner_type=None) ctx = converter.EntityContext(namer, entity_info, program_ctx) opts_ast = opts.to_ast(ctx) template = ''' def test_fn(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _ = compiler.ast_to_object(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) reparsed_opts = reparsed.test_fn() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.verbose, reparsed_opts.verbose) self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion) self.assertEqual(opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None, strip_decorators=None): """Converts a Python entity into equivalent code that uses TensorFlow ops. Supported Python entities include: * functions * classes Classes are converted by converting all their methods into a new class. Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. verbose: bool, whether to output the compiled code in the logs. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. strip_decorators: Tuple[Callable], same as ConversionOptions.strip_decorators. Returns: Union[Callable, Type], the converted entity, which is the same kind as e (that is, a function is e is a function, a class if e is a class, etc.) but its code has been converted to use TF ops. Raises: ValueError: If the entity could not be converted. """ if strip_decorators is None: strip_decorators = () strip_decorators += (convert, do_not_convert, converted_call) program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=strip_decorators, partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.conversion_order): nodes.extend(program_ctx.dependency_cache[dep]) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled
def to_graph(entity, recursive=True, arg_values=None, arg_types=None, experimental_optional_features=converter.Feature.ALL, experimental_strip_decorators=None, experimental_verbose=converter.Verbosity.BRIEF, experimental_partial_types=None): """Converts a Python entity into a TensorFlow graph. Also see: `tf.autograph.to_code`, `tf.function`. Unlike `tf.function`, `to_graph` is a low-level transpiler that converts Python code to TensorFlow graph code. It does not implement any caching, variable management or create any actual ops, and is best used where greater control over the generated TensorFlow graph is desired. Another difference from `tf.function` is that `to_graph` will not wrap the graph into a TensorFlow function or a Python callable. Internally, `tf.function` uses `to_graph`. _Example Usage_ ```python def foo(x): if x > 0: y = x * x else: y = -x return y converted_foo = to_graph(foo) x = tf.constant(1) y = converted_foo(x) # converted_foo is a TensorFlow Op-like. assert is_tensor(y) ``` Supported Python entities include: * functions * classes * object methods Functions are converted into new functions with converted code. Classes are converted by generating a new class whose methods use converted code. Methods are converted into unbound function that have an additional first argument called `self`. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. arg_values: Optional dict of value hints for symbols including function arguments mapping string names to actual values. For example, `arg_values={'a': 1}` will map the variable `a` to the value `1`. arg_types: Optional dict of type hints for symbols including function arguments. Type hints allow specifying just the type of a variable, rather than a specific value. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. experimental_strip_decorators: A tuple specifying decorators that should be excluded from the compiled output. By default, when converting a function before the decorators are applied, the compiled output will include those decorators. experimental_verbose: The level of printing verbosity to use, as a `tf.autograph.experimental.Verbosity` value. experimental_partial_types: A `set` of `type` values, reserved for internal use. Returns: Same as `entity`, the converted Python function or class. Raises: ValueError: If the entity could not be converted. """ if experimental_strip_decorators is None: experimental_strip_decorators = () experimental_strip_decorators += (convert, do_not_convert, converted_call) program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, verbose=experimental_verbose, strip_decorators=experimental_strip_decorators, optional_features=experimental_optional_features), partial_types=experimental_partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.conversion_order): nodes.extend(program_ctx.dependency_cache[dep]) compiled_module, _ = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val for key, val in program_ctx.additional_symbols.items(): if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) if tf_inspect.isfunction(entity): compiled.__defaults__ = entity.__defaults__ if hasattr(compiled, '__globals__'): # Remove self to avoid circular references. This will probably only work # so long as the function is not reentrant. del compiled.__globals__[name] # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) return compiled
def to_graph(entity, recursive=True, arg_values=None, arg_types=None, experimental_optional_features=converter.Feature.ALL): """Converts a Python entity into a TensorFlow graph. Also see: `tf.autograph.to_code`, `tf.function`. Unlike `tf.function`, `to_graph` is a low-level transpiler that converts Python code to TensorFlow graph code. It does not implement any caching, variable management or create any actual ops, and is best used where greater control over the generated TensorFlow graph is desired. Another difference from `tf.function` is that `to_graph` will not wrap the graph into a TensorFlow function or a Python callable. Internally, `tf.function` uses `to_graph`. _Example Usage_ ```python def foo(x): if x > 0: y = x * x else: y = -x return y converted_foo = to_graph(foo) x = tf.constant(1) y = converted_foo(x) # converted_foo is a TensorFlow Op-like. assert is_tensor(y) ``` Supported Python entities include: * functions * classes * object methods Functions are converted into new functions with converted code. Classes are converted by generating a new class whose methods use converted code. Methods are converted into unbound function that have an additional first argument called `self`. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. arg_values: Optional dict of value hints for symbols including function arguments mapping string names to actual values. For example, `arg_values={'a': 1}` will map the variable `a` to the value `1`. arg_types: Optional dict of type hints for symbols including function arguments. Type hints allow specifying just the type of a variable, rather than a specific value. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. Returns: Same as `entity`, the converted Python function or class. Raises: ValueError: If the entity could not be converted. """ try: program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, optional_features=experimental_optional_features), autograph_module=tf_inspect.getmodule(to_graph)) nodes, name, namespace = conversion.entity_to_graph( entity, program_ctx, arg_values, arg_types) compiled_module, _ = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) if hasattr(entity, '__defaults__'): logging.log(3, 'Default args mapping: %s has: %s', entity, entity.__defaults__) compiled.__defaults__ = entity.__defaults__ else: logging.log(3, 'Default args mapping: %s has no __defaults__', entity) logging.log(3, 'Namespace of %s includes: %s', compiled, compiled_module.__dict__.keys()) if hasattr(compiled, '__globals__'): # Remove self to avoid circular references. This will probably only work # so long as the function is not reentrant. del compiled.__globals__[name] # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: # TODO(znado): change input problem errors into TransformError raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) return compiled except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: errors.report_internal_error(entity, e)
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None, strip_decorators=None): """Converts a Python entity into equivalent code that uses TensorFlow ops. Supported Python entities include: * functions * classes Classes are converted by converting all their methods into a new class. Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. verbose: bool, whether to output the compiled code in the logs. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. strip_decorators: Tuple[Callable], same as ConversionOptions.strip_decorators. Returns: Union[Callable, Type], the converted entity, which is the same kind as e (that is, a function is e is a function, a class if e is a class, etc.) but its code has been converted to use TF ops. Raises: ValueError: If the entity could not be converted. """ if strip_decorators is None: strip_decorators = () strip_decorators += (convert, do_not_convert, converted_call) program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, verbose=verbose, strip_decorators=strip_decorators), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.conversion_order): nodes.extend(program_ctx.dependency_cache[dep]) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) return compiled
def to_graph(entity, recursive=True, arg_values=None, arg_types=None, experimental_optional_features=converter.Feature.ALL, experimental_strip_decorators=None, experimental_verbose=converter.Verbosity.BRIEF, experimental_partial_types=None): """Converts a Python entity into a TensorFlow graph. Also see: `tf.autograph.to_code`, `tf.function`. Unlike `tf.function`, `to_graph` is a low-level transpiler that converts Python code to TensorFlow graph code. It does not implement any caching, variable management or create any actual ops, and is best used where greater control over the generated TensorFlow graph is desired. Another difference from `tf.function` is that `to_graph` will not wrap the graph into a TensorFlow function or a Python callable. Internally, `tf.function` uses `to_graph`. _Example Usage_ ```python def foo(x): if x > 0: y = x * x else: y = -x return y converted_foo = to_graph(foo) x = tf.constant(1) y = converted_foo(x) # converted_foo is a TensorFlow Op-like. assert is_tensor(y) ``` Supported Python entities include: * functions * classes * object methods Functions are converted into new functions with converted code. Classes are converted by generating a new class whose methods use converted code. Methods are converted into unbound function that have an additional first argument called `self`. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. arg_values: Optional dict of value hints for symbols including function arguments mapping string names to actual values. For example, `arg_values={'a': 1}` will map the variable `a` to the value `1`. arg_types: Optional dict of type hints for symbols including function arguments. Type hints allow specifying just the type of a variable, rather than a specific value. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. experimental_strip_decorators: A tuple specifying decorators that should be excluded from the compiled output. By default, when converting a function before the decorators are applied, the compiled output will include those decorators. experimental_verbose: The level of printing verbosity to use, as a `tf.autograph.experimental.Verbosity` value. experimental_partial_types: A `set` of `type` values, reserved for internal use. Returns: Same as `entity`, the converted Python function or class. Raises: ValueError: If the entity could not be converted. """ try: if experimental_strip_decorators is None: experimental_strip_decorators = () experimental_strip_decorators += (convert, do_not_convert, converted_call) program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, verbose=experimental_verbose, strip_decorators=experimental_strip_decorators, optional_features=experimental_optional_features), partial_types=experimental_partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.conversion_order): nodes.extend(program_ctx.dependency_cache[dep]) compiled_module, _ = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val for key, val in program_ctx.additional_symbols.items(): if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) if hasattr(entity, '__defaults__'): logging.log(3, 'Default args mapping: %s has: %s', entity, entity.__defaults__) compiled.__defaults__ = entity.__defaults__ else: logging.log(3, 'Default args mapping: %s has no __defaults__', entity) logging.log(3, 'Namespace of %s includes: %s', compiled, compiled_module.__dict__.keys()) if hasattr(compiled, '__globals__'): # Remove self to avoid circular references. This will probably only work # so long as the function is not reentrant. del compiled.__globals__[name] # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: # TODO(znado): change input problem errors into TransformError raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) return compiled except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: errors.report_internal_error(entity, e)