Ejemplo n.º 1
0
 def test_get_defun_argspec_with_untyped_non_eager_defun(self):
   # In a non-eager defun with no input signature, the same restrictions as in
   # a typed defun apply.
   self.assertEqual(
       function_utils.get_argspec(function.Defun()(lambda x, y, *z: None)),
       inspect.ArgSpec(
           args=['x', 'y'], varargs='z', keywords=None, defaults=None))
Ejemplo n.º 2
0
 def test_get_defun_argspec_with_typed_non_eager_defun(self):
   # In a non-eager defun with a defined input signature, **kwargs or default
   # values are not allowed, but *args are, and the input signature may
   # overlap with *args.
   self.assertEqual(
       function_utils.get_argspec(
           function.Defun(tf.int32, tf.bool, tf.float32,
                          tf.float32)(lambda x, y, *z: None)),
       inspect.ArgSpec(
           args=['x', 'y'], varargs='z', keywords=None, defaults=None))
def serialize_tf2_as_tf_computation(target, parameter_type, unpack=None):
    """Serializes the 'target' as a TF computation with a given parameter type.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function or `tf.function`, with arguments
      matching the 'parameter_type'.
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    unpack: Whether to always unpack the parameter_type. Necessary for support
      of polymorphic tf2_computations.

  Returns:
    The constructed `pb.Computation` instance with the `pb.TensorFlow` variant
      set.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    py_typecheck.check_callable(target)
    parameter_type = computation_types.to_type(parameter_type)
    argspec = function_utils.get_argspec(target)
    if argspec.args and parameter_type is None:
        raise ValueError(
            'Expected the target to declare no parameters, found {}.'.format(
                repr(argspec.args)))

    # In the codepath for TF V1 based serialization (tff.tf_computation),
    # we get the "wrapped" function to serialize. Here, target is the
    # raw function to be wrapped; however, we still need to know if
    # the parameter_type should be unpacked into multiple args and kwargs
    # in order to construct the TensorSpecs to be passed in the call
    # to get_concrete_fn below.
    unpack = function_utils.infer_unpack_needed(target, parameter_type, unpack)
    arg_typespecs, kwarg_typespecs, parameter_binding = (
        graph_utils.get_tf_typespec_and_binding(parameter_type,
                                                arg_names=argspec.args,
                                                unpack=unpack))

    # Pseudo-global to be appended to once when target_poly below is traced.
    type_and_binding_slot = []

    # N.B. To serialize a tf.function or eager python code,
    # the return type must be a flat list, tuple, or dict. However, the
    # tff.tf_computation must be able to handle structured inputs and outputs.
    # Thus, we intercept the result of calling the original target fn, introspect
    # its structure to create a result_type and bindings, and then return a
    # flat dict output. It is this new "unpacked" tf.function that we will
    # serialize using tf.saved_model.save.
    #
    # TODO(b/117428091): The return type limitation is primarily a limitation of
    # SignatureDefs  and therefore of the signatures argument to
    # tf.saved_model.save. tf.functions attached to objects and loaded back with
    # tf.saved_model.load can take/return nests; this might offer a better
    # approach to the one taken here.

    @tf.function(autograph=False)
    def target_poly(*args, **kwargs):
        result = target(*args, **kwargs)
        result_dict, result_type, result_binding = (
            graph_utils.get_tf2_result_dict_and_binding(result))
        assert not type_and_binding_slot
        # A "side channel" python output.
        type_and_binding_slot.append((result_type, result_binding))
        return result_dict

    # Triggers tracing so that type_and_binding_slot is filled.
    cc_fn = target_poly.get_concrete_function(*arg_typespecs,
                                              **kwarg_typespecs)
    assert len(type_and_binding_slot) == 1
    result_type, result_binding = type_and_binding_slot[0]

    # N.B. Note that cc_fn does *not* accept the same args and kwargs as the
    # Python target_poly; instead, it must be called with **kwargs based on the
    # unique names embedded in the TensorSpecs inside arg_typespecs and
    # kwarg_typespecs. The (preliminary) parameter_binding tracks the mapping
    # between these tensor names and the components of the (possibly nested) TFF
    # input type. When cc_fn is serialized, concrete tensors for each input are
    # introduced, and the call finalize_binding(parameter_binding,
    # sigs['serving_default'].inputs) updates the bindings to reference these
    # concrete tensors.

    # Associate vars with unique names and explicitly attach to the Checkpoint:
    var_dict = {
        'var{:02d}'.format(i): v
        for i, v in enumerate(cc_fn.graph.variables)
    }
    saveable = tf.train.Checkpoint(fn=target_poly, **var_dict)

    try:
        # TODO(b/122081673): All we really need is the  meta graph def, we could
        # probably just load that directly, e.g., using parse_saved_model from
        # tensorflow/python/saved_model/loader_impl.py, but I'm not sure we want to
        # depend on that presumably non-public symbol. Perhaps TF can expose a way
        # to just get the MetaGraphDef directly without saving to a tempfile? This
        # looks like a small change to v2.saved_model.save().
        outdir = tempfile.mkdtemp('savedmodel')
        tf.saved_model.save(saveable, outdir, signatures=cc_fn)

        graph = tf.Graph()
        with tf.compat.v1.Session(graph=graph) as sess:
            mgd = tf.saved_model.loader.load(
                sess,
                tags=[tf.saved_model.tag_constants.SERVING],
                export_dir=outdir)
    finally:
        shutil.rmtree(outdir)
    sigs = mgd.signature_def

    # TODO(b/123102455): Figure out how to support the init_op. The meta graph def
    # contains sigs['__saved_model_init_op'].outputs['__saved_model_init_op']. It
    # probably won't do what we want, because it will want to read from
    # Checkpoints, not just run Variable initializerse (?). The right solution may
    # be to grab the target_poly.get_initialization_function(), and save a sig for
    # that.

    # Now, traverse the signature from the MetaGraphDef to find
    # find the actual tensor names and write them into the bindings.
    finalize_binding(parameter_binding, sigs['serving_default'].inputs)
    finalize_binding(result_binding, sigs['serving_default'].outputs)

    annotated_type = computation_types.FunctionType(parameter_type,
                                                    result_type)

    return pb.Computation(type=pb.Type(function=pb.FunctionType(
        parameter=type_serialization.serialize_type(parameter_type),
        result=type_serialization.serialize_type(result_type))),
                          tensorflow=pb.TensorFlow(
                              graph_def=serialization_utils.pack_graph_def(
                                  mgd.graph_def),
                              parameter=parameter_binding,
                              result=result_binding)), annotated_type
Ejemplo n.º 4
0
def _wrap(fn, parameter_type, wrapper_fn):
  """Wrap a given `fn` with a given `parameter_type` using `wrapper_fn`.

  This method does not handle the multiple modes of usage as wrapper/decorator,
  as those are handled by ComputationWrapper below. It focused on the simple
  case with a function/defun (always present) and either a valid parameter type
  or an indication that there's no parameter (None).

  The only ambiguity left to resolve is whether `fn` should be immediately
  wrapped, or treated as a polymorphic callable to be wrapped upon invocation
  based on actual parameter types. The determination is based on the presence
  or absence of parameters in the declaration of `fn`. In order to be
  treated as a concrete no-argument computation, `fn` shouldn't declare any
  arguments (even with default values).

  The `wrapper_fn` must accept three arguments, and optional forth kwarg `name`:

  * `target_fn'`, the Python function that to be wrapped, accepting possibly
    *args and **kwargs.

  * Either None for a no-parameter computation, or the type of the computation's
    parameter (an instance of `computation_types.Type`) if the computation has
    one.

  * `unpack`, an argument which will be passed on to
    `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`.
    See that function for details.

  * Optional `name`, the name of the function that is being wrapped (only for
    debugging purposes).

  Args:
    fn: The function or defun to wrap as a computation.
    parameter_type: The parameter type accepted by the computation, or None if
      there is no parameter.
    wrapper_fn: The Python callable that performs actual wrapping. The object to
      be returned by this function should be an instance of a
      `ConcreteFunction`.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types. The returned function still may accept multiple
    arguments (it has not yet had
    `function_uils.wrap_as_zero_or_one_arg_callable` applied to it).

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
  try:
    fn_name = fn.__name__
  except AttributeError:
    fn_name = None
  argspec = function_utils.get_argspec(fn)
  parameter_type = computation_types.to_type(parameter_type)
  if not parameter_type:
    if (argspec.args or argspec.varargs or argspec.keywords):
      # There is no TFF type specification, and the function/defun declares
      # parameters. Create a polymorphic template.
      def _wrap_polymorphic(wrapper_fn, fn, parameter_type, name=fn_name):
        return wrapper_fn(fn, parameter_type, unpack=True, name=name)

      polymorphic_fn = function_utils.PolymorphicFunction(
          lambda pt: _wrap_polymorphic(wrapper_fn, fn, pt))

      # When applying a decorator, the __doc__ attribute with the documentation
      # in triple-quotes is not automatically transferred from the function on
      # which it was applied to the wrapped object, so we must transfer it here
      # explicitly.
      polymorphic_fn.__doc__ = getattr(fn, '__doc__', None)
      return polymorphic_fn

  # Either we have a concrete parameter type, or this is no-arg function.
  concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
  py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
                          'value returned by the wrapper')
  if not type_utils.are_equivalent_types(concrete_fn.type_signature.parameter,
                                         parameter_type):
    raise TypeError(
        'Expected a concrete function that takes parameter {}, got one '
        'that takes {}.'.format(
            str(parameter_type), str(concrete_fn.type_signature.parameter)))
  # When applying a decorator, the __doc__ attribute with the documentation
  # in triple-quotes is not automatically transferred from the function on
  concrete_fn.__doc__ = getattr(fn, '__doc__', None)
  return concrete_fn