def test_fails_with_bad_types(self): function = computation_types.FunctionType( None, computation_types.TensorType(tf.int32)) federated = computation_types.FederatedType(tf.int32, placements.CLIENTS) tuple_on_function = computation_types.StructType([federated, function]) def foo(x): # pylint: disable=unused-variable del x # Unused. with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type {int32}@CLIENTS'): computation_wrapper_instances.tensorflow_wrapper(foo, federated) # pylint: disable=anomalous-backslash-in-string with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type \( -> int32\)'): computation_wrapper_instances.tensorflow_wrapper(foo, function) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type placement'): computation_wrapper_instances.tensorflow_wrapper( foo, computation_types.PlacementType()) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type T'): computation_wrapper_instances.tensorflow_wrapper( foo, computation_types.AbstractType('T')) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type <{int32}@CLIENTS,\( ' '-> int32\)>'): computation_wrapper_instances.tensorflow_wrapper(foo, tuple_on_function)
def test_tf_wrapper_with_tf_add(self): foo = computation_wrapper_instances.tensorflow_wrapper( tf.add, (tf.int32, tf.int32)) self.assertEqual(str(foo.type_signature), '(<int32,int32> -> int32)') # TODO(b/113112885): Remove this protected member access as noted above. comp = foo._computation_proto # pylint: disable=protected-access self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') x = tf.compat.v1.placeholder(tf.int32) y = tf.compat.v1.placeholder(tf.int32) result = tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), { comp.tensorflow.parameter.tuple.element[0].tensor.tensor_name: x, comp.tensorflow.parameter.tuple.element[1].tensor.tensor_name: y }, [comp.tensorflow.result.tensor.tensor_name]) with self.session() as sess: def _run(n): return sess.run(result, feed_dict={x: n, y: 3}) results = [_run(n) for n in [1, 20, 5, 10, 30]] self.assertEqual(results, [[4], [23], [8], [13], [33]])
def test_invoke_with_no_arg_fn(self): def foo(): return 10 foo = computation_wrapper_instances.tensorflow_wrapper(foo) self.assertEqual(foo.type_signature.compact_representation(), '( -> int32)')
def test_invoke_with_typed_fn(self): def foo(x): return x > 10 foo = computation_wrapper_instances.tensorflow_wrapper(foo, tf.int32) self.assertEqual(foo.type_signature.compact_representation(), '(int32 -> bool)')
def test_takes_structured_tuple_typed(self): MyType = collections.namedtuple('MyType', ['x', 'y']) # pylint: disable=invalid-name @tf.function def foo(x, t, l, odict, my_type): self.assertIsInstance(x, tf.Tensor) self.assertIsInstance(t, tuple) self.assertIsInstance(l, list) self.assertIsInstance(odict, collections.OrderedDict) self.assertIsInstance(my_type, MyType) return x + t[0] + l[0] + odict['foo'] + my_type.x foo = computation_wrapper_instances.tensorflow_wrapper( foo, [ tf.int32, (tf.int32, tf.int32), [tf.int32, tf.int32], collections.OrderedDict([('foo', tf.int32), ('bar', tf.int32)]), MyType(tf.int32, tf.int32), ]) self.assertEqual( foo.type_signature.compact_representation(), '(<x=int32,t=<int32,int32>,l=<int32,int32>,odict=<foo=int32,bar=int32>,my_type=<x=int32,y=int32>> -> int32)' )
def test_takes_tuple_typed(self): @tf.function def foo(t): return t[0] + t[1] foo = computation_wrapper_instances.tensorflow_wrapper( foo, (tf.int32, tf.int32)) self.assertEqual(foo.type_signature.compact_representation(), '(<int32,int32> -> int32)')
def test_invoke_with_polymorphic_lambda(self): foo = lambda x: x > 10 foo = computation_wrapper_instances.tensorflow_wrapper(foo) concrete_fn = foo.fn_for_argument_type( computation_types.TensorType(tf.int32)) self.assertEqual(concrete_fn.type_signature.compact_representation(), '(int32 -> bool)') concrete_fn = foo.fn_for_argument_type( computation_types.TensorType(tf.float32)) self.assertEqual(concrete_fn.type_signature.compact_representation(), '(float32 -> bool)')
def test_takes_namedtuple_typed(self): MyType = collections.namedtuple('MyType', ['x', 'y']) # pylint: disable=invalid-name @tf.function def foo(x): self.assertIsInstance(x, MyType) return x.x + x.y foo = computation_wrapper_instances.tensorflow_wrapper( foo, MyType(tf.int32, tf.int32)) self.assertEqual(foo.type_signature.compact_representation(), '(<x=int32,y=int32> -> int32)')
def test_with_variable(self): v_slot = [] @tf.function(autograph=False) def foo(x): if not v_slot: v_slot.append(tf.Variable(0)) v = v_slot[0] v.assign(1) return v + x foo = computation_wrapper_instances.tensorflow_wrapper(foo, tf.int32) self.assertEqual(foo.type_signature.compact_representation(), '(int32 -> int32)')
def test_returns_tuple_structured(self): MyType = collections.namedtuple('MyType', ['x', 'y']) # pylint: disable=invalid-name @tf.function def foo(): return ( 1, (2, 3.0), [4, 5.0], collections.OrderedDict([('foo', 6), ('bar', 7.0)]), MyType(True, False), ) foo = computation_wrapper_instances.tensorflow_wrapper(foo) # pyformat: disable self.assertEqual( foo.type_signature.compact_representation(), '( -> <int32,<int32,float32>,<int32,float32>,<foo=int32,bar=float32>,<x=bool,y=bool>>)' )
def test_takes_tuple_polymorphic(self): def foo(t): return t[0] + t[1] foo = computation_wrapper_instances.tensorflow_wrapper(foo) concrete_fn = foo.fn_for_argument_type( computation_types.StructType([ computation_types.TensorType(tf.int32), computation_types.TensorType(tf.int32), ])) self.assertEqual(concrete_fn.type_signature.compact_representation(), '(<int32,int32> -> int32)') concrete_fn = foo.fn_for_argument_type( computation_types.StructType([ computation_types.TensorType(tf.float32), computation_types.TensorType(tf.float32), ])) self.assertEqual(concrete_fn.type_signature.compact_representation(), '(<float32,float32> -> float32)')
def test_takes_namedtuple_polymorphic(self): MyType = collections.namedtuple('MyType', ['x', 'y']) # pylint: disable=invalid-name @tf.function def foo(t): self.assertIsInstance(t, MyType) return t.x + t.y foo = computation_wrapper_instances.tensorflow_wrapper(foo) concrete_fn = foo.fn_for_argument_type( computation_types.StructWithPythonType([('x', tf.int32), ('y', tf.int32)], MyType)) self.assertEqual(concrete_fn.type_signature.compact_representation(), '(<x=int32,y=int32> -> int32)') concrete_fn = foo.fn_for_argument_type( computation_types.StructWithPythonType([('x', tf.float32), ('y', tf.float32)], MyType)) self.assertEqual(concrete_fn.type_signature.compact_representation(), '(<x=float32,y=float32> -> float32)')
def test_takes_structured_tuple_polymorphic(self): MyType = collections.namedtuple('MyType', ['x', 'y']) # pylint: disable=invalid-name @tf.function def foo(x, t, l, odict, my_type): self.assertIsInstance(x, tf.Tensor) self.assertIsInstance(t, tuple) self.assertIsInstance(l, list) self.assertIsInstance(odict, collections.OrderedDict) self.assertIsInstance(my_type, MyType) return x + t[0] + l[0] + odict['foo'] + my_type.x foo = computation_wrapper_instances.tensorflow_wrapper(foo) concrete_fn = foo.fn_for_argument_type( computation_types.to_type([ tf.int32, (tf.int32, tf.int32), [tf.int32, tf.int32], collections.OrderedDict([('foo', tf.int32), ('bar', tf.int32)]), MyType(tf.int32, tf.int32), ])) self.assertEqual( concrete_fn.type_signature.compact_representation(), '(<int32,<int32,int32>,<int32,int32>,<foo=int32,bar=int32>,<x=int32,y=int32>> -> int32)' ) concrete_fn = foo.fn_for_argument_type( computation_types.to_type([ tf.float32, (tf.float32, tf.float32), [tf.float32, tf.float32], collections.OrderedDict([('foo', tf.float32), ('bar', tf.float32)]), MyType(tf.float32, tf.float32), ])) self.assertEqual( concrete_fn.type_signature.compact_representation(), '(<float32,<float32,float32>,<float32,float32>,<foo=float32,bar=float32>,<x=float32,y=float32>> -> float32)' )
def tf_computation(*args): """Decorates/wraps Python functions and defuns as TFF TensorFlow computations. This symbol can be used as either a decorator or a wrapper applied to a function given to it as an argument. The supported patterns and examples of usage are as follows: 1. Convert an existing function inline into a TFF computation. This is the simplest mode of usage, and how one can embed existing non-TFF code for use with the TFF framework. In this mode, one invokes `tff.tf_computation` with a pair of arguments, the first being a function/defun that contains the logic, and the second being the TFF type of the parameter: ```python foo = tff.tf_computation(lambda x: x > 10, tf.int32) ``` After executing the above code snippet, `foo` becomes an instance of the abstract base class `Computation`. Like all computations, it has the `type_signature` property: ```python str(foo.type_signature) == '(int32 -> bool)' ``` The function passed as a parameter doesn't have to be a lambda, it can also be an existing Python function or a defun. Here's how to construct a computation from the standard TensorFlow operator `tf.add`: ```python foo = tff.tf_computation(tf.add, (tf.int32, tf.int32)) ``` The resulting type signature is as expected: ```python str(foo.type_signature) == '(<int32,int32> -> int32)' ``` If one intends to create a computation that doesn't accept any arguments, the type argument is simply omitted. The function must be a no-argument function as well: ```python foo = tf_computation(lambda: tf.constant(10)) ``` 2. Decorate a Python function or a TensorFlow defun with a TFF type to wrap it as a TFF computation. The only difference between this mode of usage and the one mentioned above is that instead of passing the function/defun as an argument, `tff.tf_computation` along with the optional type specifier is written above the function/defun's body. Here's an example of a computation that accepts a parameter: ```python @tff.tf_computation(tf.int32) def foo(x): return x > 10 ``` One can think of this mode of usage as merely a syntactic sugar for the example already given earlier: ```python foo = tff.tf_computation(lambda x: x > 10, tf.int32) ``` Here's an example of a no-parameter computation: ```python @tff.tf_computation def foo(): return tf.constant(10) ``` Again, this is merely syntactic sugar for the example given earlier: ```python foo = tff.tf_computation(lambda: tf.constant(10)) ``` If the Python function has multiple decorators, `tff.tf_computation` should be the outermost one (the one that appears first in the sequence). 3. Create a polymorphic callable to be instantiated based on arguments, similarly to TensorFlow defuns that have been defined without an input signature. This mode of usage is symmetric to those above. One simply omits the type specifier, and applies `tff.tf_computation` as a decorator or wrapper to a function/defun that does expect parameters. Here's an example of wrapping a lambda as a polymorphic callable: ```python foo = tff.tf_computation(lambda x, y: x > y) ``` The resulting `foo` can be used in the same ways as if it were had the type been declared; the corresponding computation is simply created on demand, in the same way as how polymorphic TensorFlow defuns create and cache concrete function definitions for each combination of argument types. ```python ...foo(1, 2)... ...foo(0.5, 0.3)... ``` Here's an example of creating a polymorphic callable via decorator: ```python @tff.tf_computation def foo(x, y): return x > y ``` The syntax is symmetric to all examples already shown. Args: *args: Either a function/defun, or TFF type spec, or both (function first), or neither, as documented in the 3 patterns and examples of usage above. Returns: If invoked with a function as an argument, returns an instance of a TFF computation constructed based on this function. If called without one, as in the typical decorator style of usage, returns a callable that expects to be called with the function definition supplied as a parameter; see the patterns and examples of usage above. """ return computation_wrapper_instances.tensorflow_wrapper(*args)
def test_error_on_non_callable_non_type(self): with golden.check_raises_traceback( 'non_callable_non_type_traceback.expected', TypeError): computation_wrapper_instances.tensorflow_wrapper(5)