def test_get_defun_argspec_with_untyped_non_eager_defun(self): # In a non-eager defun with no input signature, the same restrictions as in # a typed defun apply. self.assertEqual( function_utils.get_argspec(function.Defun()(lambda x, y, *z: None)), inspect.ArgSpec( args=['x', 'y'], varargs='z', keywords=None, defaults=None))
def test_get_defun_argspec_with_typed_non_eager_defun(self): # In a non-eager defun with a defined input signature, **kwargs or default # values are not allowed, but *args are, and the input signature may # overlap with *args. self.assertEqual( function_utils.get_argspec( function.Defun(tf.int32, tf.bool, tf.float32, tf.float32)(lambda x, y, *z: None)), inspect.ArgSpec( args=['x', 'y'], varargs='z', keywords=None, defaults=None))
def serialize_tf2_as_tf_computation(target, parameter_type, unpack=None): """Serializes the 'target' as a TF computation with a given parameter type. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function or `tf.function`, with arguments matching the 'parameter_type'. parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. unpack: Whether to always unpack the parameter_type. Necessary for support of polymorphic tf2_computations. Returns: The constructed `pb.Computation` instance with the `pb.TensorFlow` variant set. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ py_typecheck.check_callable(target) parameter_type = computation_types.to_type(parameter_type) argspec = function_utils.get_argspec(target) if argspec.args and parameter_type is None: raise ValueError( 'Expected the target to declare no parameters, found {}.'.format( repr(argspec.args))) # In the codepath for TF V1 based serialization (tff.tf_computation), # we get the "wrapped" function to serialize. Here, target is the # raw function to be wrapped; however, we still need to know if # the parameter_type should be unpacked into multiple args and kwargs # in order to construct the TensorSpecs to be passed in the call # to get_concrete_fn below. unpack = function_utils.infer_unpack_needed(target, parameter_type, unpack) arg_typespecs, kwarg_typespecs, parameter_binding = ( graph_utils.get_tf_typespec_and_binding(parameter_type, arg_names=argspec.args, unpack=unpack)) # Pseudo-global to be appended to once when target_poly below is traced. type_and_binding_slot = [] # N.B. To serialize a tf.function or eager python code, # the return type must be a flat list, tuple, or dict. However, the # tff.tf_computation must be able to handle structured inputs and outputs. # Thus, we intercept the result of calling the original target fn, introspect # its structure to create a result_type and bindings, and then return a # flat dict output. It is this new "unpacked" tf.function that we will # serialize using tf.saved_model.save. # # TODO(b/117428091): The return type limitation is primarily a limitation of # SignatureDefs and therefore of the signatures argument to # tf.saved_model.save. tf.functions attached to objects and loaded back with # tf.saved_model.load can take/return nests; this might offer a better # approach to the one taken here. @tf.function(autograph=False) def target_poly(*args, **kwargs): result = target(*args, **kwargs) result_dict, result_type, result_binding = ( graph_utils.get_tf2_result_dict_and_binding(result)) assert not type_and_binding_slot # A "side channel" python output. type_and_binding_slot.append((result_type, result_binding)) return result_dict # Triggers tracing so that type_and_binding_slot is filled. cc_fn = target_poly.get_concrete_function(*arg_typespecs, **kwarg_typespecs) assert len(type_and_binding_slot) == 1 result_type, result_binding = type_and_binding_slot[0] # N.B. Note that cc_fn does *not* accept the same args and kwargs as the # Python target_poly; instead, it must be called with **kwargs based on the # unique names embedded in the TensorSpecs inside arg_typespecs and # kwarg_typespecs. The (preliminary) parameter_binding tracks the mapping # between these tensor names and the components of the (possibly nested) TFF # input type. When cc_fn is serialized, concrete tensors for each input are # introduced, and the call finalize_binding(parameter_binding, # sigs['serving_default'].inputs) updates the bindings to reference these # concrete tensors. # Associate vars with unique names and explicitly attach to the Checkpoint: var_dict = { 'var{:02d}'.format(i): v for i, v in enumerate(cc_fn.graph.variables) } saveable = tf.train.Checkpoint(fn=target_poly, **var_dict) try: # TODO(b/122081673): All we really need is the meta graph def, we could # probably just load that directly, e.g., using parse_saved_model from # tensorflow/python/saved_model/loader_impl.py, but I'm not sure we want to # depend on that presumably non-public symbol. Perhaps TF can expose a way # to just get the MetaGraphDef directly without saving to a tempfile? This # looks like a small change to v2.saved_model.save(). outdir = tempfile.mkdtemp('savedmodel') tf.saved_model.save(saveable, outdir, signatures=cc_fn) graph = tf.Graph() with tf.compat.v1.Session(graph=graph) as sess: mgd = tf.saved_model.loader.load( sess, tags=[tf.saved_model.tag_constants.SERVING], export_dir=outdir) finally: shutil.rmtree(outdir) sigs = mgd.signature_def # TODO(b/123102455): Figure out how to support the init_op. The meta graph def # contains sigs['__saved_model_init_op'].outputs['__saved_model_init_op']. It # probably won't do what we want, because it will want to read from # Checkpoints, not just run Variable initializerse (?). The right solution may # be to grab the target_poly.get_initialization_function(), and save a sig for # that. # Now, traverse the signature from the MetaGraphDef to find # find the actual tensor names and write them into the bindings. finalize_binding(parameter_binding, sigs['serving_default'].inputs) finalize_binding(result_binding, sigs['serving_default'].outputs) annotated_type = computation_types.FunctionType(parameter_type, result_type) return pb.Computation(type=pb.Type(function=pb.FunctionType( parameter=type_serialization.serialize_type(parameter_type), result=type_serialization.serialize_type(result_type))), tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def( mgd.graph_def), parameter=parameter_binding, result=result_binding)), annotated_type
def _wrap(fn, parameter_type, wrapper_fn): """Wrap a given `fn` with a given `parameter_type` using `wrapper_fn`. This method does not handle the multiple modes of usage as wrapper/decorator, as those are handled by ComputationWrapper below. It focused on the simple case with a function/defun (always present) and either a valid parameter type or an indication that there's no parameter (None). The only ambiguity left to resolve is whether `fn` should be immediately wrapped, or treated as a polymorphic callable to be wrapped upon invocation based on actual parameter types. The determination is based on the presence or absence of parameters in the declaration of `fn`. In order to be treated as a concrete no-argument computation, `fn` shouldn't declare any arguments (even with default values). The `wrapper_fn` must accept three arguments, and optional forth kwarg `name`: * `target_fn'`, the Python function that to be wrapped, accepting possibly *args and **kwargs. * Either None for a no-parameter computation, or the type of the computation's parameter (an instance of `computation_types.Type`) if the computation has one. * `unpack`, an argument which will be passed on to `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`. See that function for details. * Optional `name`, the name of the function that is being wrapped (only for debugging purposes). Args: fn: The function or defun to wrap as a computation. parameter_type: The parameter type accepted by the computation, or None if there is no parameter. wrapper_fn: The Python callable that performs actual wrapping. The object to be returned by this function should be an instance of a `ConcreteFunction`. Returns: Either the result of wrapping (an object that represents the computation), or a polymorphic callable that performs wrapping upon invocation based on argument types. The returned function still may accept multiple arguments (it has not yet had `function_uils.wrap_as_zero_or_one_arg_callable` applied to it). Raises: TypeError: if the arguments are of the wrong types, or the `wrapper_fn` constructs something that isn't a ConcreteFunction. """ try: fn_name = fn.__name__ except AttributeError: fn_name = None argspec = function_utils.get_argspec(fn) parameter_type = computation_types.to_type(parameter_type) if not parameter_type: if (argspec.args or argspec.varargs or argspec.keywords): # There is no TFF type specification, and the function/defun declares # parameters. Create a polymorphic template. def _wrap_polymorphic(wrapper_fn, fn, parameter_type, name=fn_name): return wrapper_fn(fn, parameter_type, unpack=True, name=name) polymorphic_fn = function_utils.PolymorphicFunction( lambda pt: _wrap_polymorphic(wrapper_fn, fn, pt)) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on # which it was applied to the wrapped object, so we must transfer it here # explicitly. polymorphic_fn.__doc__ = getattr(fn, '__doc__', None) return polymorphic_fn # Either we have a concrete parameter type, or this is no-arg function. concrete_fn = wrapper_fn(fn, parameter_type, unpack=None) py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction, 'value returned by the wrapper') if not type_utils.are_equivalent_types(concrete_fn.type_signature.parameter, parameter_type): raise TypeError( 'Expected a concrete function that takes parameter {}, got one ' 'that takes {}.'.format( str(parameter_type), str(concrete_fn.type_signature.parameter))) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on concrete_fn.__doc__ = getattr(fn, '__doc__', None) return concrete_fn