def _wrap_to_py_func_no_return(self, node): func_qn = anno.getanno(node.func, anno.Basic.QN) args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) wrapper_name = self.context.namer.new_symbol(func_qn.ssf(), args_scope.referenced) wrapper_args = [] for arg in node.args: if anno.hasanno(arg, anno.Basic.QN): arg_qn = anno.getanno(arg, anno.Basic.QN) else: arg_qn = qual_names.QN('arg') wrapper_args.append( self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced)) # TODO(mdan): Properly handle varargs, kwargs, etc. # TODO(mdan): This is best handled as a dynamic dispatch. # That way we can separate tensors from non-tensor args. template = """ def wrapper(wrapper_args): call(wrapper_args) return 1 tf.py_func(wrapper, original_args, [tf.int64]) """ wrapper_def, call_expr = templates.replace(template, call=node.func, wrapper=wrapper_name, original_args=gast.List( elts=node.args, ctx=None), wrapper_args=wrapper_args) anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) return (wrapper_def, call_expr)
def test_basic(self): a = qual_names.QN('a') self.assertEqual(a.qn, ('a',)) self.assertEqual(str(a), 'a') self.assertEqual(a.ssf(), 'a') self.assertEqual(a.ast().id, 'a') self.assertFalse(a.is_composite()) with self.assertRaises(ValueError): _ = a.parent a_b = qual_names.QN(a, 'b') self.assertEqual(a_b.qn, ('a', 'b')) self.assertEqual(str(a_b), 'a.b') self.assertEqual(a_b.ssf(), 'a_b') self.assertEqual(a_b.ast().value.id, 'a') self.assertEqual(a_b.ast().attr, 'b') self.assertTrue(a_b.is_composite()) self.assertEqual(a_b.parent.qn, ('a',)) a2 = qual_names.QN(a) self.assertEqual(a2.qn, ('a',)) with self.assertRaises(ValueError): _ = a.parent a_b2 = qual_names.QN(a_b) self.assertEqual(a_b2.qn, ('a', 'b')) self.assertEqual(a_b2.parent.qn, ('a',)) self.assertTrue(a2 == a) self.assertFalse(a2 is a) self.assertTrue(a_b.parent == a) self.assertTrue(a_b2.parent == a) self.assertTrue(a_b2 == a_b) self.assertFalse(a_b2 is a_b) self.assertFalse(a_b2 == a) with self.assertRaises(ValueError): qual_names.QN('a', 'b')
def test_rename_symbols(self): node = ast.Tuple([ ast.Name('a', ast.Load()), ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute(ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( node, { qual_names.QN('a'): qual_names.QN('renamed_a'), qual_names.QN('b.c'): qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) self.assertEqual(node.elts[1].id, 'b') self.assertEqual(node.elts[2].id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) self.assertEqual(node.elts[3].value.id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) # NOTE: We can't guard object attributes because they may not be writable. guarded_args = tuple(s for s in args_scope.used if not s.is_composite()) # TODO(mdan): Include all arguments which depended on guarded_args too. # For example, the following will still cause a race: # tf.assign(a, a + 1) # b = a + 1 # tf.assign(a, a + 1) # Control deps here should include `b` # c = b + 1 # Or maybe we should just raise an "unsafe assign" error? if guarded_args: # The aliases may need new names to avoid incorrectly making them local. # TODO(mdan): This is brutal. It will even rename modules - any fix? need_alias = tuple(s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( self.context.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: s, = guarded_args aliased_guarded_args = alias_map.get(s, s) else: aliased_guarded_args = gast.Tuple( [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ with py2tf_utils.control_dependency_on_returns(tf, call): aliased_guarded_args = py2tf_utils.alias_tensors(tf, guarded_args) """ control_deps_guard = templates.replace( template, call=node.value, aliased_guarded_args=aliased_guarded_args, guarded_args=guarded_args)[-1] else: alias_map = {} template = """ with py2tf_utils.control_dependency_on_returns(tf, call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] control_deps_guard.body = [] node = control_deps_guard anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, (node.body, alias_map)) return node
def test_hashable(self): d = {qual_names.QN('a'): 'a', qual_names.QN('b'): 'b'} self.assertEqual(d[qual_names.QN('a')], 'a') self.assertEqual(d[qual_names.QN('b')], 'b') self.assertTrue(qual_names.QN('c') not in d)