def test_simple_decorator(self): def simple_decorator(f): return lambda a: f(a) + 1 # The Python parser does capture decorators into the AST. # However, the interpreter desugars them upon load, and refering to the # decorated function at runtime usually loses any trace of the decorator. # Below is an example when that doesn't happen. def static_wrapper(): @simple_decorator def test_fn(a): # pylint:disable=unused-variable return a node = self.parse_and_analyze(static_wrapper, {'simple_decorator': simple_decorator}) node = node.body[0].body[0] node = decorators.transform(node, remove_decorators=()) result = compiler.ast_to_object( node, source_prefix=textwrap.dedent( tf_inspect.getsource(simple_decorator))) self.assertEqual(2, result.test_fn(1)) node = decorators.transform(node, remove_decorators=(simple_decorator, )) result = compiler.ast_to_object(node) self.assertEqual(1, result.test_fn(1))
def test_simple_decorator(self): def simple_decorator(f): return lambda a: f(a) + 1 # The Python parser does capture decorators into the AST. # However, the interpreter desugars them upon load, and refering to the # decorated function at runtime usually loses any trace of the decorator. # Below is an example when that doesn't happen. def static_wrapper(): @simple_decorator def test_fn(a): # pylint:disable=unused-variable return a node = self.parse_and_analyze(static_wrapper, {'simple_decorator': simple_decorator}) node = node.body[0].body[0] node = decorators.transform(node, remove_decorators=()) result = compiler.ast_to_object( node, source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator))) self.assertEqual(2, result.test_fn(1)) node = decorators.transform(node, remove_decorators=(simple_decorator,)) result = compiler.ast_to_object(node) self.assertEqual(1, result.test_fn(1))
def to_graph(f, arg_value_hints=None): """Compile a Python function into equivalent TensorFlow code. Args: f: A Python function with arbitrary arguments and return values. arg_value_hints: A dict mapping parameter names to objects that can hint at the type of those parameters. Returns: A function with a signature identical to `f`, but which when executed it creates TF a graph that has the same functionality as the original function. """ conversion_map = conversion.ConversionMap() _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.append(parser.parse_str(import_line)) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? compiled_node.__dict__.update(six.get_function_globals(f)) compiled_fn = getattr(compiled_node, name) return compiled_fn
def to_graph(o, arg_value_hints=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: o: A Python function or class. arg_value_hints: A dict mapping parameter names to objects that can hint at the type of those parameters. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap() _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.append(parser.parse_str(import_line)) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(o): compiled_node.__dict__.update(six.get_function_globals(o)) compiled_fn = getattr(compiled_node, name) return compiled_fn
def test_continue_deeply_nested(self): def test_fn(x): v = [] u = [] w = [] while x > 0: x -= 1 if x % 2 == 0: if x % 3 != 0: u.append(x) else: w.append(x) continue v.append(x) return v, u, w node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) node = continue_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) self.assertEqual(test_fn(0), result.test_fn(0)) self.assertEqual(test_fn(1), result.test_fn(1)) self.assertEqual(test_fn(2), result.test_fn(2)) self.assertEqual(test_fn(3), result.test_fn(3)) self.assertEqual(test_fn(4), result.test_fn(4))
def test_function_decorator(self): def function_decorator(): def decorator(f): return lambda a: f(a) + 1 return decorator # The Python parser does capture decorators into the AST. # However, the interpreter desugars them on load, and refering to the # decorated function at runtime usually loses any trace of the decorator. # Below is an example when that doesn't happen. def static_wrapper(): @function_decorator() def test_fn(a): # pylint:disable=unused-variable return a node = self.parse_and_analyze( static_wrapper, {'function_decorator': function_decorator}) node = node.body[0].body[0] node = decorators.transform(node, remove_decorators=()) # Since the decorator is not removed, we need to include its source # code. We cannot do it after the fact because decorators are executed # on load. result, _ = compiler.ast_to_object( node, source_prefix=textwrap.dedent( tf_inspect.getsource(function_decorator))) self.assertEqual(2, result.test_fn(1)) node = decorators.transform(node, remove_decorators=(function_decorator, )) with self.compiled(node) as result: self.assertEqual(1, result.test_fn(1))
def test_basic_break_for_loop(self): def test_fn(a): v = [] for x in a: x -= 1 if x % 2 == 0: break v.append(x) return v # The break is incompletely canonicalized for for loops. Everything is # in place except for the condition verification. def test_equiv_fn(a): v = [] for x in a: x -= 1 if x % 2 == 0: continue v.append(x) return v node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = break_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) # The break is incompletely canonicalized. Everything is in place, but # the loop does not break. self.assertEqual(test_equiv_fn([]), result.test_fn([])) self.assertEqual(test_equiv_fn([1]), result.test_fn([1])) self.assertEqual(test_equiv_fn([2]), result.test_fn([2])) self.assertEqual(test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4]))
def test_basic_break_for_loop(self): def test_fn(a): v = [] for x in a: x -= 1 if x % 2 == 0: break v.append(x) return v # The break is incompletely canonicalized for for loops. Everything is # in place except for the condition verification. def test_equiv_fn(a): v = [] for x in a: x -= 1 if x % 2 == 0: continue v.append(x) return v node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = break_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) # The break is incompletely canonicalized. Everything is in place, but # the loop does not break. self.assertEqual(test_equiv_fn([]), result.test_fn([])) self.assertEqual(test_equiv_fn([1]), result.test_fn([1])) self.assertEqual(test_equiv_fn([2]), result.test_fn([2])) self.assertEqual(test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4]))
def test_continue_deeply_nested(self): def test_fn(x): v = [] u = [] w = [] while x > 0: x -= 1 if x % 2 == 0: if x % 3 != 0: u.append(x) else: w.append(x) continue v.append(x) return v, u, w node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = break_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) self.assertEqual(test_fn(0), result.test_fn(0)) self.assertEqual(test_fn(1), result.test_fn(1)) self.assertEqual(test_fn(2), result.test_fn(2)) self.assertEqual(test_fn(3), result.test_fn(3)) self.assertEqual(test_fn(4), result.test_fn(4))
def test_uncompiled_modules(self): def test_fn(a): a = math_ops.multiply(a, constant_op.constant(2)) a = math_ops.add(a, constant_op.constant(1)) return a node = self.parse_and_analyze( test_fn, { 'math_ops': math_ops, 'constant_op': constant_op }, namer=TestNamer()) node = call_trees.transform(node, self.ctx, set(((math_ops.__name__,), (constant_op.__name__,))), ()) result = compiler.ast_to_object(node) setattr(result, 'math_ops', math_ops) setattr(result, 'constant_op', constant_op) with self.test_session() as sess: # Not renamed, because the converter doesn't rename the definition itself. # (the caller is responsible for that). result_tensor = result.test_fn(constant_op.constant(1)) result_val = sess.run(result_tensor) self.assertEquals(3, result_val)
def test_function_decorator(self): def function_decorator(): def decorator(f): return lambda a: f(a) + 1 return decorator # The Python parser does capture decorators into the AST. # However, the interpreter desugars them on load, and refering to the # decorated function at runtime usually loses any trace of the decorator. # Below is an example when that doesn't happen. def static_wrapper(): @function_decorator() def test_fn(a): # pylint:disable=unused-variable return a node = self.parse_and_analyze(static_wrapper, {'function_decorator': function_decorator}) node = node.body[0].body[0] node = decorators.transform(node, remove_decorators=()) # Since the decorator is not removed, we need to include its source # code. We cannot do it after the fact because decorators are executed # on load. result, _ = compiler.ast_to_object( node, source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator))) self.assertEqual(2, result.test_fn(1)) node = decorators.transform(node, remove_decorators=(function_decorator,)) with self.compiled(node) as result: self.assertEqual(1, result.test_fn(1))
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source = compiler.ast_to_object(node) expected_source = """ def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def test_uncompiled_modules(self): def test_fn(a): a = math_ops.multiply(a, constant_op.constant(2)) a = math_ops.add(a, constant_op.constant(1)) return a node = self._parse_and_analyze(test_fn, { 'math_ops': math_ops, 'constant_op': constant_op }) node = call_trees.transform(node, TestNamer(), {}, set(((math_ops.__name__,), (constant_op.__name__,))), ()) result = compiler.ast_to_object(node) setattr(result, 'math_ops', math_ops) setattr(result, 'constant_op', constant_op) with self.test_session() as sess: # Not renamed, because the converter doesn't rename the definition itself. # (the caller is responsible for that). result_tensor = result.test_fn(constant_op.constant(1)) result_val = sess.run(result_tensor) self.assertEquals(3, result_val)
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) mod = compiler.ast_to_object(node) self.assertEqual(2, mod.f(1)) with open(mod.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(""" def f(a): return a + 1 """).strip(), temp_output.read().strip())
def to_graph(f, arg_value_hints=None): """Compile a Python function into equivalent TensorFlow code. Args: f: A Python function with arbitrary arguments and return values. arg_value_hints: A dict mapping parameter names to objects that can hint at the type of those parameters. Returns: A function with a signature identical to `f`, but which when executed it creates TF a graph that has the same functionality as the original function. """ conversion_map = conversion.ConversionMap() _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.append(parser.parse_str(import_line)) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? compiled_node.__dict__.update(six.get_function_globals(f)) compiled_fn = getattr(compiled_node, name) return compiled_fn
def _remover_wrapper(self, f, remove_decorators): namespace = { 'self_removing_decorator': self_removing_decorator, 'simple_decorator': simple_decorator } node = self.parse_and_analyze(f, namespace) node, _ = decorators.transform(node, remove_decorators=remove_decorators) result, _ = compiler.ast_to_object(node) return getattr(result, f.__name__)
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, graph_ready, convert_inline), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.append(parser.parse_str(import_line)) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): compiled_node.__dict__.update(six.get_function_globals(e)) compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn())
def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. output = parser.parse_str('b = 3') output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d), ) result, _ = compiler.ast_to_object(output) self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) print(d)
def compiled(self, node, *symbols): source = '<compile failed>' try: result, source = compiler.ast_to_object(node) result.tf = self.make_fake_tf(*symbols) result.py2tf_utils = utils yield result except Exception: # pylint:disable=broad-except print('Offending compiled code:\n%s' % source) raise
def _remover_wrapper(self, f, remove_decorators): namespace = { 'self_removing_decorator': self_removing_decorator, 'simple_decorator': simple_decorator } node = self.parse_and_analyze(f, namespace) node, _ = decorators.transform(node, remove_decorators=remove_decorators) result, _ = compiler.ast_to_object(node) return getattr(result, f.__name__)
def test_noop(self): def test_fn(a): return a node = self.parse_and_analyze(test_fn, {}) node, deps = decorators.transform(node, remove_decorators=()) result, _ = compiler.ast_to_object(node) self.assertFalse(deps) self.assertEqual(1, result.test_fn(1))
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn())
def test_transform(self): def test_fn(a): print(a) node = self._parse_and_analyze(test_fn, {'print': print}) node = print_functions.transform(node) result = compiler.ast_to_object(node) result.test_fn('a') self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. output = parser.parse_str('b = 3') output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) result, _ = compiler.ast_to_object(output) self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) print(d)
def compiled(self, node, *symbols): source = '<compile failed>' try: result, source = compiler.ast_to_object(node) result.tf = self.make_fake_tf(*symbols) result.py2tf_utils = utils yield result except Exception: # pylint:disable=broad-except print('Offending compiled code:\n%s' % source) raise
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, graph_ready, convert_inline), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.extend(parser.parse_str(import_line).body) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): compiled_node.__dict__.update(inspect_utils.getnamespace(e)) compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def test_bool_ops(self): def test_fn(a, b, c): return (a or b) and (a or b or c) node = self.parse_and_analyze(test_fn, {}) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True)))
def test_replace_variable(self): def template(a): # pylint:disable=unused-argument def test_fn(a): # pylint:disable=unused-variable a += 1 a = 2 * a + 1 return b # pylint:disable=undefined-variable node = templates.replace( template, a=gast.Name('b', gast.Load(), None))[0] result = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_noop(self): def test_fn(a): return a node = self.parse_and_analyze(test_fn, {}) node, deps = decorators.transform(node, remove_decorators=()) result, _ = compiler.ast_to_object(node) self.assertFalse(deps) self.assertEqual(1, result.test_fn(1))
def test_replace_variable(self): def template(a): # pylint:disable=unused-argument def test_fn(a): # pylint:disable=unused-variable a += 1 a = 2 * a + 1 return b # pylint:disable=undefined-variable node = templates.replace(template, a=gast.Name('b', gast.Load(), None))[0] result = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self): def template(fname): # pylint:disable=unused-argument def fname(a): # pylint:disable=function-redefined a += 1 a = 2 * a + 1 return a node = templates.replace( template, fname=gast.Name('test_fn', gast.Load(), None))[0] result = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_transform(self): def test_fn(a): print(a) node = self.parse_and_analyze(test_fn, {'print': print}) node = print_functions.transform(node) result = compiler.ast_to_object(node) result.test_fn('a') self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
def test_replace_variable(self): template = """ def test_fn(a): a += 1 a = 2 * a + 1 return b """ node = templates.replace(template, a='b')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_replace_variable(self): template = """ def test_fn(a): a += 1 a = 2 * a + 1 return b """ node = templates.replace(template, a='b')[0] result = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self): template = """ def fname(a): a += 1 a = 2 * a + 1 return a """ node = templates.replace(template, fname='test_fn')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self): template = """ def fname(a): a += 1 a = 2 * a + 1 return a """ node = templates.replace(template, fname='test_fn')[0] result = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_bool_ops(self): def test_fn(a, b, c): return (a or b) and (a or b or c) node = self.parse_and_analyze(test_fn, {}) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True)))
def test_equals(self): def test_fn(a, b): return a == b node = self.parse_and_analyze(test_fn, {}) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(1, 1))) self.assertFalse(sess.run(result.test_fn(1, 2)))
def test_len(self): def test_fn(a): return len(a) node = self.parse_and_analyze(test_fn, {'len': len}) node = builtin_functions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', array_ops) with self.test_session() as sess: self.assertEqual( 3, sess.run(result.test_fn(constant_op.constant([0, 0, 0]))))
def test_replace_function_name(self): def template(fname): # pylint:disable=unused-argument def fname(a): # pylint:disable=function-redefined a += 1 a = 2 * a + 1 return a node = templates.replace(template, fname=gast.Name('test_fn', gast.Load(), None))[0] result = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def test_equals(self): def test_fn(a, b): return a == b node = self.parse_and_analyze(test_fn, {}) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(1, 1))) self.assertFalse(sess.run(result.test_fn(1, 2)))
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def test_while_single_var(self): def test_fn(n): while n > 0: n -= 1 return n node = self._parse_and_analyze(test_fn, {}) node = control_flow.transform(node, TestNamer()) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
def test_while_single_var(self): def test_fn(n): while n > 0: n -= 1 return n node = self._parse_and_analyze(test_fn, {}) node = control_flow.transform(node, TestNamer()) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
def test_len(self): def test_fn(a): return len(a) node = self.parse_and_analyze(test_fn, {'len': len}) node = builtin_functions.transform(node, self.ctx) result = compiler.ast_to_object(node) setattr(result, 'tf', array_ops) with self.test_session() as sess: self.assertEqual(3, sess.run( result.test_fn(constant_op.constant([0, 0, 0]))))
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def test_parser_compile_idempotent(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource( compiler.ast_to_object( parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
def test_if_single_var(self): def test_fn(n): if n > 0: n = -n return n node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) node = control_flow.transform(node, self.ctx) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
def test_replace_name_with_call(self): template = """ def test_fn(): b = 5 def g(a): return 3 * a def f(): return g return foo """ source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(15, result.test_fn())
def test_basic_for(self): def test_fn(l): s = 0 for e in l: s += e return s node = self._parse_and_analyze(test_fn, {}) node = for_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) l = [1, 2, 3] self.assertEqual(test_fn(l), result.test_fn(l)) l = [] self.assertEqual(test_fn(l), result.test_fn(l))
def test_code_block(self): def template(block): # pylint:disable=unused-argument def test_fn(a): # pylint:disable=unused-variable block # pylint:disable=pointless-statement return a node = templates.replace( template, block=[ gast.Assign([gast.Name('a', gast.Store(), None)], gast.BinOp(gast.Name('a', gast.Load(), None), gast.Add(), gast.Num(1))), ] * 2)[0] result = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn(1))
def test_replace_name_with_call(self): template = """ def test_fn(): b = 5 def g(a): return 3 * a def f(): return g return foo """ source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(15, result.test_fn())
def test_basic_for(self): def test_fn(l): s = 0 for e in l: s += e return s node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) node = for_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) l = [1, 2, 3] self.assertEqual(test_fn(l), result.test_fn(l)) l = [] self.assertEqual(test_fn(l), result.test_fn(l))
def test_code_block(self): template = """ def test_fn(a): block return a """ node = templates.replace( template, block=[ gast.Assign([ gast.Name('a', None, None) ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] result = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn(1))
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)