def test_get_signature_with_class_property(self): class C: @property def x(self): return 99 c = C() with self.assertRaises(TypeError): function_utils.get_signature(c.x)
def test_get_callargs_for_signature(self, fn, args, kwargs): signature = function_utils.get_signature(fn) expected_error = None try: signature = inspect.signature(fn) bound_arguments = signature.bind(*args, **kwargs) expected_callargs = bound_arguments.arguments except TypeError as e: expected_error = e expected_callargs = None result_callargs = None if expected_error is None: try: bound_args = signature.bind(*args, **kwargs).arguments self.assertEqual(bound_args, expected_callargs) except (TypeError, AssertionError) as test_err: raise AssertionError( 'With signature `{!s}`, args {!s}, kwargs {!s}, expected bound ' 'args {!s} and error {!s}, tested function returned {!s} and the ' 'test has failed with message: {!s}'.format( signature, args, kwargs, expected_callargs, expected_error, result_callargs, test_err)) else: with self.assertRaises(TypeError): _ = signature.bind(*args, **kwargs)
def test_get_defun_argspec_with_untyped_non_eager_defun(self): # In a tf.function with no input signature, the same restrictions as in a # typed eager function apply. fn = tf.function(lambda x, y, *z: None) self.assertEqual( collections.OrderedDict(function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), ))
def test_as_wrapper_with_classmethod(self): class C: @classmethod def foo(cls, x): return x * 2 signature = function_utils.get_signature(C.foo) self.assertEqual( signature.parameters, collections.OrderedDict(x=inspect.Parameter( 'x', inspect.Parameter.POSITIONAL_OR_KEYWORD)))
def test_get_signature_with_class_instance_method(self): class C: def __init__(self, x): self._x = x def foo(self, y): return self._x * y c = C(5) signature = function_utils.get_signature(c.foo) self.assertEqual( signature.parameters, collections.OrderedDict(y=inspect.Parameter( 'y', inspect.Parameter.POSITIONAL_OR_KEYWORD)))
def test_get_defun_argspec_with_typed_non_eager_defun(self): # In a tf.function with a defined input signature, **kwargs or default # values are not allowed, but *args are, and the input signature may overlap # with *args. fn = tf.function(lambda x, y, *z: None, ( tf.TensorSpec(None, tf.int32), tf.TensorSpec(None, tf.bool), tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32), )) self.assertEqual( collections.OrderedDict(function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), ))
class FunctionUtilsTest(test.TestCase, parameterized.TestCase): def test_get_defun_argspec_with_typed_non_eager_defun(self): # In a tf.function with a defined input signature, **kwargs or default # values are not allowed, but *args are, and the input signature may overlap # with *args. fn = tf.function(lambda x, y, *z: None, ( tf.TensorSpec(None, tf.int32), tf.TensorSpec(None, tf.bool), tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32), )) self.assertEqual( collections.OrderedDict( function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), )) def test_get_defun_argspec_with_untyped_non_eager_defun(self): # In a tf.function with no input signature, the same restrictions as in a # typed eager function apply. fn = tf.function(lambda x, y, *z: None) self.assertEqual( collections.OrderedDict( function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), )) def test_get_signature_with_class_instance_method(self): class C: def __init__(self, x): self._x = x def foo(self, y): return self._x * y c = C(5) signature = function_utils.get_signature(c.foo) self.assertEqual( signature.parameters, collections.OrderedDict(y=inspect.Parameter( 'y', inspect.Parameter.POSITIONAL_OR_KEYWORD))) def test_get_signature_with_class_property(self): class C: @property def x(self): return 99 c = C() with self.assertRaises(TypeError): function_utils.get_signature(c.x) def test_as_wrapper_with_classmethod(self): class C: @classmethod def foo(cls, x): return x * 2 signature = function_utils.get_signature(C.foo) self.assertEqual( signature.parameters, collections.OrderedDict(x=inspect.Parameter( 'x', inspect.Parameter.POSITIONAL_OR_KEYWORD))) # pyformat: disable @parameterized.parameters( itertools.product( # Values of 'fn' to test. [ lambda: None, lambda a: None, lambda a, b: None, lambda *a: None, lambda **a: None, lambda *a, **b: None, lambda a, *b: None, lambda a, **b: None, lambda a, b, **c: None, lambda a, b=10: None, lambda a, b=10, c=20: None, lambda a, b=10, *c: None, lambda a, b=10, **c: None, lambda a, b=10, *c, **d: None, lambda a, b, c=10, *d: None, lambda a=10, b=20, c=30, **d: None ], # Values of 'args' to test. [[], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4]], # Values of 'kwargs' to test. [{}, { 'b': 100 }, { 'name': 'foo' }, { 'b': 100, 'name': 'foo' }])) # pyformat: enable def test_get_callargs_for_signature(self, fn, args, kwargs): signature = function_utils.get_signature(fn) expected_error = None try: signature = inspect.signature(fn) bound_arguments = signature.bind(*args, **kwargs) expected_callargs = bound_arguments.arguments except TypeError as e: expected_error = e expected_callargs = None result_callargs = None if expected_error is None: try: bound_args = signature.bind(*args, **kwargs).arguments self.assertEqual(bound_args, expected_callargs) except (TypeError, AssertionError) as test_err: raise AssertionError( 'With signature `{!s}`, args {!s}, kwargs {!s}, expected bound ' 'args {!s} and error {!s}, tested function returned {!s} and the ' 'test has failed with message: {!s}'.format( signature, args, kwargs, expected_callargs, expected_error, result_callargs, test_err)) else: with self.assertRaises(TypeError): _ = signature.bind(*args, **kwargs) # pyformat: disable @parameterized.named_parameters( ('args_only', function_utils.get_signature(lambda a: None), [tf.int32], collections.OrderedDict()), ('args_and_kwargs_unnamed', function_utils.get_signature(lambda a, b=True: None), [tf.int32, tf.bool], collections.OrderedDict()), ('args_and_kwargs_named', function_utils.get_signature(lambda a, b=True: None), [tf.int32], collections.OrderedDict(b=tf.bool)), ('args_and_kwargs_default_int', function_utils.get_signature(lambda a=10, b=True: None), [tf.int32], collections.OrderedDict(b=tf.bool)), ) # pyformat: enable def test_is_signature_compatible_with_types_true(self, signature, *args, **kwargs): self.assertFalse( function_utils.is_signature_compatible_with_types( signature, *args, **kwargs)) # pyformat: disable @parameterized.named_parameters( ('args_only', function_utils.get_signature(lambda a=True: None), [tf.int32], collections.OrderedDict()), ('args_and_kwargs', function_utils.get_signature(lambda a=10, b=True: None), [tf.bool], collections.OrderedDict(b=tf.bool)), ) # pyformat: enable def test_is_signature_compatible_with_types_false(self, signature, *args, **kwargs): self.assertFalse( function_utils.is_signature_compatible_with_types( signature, *args, **kwargs)) # pyformat: disable @parameterized.parameters( (tf.int32, False), ([tf.int32, tf.int32], True), ([tf.int32, ('b', tf.int32)], True), ([('a', tf.int32), ('b', tf.int32)], True), ([('a', tf.int32), tf.int32], False), (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), True), (anonymous_tuple.AnonymousTuple([('a', 1), (None, 2)]), False)) # pyformat: enable def test_is_argument_tuple(self, arg, expected_result): self.assertEqual(function_utils.is_argument_tuple(arg), expected_result) # pyformat: disable @parameterized.parameters( (anonymous_tuple.AnonymousTuple([(None, 1)]), [1], {}), (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), [1], { 'a': 2 })) # pyformat: enable def test_unpack_args_from_anonymous_tuple(self, tuple_with_args, expected_args, expected_kwargs): self.assertEqual( function_utils.unpack_args_from_tuple(tuple_with_args), (expected_args, expected_kwargs)) # pyformat: disable @parameterized.parameters( ([tf.int32], [tf.int32], {}), ([('a', tf.int32)], [], { 'a': tf.int32 }), ([tf.int32, tf.bool], [tf.int32, tf.bool], {}), ([tf.int32, ('b', tf.bool)], [tf.int32], { 'b': tf.bool }), ([('a', tf.int32), ('b', tf.bool)], [], { 'a': tf.int32, 'b': tf.bool })) # pyformat: enable def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args, expected_kwargs): args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args) self.assertEqual(len(args), len(expected_args)) for idx, arg in enumerate(args): self.assertTrue( arg.is_equivalent_to( computation_types.to_type(expected_args[idx]))) self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys())) for k, v in kwargs.items(): self.assertTrue( v.is_equivalent_to( computation_types.to_type(expected_kwargs[k]))) def test_pack_args_into_anonymous_tuple_without_type_spec(self): self.assertEqual( function_utils.pack_args_into_anonymous_tuple([1], {'a': 10}), anonymous_tuple.AnonymousTuple([(None, 1), ('a', 10)])) self.assertIn( function_utils.pack_args_into_anonymous_tuple([1, 2], { 'a': 10, 'b': 20 }), [ anonymous_tuple.AnonymousTuple([ (None, 1), (None, 2), ('a', 10), ('b', 20), ]), anonymous_tuple.AnonymousTuple([ (None, 1), (None, 2), ('b', 20), ('a', 10), ]) ]) self.assertIn( function_utils.pack_args_into_anonymous_tuple([], { 'a': 10, 'b': 20 }), [ anonymous_tuple.AnonymousTuple([('a', 10), ('b', 20)]), anonymous_tuple.AnonymousTuple([('b', 20), ('a', 10)]) ]) self.assertEqual( function_utils.pack_args_into_anonymous_tuple([1], {}), anonymous_tuple.AnonymousTuple([(None, 1)])) # pyformat: disable @parameterized.parameters( ([1], {}, [tf.int32], [(None, 1)]), ([1, True], {}, [tf.int32, tf.bool], [(None, 1), (None, True)]), ([1, True], {}, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]), ([1], { 'y': True }, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]), ([], { 'x': 1, 'y': True }, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]), ([], collections.OrderedDict([('y', True), ('x', 1)]), [ ('x', tf.int32), ('y', tf.bool) ], [('x', 1), ('y', True)])) # pyformat: enable def test_pack_args_into_anonymous_tuple_with_type_spec_expect_success( self, args, kwargs, type_spec, elements): self.assertEqual( function_utils.pack_args_into_anonymous_tuple( args, kwargs, type_spec, NoopIngestContextForTest()), anonymous_tuple.AnonymousTuple(elements)) # pyformat: disable @parameterized.parameters(([1], {}, [(tf.bool)]), ([], { 'x': 1, 'y': True }, [(tf.int32), (tf.bool)])) # pyformat: enable def test_pack_args_into_anonymous_tuple_with_type_spec_expect_failure( self, args, kwargs, type_spec): with self.assertRaises(TypeError): function_utils.pack_args_into_anonymous_tuple( args, kwargs, type_spec, NoopIngestContextForTest()) # pyformat: disable @parameterized.parameters( (None, [], {}, 'None'), (tf.int32, [1], {}, '1'), ([tf.int32, tf.bool], [1, True], {}, '<1,True>'), ([('x', tf.int32), ('y', tf.bool)], [1, True], {}, '<x=1,y=True>'), ([('x', tf.int32), ('y', tf.bool)], [1], { 'y': True }, '<x=1,y=True>'), ([tf.int32, tf.bool], [anonymous_tuple.AnonymousTuple([(None, 1), (None, True)])], {}, '<1,True>')) # pyformat: enable def test_pack_args(self, parameter_type, args, kwargs, expected_value_string): self.assertEqual( str( function_utils.pack_args(parameter_type, args, kwargs, NoopIngestContextForTest())), expected_value_string) # pyformat: disable @parameterized.parameters( (1, lambda: 10, None, None, None, 10), (2, lambda x=1: x + 10, None, None, None, 11), (3, lambda x=1: x + 10, tf.int32, None, 20, 30), (4, lambda x, y: x + y, [tf.int32, tf.int32], None, anonymous_tuple.AnonymousTuple([('x', 5), ('y', 6)]), 11), (5, lambda *args: str(args), [tf.int32, tf.int32], True, anonymous_tuple.AnonymousTuple([(None, 5), (None, 6)]), '(5, 6)'), (6, lambda *args: str(args), [ ('x', tf.int32), ('y', tf.int32) ], False, anonymous_tuple.AnonymousTuple([ ('x', 5), ('y', 6) ]), '(AnonymousTuple([(\'x\', 5), (\'y\', 6)]),)'), ( 7, lambda x: str(x), # pylint: disable=unnecessary-lambda [tf.int32], None, anonymous_tuple.AnonymousTuple([(None, 10)]), '[10]')) # pyformat: enable def test_wrap_as_zero_or_one_arg_callable(self, unused_index, fn, parameter_type, unpack, arg, expected_result): wrapped_fn = function_utils.wrap_as_zero_or_one_arg_callable( fn, parameter_type, unpack) actual_result = wrapped_fn(arg) if parameter_type else wrapped_fn() self.assertEqual(actual_result, expected_result)
def _wrap(fn, parameter_type, wrapper_fn): """Wraps a possibly-polymorphic `fn` in `wrapper_fn`. If `parameter_type` is `None` and `fn` takes any arguments (even with default values), `fn` is inferred to be polymorphic and won't be passed to `wrapper_fn` until invocation time (when concrete parameter types are available). `wrapper_fn` must accept three positional arguments and one defaulted argument `name`: * `target_fn`, the Python function to be wrapped. * `parameter_type`, the optional type of the computation's parameter (an instance of `computation_types.Type`). * `unpack`, an argument which will be passed on to `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`. See that function for details. * Optional `name`, the name of the function that is being wrapped (only for debugging purposes). Args: fn: The function or defun to wrap as a computation. parameter_type: Optional type of any arguments to `fn`. wrapper_fn: The Python callable that performs actual wrapping. The object to be returned by this function should be an instance of a `ConcreteFunction`. Returns: Either the result of wrapping (an object that represents the computation), or a polymorphic callable that performs wrapping upon invocation based on argument types. The returned function still may accept multiple arguments (it has not yet had `function_uils.wrap_as_zero_or_one_arg_callable` applied to it). Raises: TypeError: if the arguments are of the wrong types, or the `wrapper_fn` constructs something that isn't a ConcreteFunction. """ try: fn_name = fn.__name__ except AttributeError: fn_name = None signature = function_utils.get_signature(fn) parameter_type = computation_types.to_type(parameter_type) if parameter_type is None and signature.parameters: # There is no TFF type specification, and the function/defun declares # parameters. Create a polymorphic template. def _wrap_polymorphic(parameter_type: computation_types.Type, unpack: Optional[bool]): return wrapper_fn(fn, parameter_type, unpack=unpack, name=fn_name) polymorphic_fn = function_utils.PolymorphicFunction(_wrap_polymorphic) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on # which it was applied to the wrapped object, so we must transfer it here # explicitly. polymorphic_fn.__doc__ = getattr(fn, '__doc__', None) return polymorphic_fn # Either we have a concrete parameter type, or this is no-arg function. concrete_fn = wrapper_fn(fn, parameter_type, unpack=None) py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction, 'value returned by the wrapper') if (concrete_fn.type_signature.parameter is not None and not concrete_fn.type_signature.parameter.is_equivalent_to( parameter_type)): raise TypeError( 'Expected a concrete function that takes parameter {}, got one ' 'that takes {}.'.format(str(parameter_type), str(concrete_fn.type_signature.parameter))) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on concrete_fn.__doc__ = getattr(fn, '__doc__', None) return concrete_fn
def serialize_tf2_as_tf_computation(target, parameter_type, unpack=None): """Serializes the 'target' as a TF computation with a given parameter type. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function or `tf.function`, with arguments matching the 'parameter_type'. parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. unpack: Whether to always unpack the parameter_type. Necessary for support of polymorphic tf2_computations. Returns: The constructed `pb.Computation` instance with the `pb.TensorFlow` variant set. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ py_typecheck.check_callable(target) parameter_type = computation_types.to_type(parameter_type) signature = function_utils.get_signature(target) if signature.parameters and parameter_type is None: raise ValueError( 'Expected the target to declare no parameters, found {!r}.'.format( signature.parameters)) # In the codepath for TF V1 based serialization (tff.tf_computation), # we get the "wrapped" function to serialize. Here, target is the # raw function to be wrapped; however, we still need to know if # the parameter_type should be unpacked into multiple args and kwargs # in order to construct the TensorSpecs to be passed in the call # to get_concrete_fn below. unpack = function_utils.infer_unpack_needed(target, parameter_type, unpack) arg_typespecs, kwarg_typespecs, parameter_binding = ( tensorflow_utils.get_tf_typespec_and_binding( parameter_type, arg_names=list(signature.parameters.keys()), unpack=unpack)) # Pseudo-global to be appended to once when target_poly below is traced. type_and_binding_slot = [] # N.B. To serialize a tf.function or eager python code, # the return type must be a flat list, tuple, or dict. However, the # tff.tf_computation must be able to handle structured inputs and outputs. # Thus, we intercept the result of calling the original target fn, introspect # its structure to create a result_type and bindings, and then return a # flat dict output. It is this new "unpacked" tf.function that we will # serialize using tf.saved_model.save. # # TODO(b/117428091): The return type limitation is primarily a limitation of # SignatureDefs and therefore of the signatures argument to # tf.saved_model.save. tf.functions attached to objects and loaded back with # tf.saved_model.load can take/return nests; this might offer a better # approach to the one taken here. @tf.function def target_poly(*args, **kwargs): result = target(*args, **kwargs) result_dict, result_type, result_binding = ( tensorflow_utils.get_tf2_result_dict_and_binding(result)) assert not type_and_binding_slot # A "side channel" python output. type_and_binding_slot.append((result_type, result_binding)) return result_dict # Triggers tracing so that type_and_binding_slot is filled. cc_fn = target_poly.get_concrete_function(*arg_typespecs, **kwarg_typespecs) assert len(type_and_binding_slot) == 1 result_type, result_binding = type_and_binding_slot[0] # N.B. Note that cc_fn does *not* accept the same args and kwargs as the # Python target_poly; instead, it must be called with **kwargs based on the # unique names embedded in the TensorSpecs inside arg_typespecs and # kwarg_typespecs. The (preliminary) parameter_binding tracks the mapping # between these tensor names and the components of the (possibly nested) TFF # input type. When cc_fn is serialized, concrete tensors for each input are # introduced, and the call finalize_binding(parameter_binding, # sigs['serving_default'].inputs) updates the bindings to reference these # concrete tensors. # Associate vars with unique names and explicitly attach to the Checkpoint: var_dict = { 'var{:02d}'.format(i): v for i, v in enumerate(cc_fn.graph.variables) } saveable = tf.train.Checkpoint(fn=target_poly, **var_dict) try: # TODO(b/122081673): All we really need is the meta graph def, we could # probably just load that directly, e.g., using parse_saved_model from # tensorflow/python/saved_model/loader_impl.py, but I'm not sure we want to # depend on that presumably non-public symbol. Perhaps TF can expose a way # to just get the MetaGraphDef directly without saving to a tempfile? This # looks like a small change to v2.saved_model.save(). outdir = tempfile.mkdtemp('savedmodel') tf.saved_model.save(saveable, outdir, signatures=cc_fn) graph = tf.Graph() with tf.compat.v1.Session(graph=graph) as sess: mgd = tf.compat.v1.saved_model.load( sess, tags=[tf.saved_model.SERVING], export_dir=outdir) finally: shutil.rmtree(outdir) sigs = mgd.signature_def # TODO(b/123102455): Figure out how to support the init_op. The meta graph def # contains sigs['__saved_model_init_op'].outputs['__saved_model_init_op']. It # probably won't do what we want, because it will want to read from # Checkpoints, not just run Variable initializerse (?). The right solution may # be to grab the target_poly.get_initialization_function(), and save a sig for # that. # Now, traverse the signature from the MetaGraphDef to find # find the actual tensor names and write them into the bindings. finalize_binding(parameter_binding, sigs['serving_default'].inputs) finalize_binding(result_binding, sigs['serving_default'].outputs) annotated_type = computation_types.FunctionType(parameter_type, result_type) return pb.Computation( type=pb.Type( function=pb.FunctionType( parameter=type_serialization.serialize_type(parameter_type), result=type_serialization.serialize_type(result_type))), tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(mgd.graph_def), parameter=parameter_binding, result=result_binding)), annotated_type
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): """Serializes the 'target' as a TF computation with a given parameter type. See also `serialize_tf2_as_tf_computation` for TensorFlow 2 serialization. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function. In the future, we will add here support for serializing the various kinds of non-eager and eager functions, and eventually aim at full support for and compliance with TF 2.0. This function is currently required to declare either zero parameters if `parameter_type` is `None`, or exactly one parameter if it's not `None`. The nested structure of this parameter must correspond to the structure of the 'parameter_type'. In the future, we may support targets with multiple args/keyword args (to be documented in the API and referenced from here). parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. context_stack: The context stack to use. Returns: A tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(target, types.FunctionType) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) parameter_type = computation_types.to_type(parameter_type) signature = function_utils.get_signature(target) with tf.Graph().as_default() as graph: args = [] if parameter_type is not None: if len(signature.parameters) != 1: raise ValueError( 'Expected the target to declare exactly one parameter, found {!r}.' .format(signature.parameters)) parameter_name = next(iter(signature.parameters)) parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( parameter_name, parameter_type, graph) args.append(parameter_value) else: if signature.parameters: raise ValueError( 'Expected the target to declare no parameters, found {!r}.'.format( signature.parameters)) parameter_binding = None context = tf_computation_context.TensorFlowComputationContext(graph) with context_stack.install(context): result = target(*args) # TODO(b/122081673): This needs to change for TF 2.0. We may also # want to allow the person creating a tff.tf_computation to specify # a different initializer; e.g., if it is known that certain # variables will be assigned immediately to arguments of the function, # then it is wasteful to initialize them before this. # # The following is a bit of a work around: the collections below may # contain variables more than once, hence we throw into a set. TFF needs # to ensure all variables are initialized, but not all variables are # always in the collections we expect. tff.learning._KerasModel tries to # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into # VARS_FOR_TFF_TO_INITIALIZE for now. all_variables = set(tf.compat.v1.global_variables() + tf.compat.v1.local_variables() + tf.compat.v1.get_collection( graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE)) if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format(len(all_variables)) with tf.control_dependencies(context.init_ops): # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. init_op_name = tf.group( tf.compat.v1.initializers.variables(all_variables, name=name), *tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)).name elif context.init_ops: init_op_name = tf.group( *context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) annotated_type = computation_types.FunctionType(parameter_type, result_type) # WARNING: we do not really want to be modifying the graph here if we can # avoid it. This is purely to work around performance issues uncovered with # the non-standard usage of Tensorflow and have been discussed with the # Tensorflow core team before being added. clean_graph_def = _clean_graph_def(graph.as_graph_def()) return pb.Computation( type=pb.Type( function=pb.FunctionType( parameter=type_serialization.serialize_type(parameter_type), result=type_serialization.serialize_type(result_type))), tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(clean_graph_def), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name)), annotated_type
def _parameters(fn): return function_utils.get_signature(fn).parameters.values()
class FunctionUtilsTest(test.TestCase, parameterized.TestCase): def test_is_defun(self): self.assertTrue(function_utils.is_defun(tf.function(lambda x: None))) fn = tf.function(lambda x: None, (tf.TensorSpec(None, tf.int32), )) self.assertTrue(function_utils.is_defun(fn)) self.assertFalse(function_utils.is_defun(lambda x: None)) self.assertFalse(function_utils.is_defun(None)) def test_get_defun_argspec_with_typed_non_eager_defun(self): # In a tf.function with a defined input signature, **kwargs or default # values are not allowed, but *args are, and the input signature may overlap # with *args. fn = tf.function(lambda x, y, *z: None, ( tf.TensorSpec(None, tf.int32), tf.TensorSpec(None, tf.bool), tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32), )) self.assertEqual( collections.OrderedDict( function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), )) def test_get_defun_argspec_with_untyped_non_eager_defun(self): # In a tf.function with no input signature, the same restrictions as in a # typed eager function apply. fn = tf.function(lambda x, y, *z: None) self.assertEqual( collections.OrderedDict( function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), )) # pyformat: disable @parameterized.parameters( itertools.product( # Values of 'fn' to test. [ lambda: None, lambda a: None, lambda a, b: None, lambda *a: None, lambda **a: None, lambda *a, **b: None, lambda a, *b: None, lambda a, **b: None, lambda a, b, **c: None, lambda a, b=10: None, lambda a, b=10, c=20: None, lambda a, b=10, *c: None, lambda a, b=10, **c: None, lambda a, b=10, *c, **d: None, lambda a, b, c=10, *d: None, lambda a=10, b=20, c=30, **d: None ], # Values of 'args' to test. [[], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4]], # Values of 'kwargs' to test. [{}, { 'b': 100 }, { 'name': 'foo' }, { 'b': 100, 'name': 'foo' }])) # pyformat: enable def test_get_callargs_for_signature(self, fn, args, kwargs): signature = function_utils.get_signature(fn) expected_error = None try: signature = inspect.signature(fn) bound_arguments = signature.bind(*args, **kwargs) expected_callargs = bound_arguments.arguments except TypeError as e: expected_error = e expected_callargs = None result_callargs = None if expected_error is None: try: bound_args = signature.bind(*args, **kwargs).arguments self.assertEqual(bound_args, expected_callargs) except (TypeError, AssertionError) as test_err: raise AssertionError( 'With signature `{!s}`, args {!s}, kwargs {!s}, expected bound ' 'args {!s} and error {!s}, tested function returned {!s} and the ' 'test has failed with message: {!s}'.format( signature, args, kwargs, expected_callargs, expected_error, result_callargs, test_err)) else: with self.assertRaises(TypeError): _ = signature.bind(*args, **kwargs) # pyformat: disable # pylint: disable=g-complex-comprehension @parameterized.parameters( (function_utils.get_signature(params[0]), ) + params[1:] for params in [ (lambda a: None, [tf.int32], {}), (lambda a, b=True: None, [tf.int32, tf.bool], {}), (lambda a, b=True: None, [tf.int32], { 'b': tf.bool }), (lambda a, b=True: None, [tf.bool], { 'b': tf.bool }), (lambda a=10, b=True: None, [tf.int32], { 'b': tf.bool }), ]) # pylint: enable=g-complex-comprehension # pyformat: enable def test_is_signature_compatible_with_types_true(self, signature, args, kwargs): self.assertTrue( function_utils.is_signature_compatible_with_types( signature, *[computation_types.to_type(a) for a in args], **{k: computation_types.to_type(v) for k, v in kwargs.items()})) # pyformat: disable # pylint: disable=g-complex-comprehension @parameterized.parameters( (function_utils.get_signature(params[0]), ) + params[1:] for params in [ (lambda a=True: None, [tf.int32], {}), (lambda a=10, b=True: None, [tf.bool], { 'b': tf.bool }), ]) # pylint: enable=g-complex-comprehension # pyformat: enable def test_is_signature_compatible_with_types_false(self, signature, args, kwargs): self.assertFalse( function_utils.is_signature_compatible_with_types( signature, *[computation_types.to_type(a) for a in args], **{k: computation_types.to_type(v) for k, v in kwargs.items()})) # pyformat: disable @parameterized.parameters( (tf.int32, False), ([tf.int32, tf.int32], True), ([tf.int32, ('b', tf.int32)], True), ([('a', tf.int32), ('b', tf.int32)], True), ([('a', tf.int32), tf.int32], False), (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), True), (anonymous_tuple.AnonymousTuple([('a', 1), (None, 2)]), False)) # pyformat: enable def test_is_argument_tuple(self, arg, expected_result): self.assertEqual(function_utils.is_argument_tuple(arg), expected_result) # pyformat: disable @parameterized.parameters( (anonymous_tuple.AnonymousTuple([(None, 1)]), [1], {}), (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), [1], { 'a': 2 })) # pyformat: enable def test_unpack_args_from_anonymous_tuple(self, tuple_with_args, expected_args, expected_kwargs): self.assertEqual( function_utils.unpack_args_from_tuple(tuple_with_args), (expected_args, expected_kwargs)) # pyformat: disable @parameterized.parameters( ([tf.int32], [tf.int32], {}), ([('a', tf.int32)], [], { 'a': tf.int32 }), ([tf.int32, tf.bool], [tf.int32, tf.bool], {}), ([tf.int32, ('b', tf.bool)], [tf.int32], { 'b': tf.bool }), ([('a', tf.int32), ('b', tf.bool)], [], { 'a': tf.int32, 'b': tf.bool })) # pyformat: enable def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args, expected_kwargs): args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args) self.assertEqual(len(args), len(expected_args)) for idx, arg in enumerate(args): self.assertTrue( type_utils.are_equivalent_types( arg, computation_types.to_type(expected_args[idx]))) self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys())) for k, v in kwargs.items(): self.assertTrue( type_utils.are_equivalent_types(computation_types.to_type(v), expected_kwargs[k])) def test_pack_args_into_anonymous_tuple_without_type_spec(self): self.assertEqual( function_utils.pack_args_into_anonymous_tuple([1], {'a': 10}), anonymous_tuple.AnonymousTuple([(None, 1), ('a', 10)])) self.assertIn( function_utils.pack_args_into_anonymous_tuple([1, 2], { 'a': 10, 'b': 20 }), [ anonymous_tuple.AnonymousTuple([ (None, 1), (None, 2), ('a', 10), ('b', 20), ]), anonymous_tuple.AnonymousTuple([ (None, 1), (None, 2), ('b', 20), ('a', 10), ]) ]) self.assertIn( function_utils.pack_args_into_anonymous_tuple([], { 'a': 10, 'b': 20 }), [ anonymous_tuple.AnonymousTuple([('a', 10), ('b', 20)]), anonymous_tuple.AnonymousTuple([('b', 20), ('a', 10)]) ]) self.assertEqual( function_utils.pack_args_into_anonymous_tuple([1], {}), anonymous_tuple.AnonymousTuple([(None, 1)])) # pyformat: disable @parameterized.parameters( ([1], {}, [tf.int32], [(None, 1)]), ([1, True], {}, [tf.int32, tf.bool], [(None, 1), (None, True)]), ([1, True], {}, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]), ([1], { 'y': True }, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]), ([], { 'x': 1, 'y': True }, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]), ([], collections.OrderedDict([('y', True), ('x', 1)]), [ ('x', tf.int32), ('y', tf.bool) ], [('x', 1), ('y', True)])) # pyformat: enable def test_pack_args_into_anonymous_tuple_with_type_spec_expect_success( self, args, kwargs, type_spec, elements): self.assertEqual( function_utils.pack_args_into_anonymous_tuple( args, kwargs, type_spec, NoopIngestContextForTest()), anonymous_tuple.AnonymousTuple(elements)) # pyformat: disable @parameterized.parameters(([1], {}, [(tf.bool)]), ([], { 'x': 1, 'y': True }, [(tf.int32), (tf.bool)])) # pyformat: enable def test_pack_args_into_anonymous_tuple_with_type_spec_expect_failure( self, args, kwargs, type_spec): with self.assertRaises(TypeError): function_utils.pack_args_into_anonymous_tuple( args, kwargs, type_spec, NoopIngestContextForTest()) # pyformat: disable @parameterized.parameters( (None, [], {}, 'None'), (tf.int32, [1], {}, '1'), ([tf.int32, tf.bool], [1, True], {}, '<1,True>'), ([('x', tf.int32), ('y', tf.bool)], [1, True], {}, '<x=1,y=True>'), ([('x', tf.int32), ('y', tf.bool)], [1], { 'y': True }, '<x=1,y=True>'), ([tf.int32, tf.bool], [anonymous_tuple.AnonymousTuple([(None, 1), (None, True)])], {}, '<1,True>')) # pyformat: enable def test_pack_args(self, parameter_type, args, kwargs, expected_value_string): self.assertEqual( str( function_utils.pack_args(parameter_type, args, kwargs, NoopIngestContextForTest())), expected_value_string) # pyformat: disable @parameterized.parameters( (1, lambda: 10, None, None, None, 10), (2, lambda x=1: x + 10, None, None, None, 11), (3, lambda x=1: x + 10, tf.int32, None, 20, 30), (4, lambda x, y: x + y, [tf.int32, tf.int32], None, anonymous_tuple.AnonymousTuple([('x', 5), ('y', 6)]), 11), (5, lambda *args: str(args), [tf.int32, tf.int32], True, anonymous_tuple.AnonymousTuple([(None, 5), (None, 6)]), '(5, 6)'), (6, lambda *args: str(args), [ ('x', tf.int32), ('y', tf.int32) ], False, anonymous_tuple.AnonymousTuple([ ('x', 5), ('y', 6) ]), '(AnonymousTuple([(\'x\', 5), (\'y\', 6)]),)'), ( 7, lambda x: str(x), # pylint: disable=unnecessary-lambda [tf.int32], None, anonymous_tuple.AnonymousTuple([(None, 10)]), '[10]')) # pyformat: enable def test_wrap_as_zero_or_one_arg_callable(self, unused_index, fn, parameter_type, unpack, arg, expected_result): wrapped_fn = function_utils.wrap_as_zero_or_one_arg_callable( fn, parameter_type, unpack) actual_result = wrapped_fn(arg) if parameter_type else wrapped_fn() self.assertEqual(actual_result, expected_result) def test_polymorphic_function(self): class ContextForTest(context_base.Context): def ingest(self, val, type_spec): return val def invoke(self, comp, arg): return 'name={},type={},arg={}'.format( comp.name, str(comp.type_signature.parameter), str(arg)) class TestFunction(function_utils.ConcreteFunction): def __init__(self, name, parameter_type): self._name = name super().__init__( computation_types.FunctionType(parameter_type, tf.string), context_stack_impl.context_stack) @property def name(self): return self._name class TestFunctionFactory(object): def __init__(self): self._count = 0 def __call__(self, parameter_type): self._count = self._count + 1 return TestFunction(str(self._count), parameter_type) with context_stack_impl.context_stack.install(ContextForTest()): fn = function_utils.PolymorphicFunction(TestFunctionFactory()) self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>') self.assertEqual(fn(20, x=True), 'name=2,type=<int32,x=bool>,arg=<20,x=True>') self.assertEqual(fn(True), 'name=3,type=<bool>,arg=<True>') self.assertEqual(fn(30, x=40), 'name=4,type=<int32,x=int32>,arg=<30,x=40>') self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>') self.assertEqual(fn(0, x=False), 'name=2,type=<int32,x=bool>,arg=<0,x=False>') self.assertEqual(fn(False), 'name=3,type=<bool>,arg=<False>') self.assertEqual(fn(60, x=70), 'name=4,type=<int32,x=int32>,arg=<60,x=70>') def test_concrete_function(self): class ContextForTest(context_base.Context): def ingest(self, val, type_spec): return val def invoke(self, comp, arg): return comp.invoke_fn(arg) class TestFunction(function_utils.ConcreteFunction): def __init__(self, type_signature, invoke_fn): super().__init__(type_signature, context_stack_impl.context_stack) self._invoke_fn = invoke_fn def invoke_fn(self, arg): return self._invoke_fn(arg) with context_stack_impl.context_stack.install(ContextForTest()): fn = TestFunction( computation_types.FunctionType(tf.int32, tf.bool), lambda x: x > 10) self.assertEqual(fn(5), False) self.assertEqual(fn(15), True) fn = TestFunction( computation_types.FunctionType([('x', tf.int32), ('y', tf.int32)], tf.bool), lambda arg: arg.x > arg.y) self.assertEqual(fn(5, 10), False) self.assertEqual(fn(10, 5), True) self.assertEqual(fn(y=10, x=5), False) self.assertEqual(fn(y=5, x=10), True) self.assertEqual(fn(10, y=5), True)
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): """Serializes the 'target' as a TF computation with a given parameter type. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function. In the future, we will add here support for serializing the various kinds of non-eager and eager functions, and eventually aim at full support for and compliance with TF 2.0. This function is currently required to declare either zero parameters if `parameter_type` is `None`, or exactly one parameter if it's not `None`. The nested structure of this parameter must correspond to the structure of the 'parameter_type'. In the future, we may support targets with multiple args/keyword args (to be documented in the API and referenced from here). parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `computation_types.Type`. context_stack: The context stack to use. Returns: A tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(target, types.FunctionType) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: py_typecheck.check_type(parameter_type, computation_types.Type) signature = function_utils.get_signature(target) with tf.Graph().as_default() as graph: if parameter_type is not None: if len(signature.parameters) != 1: raise ValueError( 'Expected the target to declare exactly one parameter, found {!r}.' .format(signature.parameters)) parameter_name = next(iter(signature.parameters)) parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( parameter_name, parameter_type, graph) else: if signature.parameters: raise ValueError( 'Expected the target to declare no parameters, found {!r}.' .format(signature.parameters)) parameter_value = None parameter_binding = None context = tensorflow_computation_context.TensorFlowComputationContext( graph) with context_stack.install(context): with variable_utils.record_variable_creation_scope( ) as all_variables: if parameter_value is not None: result = target(parameter_value) else: result = target() initializer_ops = [] if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format( len(all_variables)) initializer_ops.append( tf.compat.v1.initializers.variables(all_variables, name=name)) initializer_ops.extend( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)) if initializer_ops: # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. with tf.compat.v1.control_dependencies(context.init_ops): init_op_name = tf.group(*initializer_ops, name='grouped_initializers').name elif context.init_ops: init_op_name = tf.group(*context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) # WARNING: we do not really want to be modifying the graph here if we can # avoid it. This is purely to work around performance issues uncovered with # the non-standard usage of Tensorflow and have been discussed with the # Tensorflow core team before being added. clean_graph_def = _clean_graph_def(graph.as_graph_def()) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(clean_graph_def), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow), type_signature