def test_replace_code_block(self): template = """ def test_fn(a): block return a """ class ShouldBeReplaced(object): pass node = templates.replace( template, block=[ gast.Assign( [ gast.Name('a', ctx=ShouldBeReplaced, annotation=None, type_comment=None) ], gast.BinOp( gast.Name('a', ctx=ShouldBeReplaced, annotation=None, type_comment=None), gast.Add(), gast.Constant(1, kind=None)), ), ] * 2)[0] result, _, _ = loader.load_ast(node) self.assertEqual(3, result.test_fn(1))
def test_load_ast(self): node = gast.FunctionDef( name='f', args=gast.arguments(args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp(op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source, _ = loader.load_ast(node) expected_source = """ # coding=utf-8 def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def module(self): """Constructs an `instructions.Module` for this `Context`. Returns: module: An `instructions.Module` representing the batched computation defined by all the functions decorated with `batch` in this `Context` so far. """ if self._module is not None: return self._module ab = dsl.ProgramBuilder() function_objects = [] for function, type_inference in self._tagged_functions: declared = ab.declare_function(function.__name__, type_inference) function_objects.append(declared) for function, _ in self._tagged_functions: name = function.__name__ node, ctx = _parse_and_analyze(function, self.function_names()) node = _AutoBatchingTransformer(self.function_names(), [ scoped_name for scoped_name, _ in _environment(function, [name]) ], ctx).visit(node) builder_module, _, _ = loader.load_ast(node) for scoped_name, val in _environment(function, [name]): builder_module.__dict__[scoped_name] = val builder = getattr(builder_module, name) builder(ab, function_objects) self._module = ab.module() return self._module
def test_basic(self): def test_function(): a = 0 return a node, _ = parser.parse_entity(test_function, future_features=()) node = anf.transform(node, self._simple_context()) result, _, _ = loader.load_ast(node) self.assertEqual(test_function(), result.test_function())
def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. node = parser.parse('def f(b): pass') node.body.append(ast.Return(d)) result, _, _ = loader.load_ast(node) self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _, _ = loader.load_ast(node) self.assertEqual(3, result.test_fn())
def _get_source(self, node): try: source, _ = loader.load_ast(node) return source # pylint: disable=broad-except # This function is used for error reporting. If an exception occurs here, # it should be suppressed, in favor of emitting as informative a message # about the original error as possible. except Exception: return '<could not convert AST to source>'
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result, _, _ = loader.load_ast(node) self.assertEqual((2, 3), result.test_fn(2, 3))
def test_replace_variable(self): template = """ def test_fn(a): a += 1 a = 2 * a + 1 return b """ node = templates.replace(template, a='b')[0] result, _, _ = loader.load_ast(node) self.assertEqual(7, result.test_fn(2))
def test_replace_function_name(self): template = """ def fname(a): a += 1 a = 2 * a + 1 return a """ node = templates.replace(template, fname='test_fn')[0] result, _, _ = loader.load_ast(node) self.assertEqual(7, result.test_fn(2))
def test_parse_load_identity(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b node, _ = parser.parse_entity(test_fn, future_features=()) module, _, _ = loader.load_ast(node) self.assertEqual(textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource(module.test_fn))
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _, _ = loader.load_ast(node) mod = imp.new_module('test') mod.b = 3 self.assertEqual(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def compiled(self, node, namespace, symbols=()): source = None self.dynamic_calls = [] # See api.converted_call def converted_call(f, args, kwargs, unused_opts=None, unused_function_ctx=None): """Mock version of api.converted_call.""" self.dynamic_calls.append((args, kwargs)) if kwargs is None: kwargs = {} return f(*args, **kwargs) def fake_autograph_artifact(f): setattr(f, 'fake_autograph_artifact', True) return f try: result, source, source_map = loader.load_ast( node, include_source_map=True) # TODO(mdan): Move the unparsing from converter into pyct and reuse here. # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__.update(special_functions.__dict__) fake_ag.ConversionOptions = converter.ConversionOptions fake_ag.Feature = converter.Feature fake_ag.utils = utils fake_ag.FunctionScope = function_wrappers.FunctionScope fake_ag.autograph_artifact = fake_autograph_artifact result.ag__ = fake_ag result.ag_source_map__ = source_map for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending source code:\n%s' % source) raise
def test_replace_name_with_call(self): template = """ def test_fn(): b = 5 def g(a): return 3 * a def f(): return g return foo """ source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _, _ = loader.load_ast(node) self.assertEqual(15, result.test_fn())
def test_parse_load_identity(self): def test_fn(x): a = True b = '' if a: b = (x + 1) return b node, _ = parser.parse_entity(test_fn, future_features=()) module, _, _ = loader.load_ast(node) source = tf_inspect.getsource(module.test_fn) expected_node_src = textwrap.dedent(tf_inspect.getsource(test_fn)) self.assertAstMatches(node, source) self.assertAstMatches(node, expected_node_src)
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _, _ = loader.load_ast(node) self.assertEqual(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def test_parse_load_identity(self): def test_fn(x): a = True b = '' if a: b = (x + 1) return b node, _ = parser.parse_entity(test_fn, future_features=()) module, _, _ = loader.load_ast(node) # astunparse uses fixed 4-space indenting. self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource(module.test_fn).replace(' ', ' '))
def test_load_ast(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[ gast.Name( 'a', ctx=gast.Param(), annotation=None, type_comment=None) ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name( 'a', ctx=gast.Load(), annotation=None, type_comment=None), right=gast.Constant(1, kind=None))) ], decorator_list=[], returns=None, type_comment=None) module, source, _ = loader.load_ast(node) expected_node_src = """ # coding=utf-8 def f(a): return (a + 1) """ expected_node_src = textwrap.dedent(expected_node_src) self.assertAstMatches(node, source) self.assertAstMatches(node, expected_node_src) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertAstMatches(node, temp_output.read())
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names): """Returns a (possibly cached) factory for the converted result of entity.""" # The cache subkey encompasses any conversion options on which the generated # code may depend. # The cached factory includes the necessary definitions to distinguish # between the global and non-global free variables. For this reason, the # cache subkey includes the names of the free non-globals. subkey = (program_ctx.options, frozenset(free_nonglobal_var_names)) with _CACHE_LOCK: # The cache values are _ConvertedEntityFactoryInfo objects. if _CACHE.has(entity, subkey): # TODO(mdan): Check whether the module is still loaded. converted_entity_info = _CACHE[entity][subkey] logging.log(3, 'Cache hit for entity %s subkey %s: %s', entity, subkey, converted_entity_info) return converted_entity_info logging.log(1, 'Entity %s is not cached for subkey %s', entity, subkey) nodes, converted_name, entity_info = convert_entity_to_ast( entity, program_ctx) namer = naming.Namer(entity_info.namespace) factory_factory_name = namer.new_symbol( 'create_converted_entity_factory', ()) factory_name = namer.new_symbol('create_converted_entity', ()) nodes = _wrap_into_dynamic_factory(nodes, converted_name, factory_factory_name, factory_name, free_nonglobal_var_names, entity_info.future_features) module, _, source_map = loader.load_ast(nodes, include_source_map=True) module_name = module.__name__ converted_entity_info = _ConvertedEntityFactoryInfo( module_name=module_name, converted_name=converted_name, factory_factory_name=factory_factory_name, source_map=source_map) _CACHE[entity][subkey] = converted_entity_info return converted_entity_info
def create(self, nodes, namer, inner_factory_name='inner_factory', outer_factory_name='outer_factory', future_features=()): """Initializes a function.""" if self._unbound_factory is not None: raise ValueError('double initialization; create a new object instead') inner_factory_name = namer.new_symbol(inner_factory_name, ()) outer_factory_name = namer.new_symbol(outer_factory_name, ()) nodes = _wrap_into_factory(nodes, self._name, inner_factory_name, outer_factory_name, self._freevars, self._extra_locals.keys(), future_features) module, _, source_map = loader.load_ast( nodes, include_source_map=True) outer_factory = getattr(module, outer_factory_name) self._unbound_factory = outer_factory() self.module = module self.source_map = source_map
def test_to_ast(self): opts = converter.ConversionOptions() opts_ast = opts.to_ast() template = ''' def test_fn(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _, _ = loader.load_ast(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) reparsed_opts = reparsed.test_fn() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.user_requested, False) self.assertEqual( opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def test_to_ast(self): opts = converter.ConversionOptions() opts_ast = opts.to_ast() template = ''' def f(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _, _ = loader.load_ast(opts_packed) fake_ag = imp.new_module('fake_ag') fake_ag.ConversionOptions = converter.ConversionOptions fake_ag.Feature = converter.Feature reparsed.ag__ = fake_ag reparsed_opts = reparsed.f() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.user_requested, False) self.assertEqual(opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def debug_print_src(self, node): """Helper method useful for debugging. Prints the AST as code.""" if __debug__: print(loader.load_ast(node)) return node