def test_get_signature_with_class_property(self):
        class C:
            @property
            def x(self):
                return 99

        c = C()
        with self.assertRaises(TypeError):
            function_utils.get_signature(c.x)
    def test_get_callargs_for_signature(self, fn, args, kwargs):
        signature = function_utils.get_signature(fn)
        expected_error = None
        try:
            signature = inspect.signature(fn)
            bound_arguments = signature.bind(*args, **kwargs)
            expected_callargs = bound_arguments.arguments
        except TypeError as e:
            expected_error = e
            expected_callargs = None

        result_callargs = None
        if expected_error is None:
            try:
                bound_args = signature.bind(*args, **kwargs).arguments
                self.assertEqual(bound_args, expected_callargs)
            except (TypeError, AssertionError) as test_err:
                raise AssertionError(
                    'With signature `{!s}`, args {!s}, kwargs {!s}, expected bound '
                    'args {!s} and error {!s}, tested function returned {!s} and the '
                    'test has failed with message: {!s}'.format(
                        signature, args, kwargs, expected_callargs,
                        expected_error, result_callargs, test_err))
        else:
            with self.assertRaises(TypeError):
                _ = signature.bind(*args, **kwargs)
 def test_get_defun_argspec_with_untyped_non_eager_defun(self):
   # In a tf.function with no input signature, the same restrictions as in a
   # typed eager function apply.
   fn = tf.function(lambda x, y, *z: None)
   self.assertEqual(
       collections.OrderedDict(function_utils.get_signature(fn).parameters),
       collections.OrderedDict(
           x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD),
           y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD),
           z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL),
       ))
    def test_as_wrapper_with_classmethod(self):
        class C:
            @classmethod
            def foo(cls, x):
                return x * 2

        signature = function_utils.get_signature(C.foo)
        self.assertEqual(
            signature.parameters,
            collections.OrderedDict(x=inspect.Parameter(
                'x', inspect.Parameter.POSITIONAL_OR_KEYWORD)))
    def test_get_signature_with_class_instance_method(self):
        class C:
            def __init__(self, x):
                self._x = x

            def foo(self, y):
                return self._x * y

        c = C(5)
        signature = function_utils.get_signature(c.foo)
        self.assertEqual(
            signature.parameters,
            collections.OrderedDict(y=inspect.Parameter(
                'y', inspect.Parameter.POSITIONAL_OR_KEYWORD)))
 def test_get_defun_argspec_with_typed_non_eager_defun(self):
   # In a tf.function with a defined input signature, **kwargs or default
   # values are not allowed, but *args are, and the input signature may overlap
   # with *args.
   fn = tf.function(lambda x, y, *z: None, (
       tf.TensorSpec(None, tf.int32),
       tf.TensorSpec(None, tf.bool),
       tf.TensorSpec(None, tf.float32),
       tf.TensorSpec(None, tf.float32),
   ))
   self.assertEqual(
       collections.OrderedDict(function_utils.get_signature(fn).parameters),
       collections.OrderedDict(
           x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD),
           y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD),
           z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL),
       ))
class FunctionUtilsTest(test.TestCase, parameterized.TestCase):
    def test_get_defun_argspec_with_typed_non_eager_defun(self):
        # In a tf.function with a defined input signature, **kwargs or default
        # values are not allowed, but *args are, and the input signature may overlap
        # with *args.
        fn = tf.function(lambda x, y, *z: None, (
            tf.TensorSpec(None, tf.int32),
            tf.TensorSpec(None, tf.bool),
            tf.TensorSpec(None, tf.float32),
            tf.TensorSpec(None, tf.float32),
        ))
        self.assertEqual(
            collections.OrderedDict(
                function_utils.get_signature(fn).parameters),
            collections.OrderedDict(
                x=inspect.Parameter('x',
                                    inspect.Parameter.POSITIONAL_OR_KEYWORD),
                y=inspect.Parameter('y',
                                    inspect.Parameter.POSITIONAL_OR_KEYWORD),
                z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL),
            ))

    def test_get_defun_argspec_with_untyped_non_eager_defun(self):
        # In a tf.function with no input signature, the same restrictions as in a
        # typed eager function apply.
        fn = tf.function(lambda x, y, *z: None)
        self.assertEqual(
            collections.OrderedDict(
                function_utils.get_signature(fn).parameters),
            collections.OrderedDict(
                x=inspect.Parameter('x',
                                    inspect.Parameter.POSITIONAL_OR_KEYWORD),
                y=inspect.Parameter('y',
                                    inspect.Parameter.POSITIONAL_OR_KEYWORD),
                z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL),
            ))

    def test_get_signature_with_class_instance_method(self):
        class C:
            def __init__(self, x):
                self._x = x

            def foo(self, y):
                return self._x * y

        c = C(5)
        signature = function_utils.get_signature(c.foo)
        self.assertEqual(
            signature.parameters,
            collections.OrderedDict(y=inspect.Parameter(
                'y', inspect.Parameter.POSITIONAL_OR_KEYWORD)))

    def test_get_signature_with_class_property(self):
        class C:
            @property
            def x(self):
                return 99

        c = C()
        with self.assertRaises(TypeError):
            function_utils.get_signature(c.x)

    def test_as_wrapper_with_classmethod(self):
        class C:
            @classmethod
            def foo(cls, x):
                return x * 2

        signature = function_utils.get_signature(C.foo)
        self.assertEqual(
            signature.parameters,
            collections.OrderedDict(x=inspect.Parameter(
                'x', inspect.Parameter.POSITIONAL_OR_KEYWORD)))

    # pyformat: disable
    @parameterized.parameters(
        itertools.product(
            # Values of 'fn' to test.
            [
                lambda: None, lambda a: None, lambda a, b: None,
                lambda *a: None, lambda **a: None, lambda *a, **b: None,
                lambda a, *b: None, lambda a, **b: None,
                lambda a, b, **c: None, lambda a, b=10: None,
                lambda a, b=10, c=20: None, lambda a, b=10, *c: None,
                lambda a, b=10, **c: None, lambda a, b=10, *c, **d: None,
                lambda a, b, c=10, *d: None, lambda a=10, b=20, c=30, **d: None
            ],
            # Values of 'args' to test.
            [[], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4]],
            # Values of 'kwargs' to test.
            [{}, {
                'b': 100
            }, {
                'name': 'foo'
            }, {
                'b': 100,
                'name': 'foo'
            }]))
    # pyformat: enable
    def test_get_callargs_for_signature(self, fn, args, kwargs):
        signature = function_utils.get_signature(fn)
        expected_error = None
        try:
            signature = inspect.signature(fn)
            bound_arguments = signature.bind(*args, **kwargs)
            expected_callargs = bound_arguments.arguments
        except TypeError as e:
            expected_error = e
            expected_callargs = None

        result_callargs = None
        if expected_error is None:
            try:
                bound_args = signature.bind(*args, **kwargs).arguments
                self.assertEqual(bound_args, expected_callargs)
            except (TypeError, AssertionError) as test_err:
                raise AssertionError(
                    'With signature `{!s}`, args {!s}, kwargs {!s}, expected bound '
                    'args {!s} and error {!s}, tested function returned {!s} and the '
                    'test has failed with message: {!s}'.format(
                        signature, args, kwargs, expected_callargs,
                        expected_error, result_callargs, test_err))
        else:
            with self.assertRaises(TypeError):
                _ = signature.bind(*args, **kwargs)

    # pyformat: disable
    @parameterized.named_parameters(
        ('args_only', function_utils.get_signature(lambda a: None), [tf.int32],
         collections.OrderedDict()),
        ('args_and_kwargs_unnamed',
         function_utils.get_signature(lambda a, b=True: None),
         [tf.int32, tf.bool], collections.OrderedDict()),
        ('args_and_kwargs_named',
         function_utils.get_signature(lambda a, b=True: None), [tf.int32],
         collections.OrderedDict(b=tf.bool)),
        ('args_and_kwargs_default_int',
         function_utils.get_signature(lambda a=10, b=True: None), [tf.int32],
         collections.OrderedDict(b=tf.bool)),
    )
    # pyformat: enable
    def test_is_signature_compatible_with_types_true(self, signature, *args,
                                                     **kwargs):
        self.assertFalse(
            function_utils.is_signature_compatible_with_types(
                signature, *args, **kwargs))

    # pyformat: disable
    @parameterized.named_parameters(
        ('args_only', function_utils.get_signature(lambda a=True: None),
         [tf.int32], collections.OrderedDict()),
        ('args_and_kwargs',
         function_utils.get_signature(lambda a=10, b=True: None), [tf.bool],
         collections.OrderedDict(b=tf.bool)),
    )
    # pyformat: enable
    def test_is_signature_compatible_with_types_false(self, signature, *args,
                                                      **kwargs):
        self.assertFalse(
            function_utils.is_signature_compatible_with_types(
                signature, *args, **kwargs))

    # pyformat: disable
    @parameterized.parameters(
        (tf.int32, False), ([tf.int32, tf.int32], True),
        ([tf.int32, ('b', tf.int32)], True), ([('a', tf.int32),
                                               ('b', tf.int32)], True),
        ([('a', tf.int32), tf.int32], False),
        (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), True),
        (anonymous_tuple.AnonymousTuple([('a', 1), (None, 2)]), False))
    # pyformat: enable
    def test_is_argument_tuple(self, arg, expected_result):
        self.assertEqual(function_utils.is_argument_tuple(arg),
                         expected_result)

    # pyformat: disable
    @parameterized.parameters(
        (anonymous_tuple.AnonymousTuple([(None, 1)]), [1], {}),
        (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), [1], {
            'a': 2
        }))
    # pyformat: enable
    def test_unpack_args_from_anonymous_tuple(self, tuple_with_args,
                                              expected_args, expected_kwargs):
        self.assertEqual(
            function_utils.unpack_args_from_tuple(tuple_with_args),
            (expected_args, expected_kwargs))

    # pyformat: disable
    @parameterized.parameters(
        ([tf.int32], [tf.int32], {}), ([('a', tf.int32)], [], {
            'a': tf.int32
        }), ([tf.int32, tf.bool], [tf.int32, tf.bool], {}),
        ([tf.int32, ('b', tf.bool)], [tf.int32], {
            'b': tf.bool
        }), ([('a', tf.int32), ('b', tf.bool)], [], {
            'a': tf.int32,
            'b': tf.bool
        }))
    # pyformat: enable
    def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args,
                                         expected_kwargs):
        args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args)
        self.assertEqual(len(args), len(expected_args))
        for idx, arg in enumerate(args):
            self.assertTrue(
                arg.is_equivalent_to(
                    computation_types.to_type(expected_args[idx])))
        self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys()))
        for k, v in kwargs.items():
            self.assertTrue(
                v.is_equivalent_to(
                    computation_types.to_type(expected_kwargs[k])))

    def test_pack_args_into_anonymous_tuple_without_type_spec(self):
        self.assertEqual(
            function_utils.pack_args_into_anonymous_tuple([1], {'a': 10}),
            anonymous_tuple.AnonymousTuple([(None, 1), ('a', 10)]))
        self.assertIn(
            function_utils.pack_args_into_anonymous_tuple([1, 2], {
                'a': 10,
                'b': 20
            }), [
                anonymous_tuple.AnonymousTuple([
                    (None, 1),
                    (None, 2),
                    ('a', 10),
                    ('b', 20),
                ]),
                anonymous_tuple.AnonymousTuple([
                    (None, 1),
                    (None, 2),
                    ('b', 20),
                    ('a', 10),
                ])
            ])
        self.assertIn(
            function_utils.pack_args_into_anonymous_tuple([], {
                'a': 10,
                'b': 20
            }), [
                anonymous_tuple.AnonymousTuple([('a', 10), ('b', 20)]),
                anonymous_tuple.AnonymousTuple([('b', 20), ('a', 10)])
            ])
        self.assertEqual(
            function_utils.pack_args_into_anonymous_tuple([1], {}),
            anonymous_tuple.AnonymousTuple([(None, 1)]))

    # pyformat: disable
    @parameterized.parameters(
        ([1], {}, [tf.int32], [(None, 1)]),
        ([1, True], {}, [tf.int32, tf.bool], [(None, 1), (None, True)]),
        ([1, True], {}, [('x', tf.int32),
                         ('y', tf.bool)], [('x', 1), ('y', True)]), ([1], {
                             'y': True
                         }, [('x', tf.int32), ('y', tf.bool)], [('x', 1),
                                                                ('y', True)]),
        ([], {
            'x': 1,
            'y': True
        }, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]),
        ([], collections.OrderedDict([('y', True), ('x', 1)]), [
            ('x', tf.int32), ('y', tf.bool)
        ], [('x', 1), ('y', True)]))
    # pyformat: enable
    def test_pack_args_into_anonymous_tuple_with_type_spec_expect_success(
            self, args, kwargs, type_spec, elements):
        self.assertEqual(
            function_utils.pack_args_into_anonymous_tuple(
                args, kwargs, type_spec, NoopIngestContextForTest()),
            anonymous_tuple.AnonymousTuple(elements))

    # pyformat: disable
    @parameterized.parameters(([1], {}, [(tf.bool)]), ([], {
        'x': 1,
        'y': True
    }, [(tf.int32), (tf.bool)]))
    # pyformat: enable
    def test_pack_args_into_anonymous_tuple_with_type_spec_expect_failure(
            self, args, kwargs, type_spec):
        with self.assertRaises(TypeError):
            function_utils.pack_args_into_anonymous_tuple(
                args, kwargs, type_spec, NoopIngestContextForTest())

    # pyformat: disable
    @parameterized.parameters(
        (None, [], {}, 'None'), (tf.int32, [1], {}, '1'),
        ([tf.int32, tf.bool], [1, True], {}, '<1,True>'),
        ([('x', tf.int32), ('y', tf.bool)], [1, True], {}, '<x=1,y=True>'),
        ([('x', tf.int32), ('y', tf.bool)], [1], {
            'y': True
        }, '<x=1,y=True>'),
        ([tf.int32, tf.bool],
         [anonymous_tuple.AnonymousTuple([(None, 1),
                                          (None, True)])], {}, '<1,True>'))
    # pyformat: enable
    def test_pack_args(self, parameter_type, args, kwargs,
                       expected_value_string):
        self.assertEqual(
            str(
                function_utils.pack_args(parameter_type, args, kwargs,
                                         NoopIngestContextForTest())),
            expected_value_string)

    # pyformat: disable
    @parameterized.parameters(
        (1, lambda: 10, None, None, None, 10),
        (2, lambda x=1: x + 10, None, None, None, 11),
        (3, lambda x=1: x + 10, tf.int32, None, 20, 30),
        (4, lambda x, y: x + y, [tf.int32, tf.int32], None,
         anonymous_tuple.AnonymousTuple([('x', 5), ('y', 6)]), 11),
        (5, lambda *args: str(args), [tf.int32, tf.int32], True,
         anonymous_tuple.AnonymousTuple([(None, 5), (None, 6)]), '(5, 6)'),
        (6, lambda *args: str(args), [
            ('x', tf.int32), ('y', tf.int32)
        ], False, anonymous_tuple.AnonymousTuple([
            ('x', 5), ('y', 6)
        ]), '(AnonymousTuple([(\'x\', 5), (\'y\', 6)]),)'),
        (
            7,
            lambda x: str(x),  # pylint: disable=unnecessary-lambda
            [tf.int32],
            None,
            anonymous_tuple.AnonymousTuple([(None, 10)]),
            '[10]'))
    # pyformat: enable
    def test_wrap_as_zero_or_one_arg_callable(self, unused_index, fn,
                                              parameter_type, unpack, arg,
                                              expected_result):
        wrapped_fn = function_utils.wrap_as_zero_or_one_arg_callable(
            fn, parameter_type, unpack)
        actual_result = wrapped_fn(arg) if parameter_type else wrapped_fn()
        self.assertEqual(actual_result, expected_result)
def _wrap(fn, parameter_type, wrapper_fn):
    """Wraps a possibly-polymorphic `fn` in `wrapper_fn`.

  If `parameter_type` is `None` and `fn` takes any arguments (even with default
  values), `fn` is inferred to be polymorphic and won't be passed to
  `wrapper_fn` until invocation time (when concrete parameter types are
  available).

  `wrapper_fn` must accept three positional arguments and one defaulted argument
  `name`:

  * `target_fn`, the Python function to be wrapped.

  * `parameter_type`, the optional type of the computation's
    parameter (an instance of `computation_types.Type`).

  * `unpack`, an argument which will be passed on to
    `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`.
    See that function for details.

  * Optional `name`, the name of the function that is being wrapped (only for
    debugging purposes).

  Args:
    fn: The function or defun to wrap as a computation.
    parameter_type: Optional type of any arguments to `fn`.
    wrapper_fn: The Python callable that performs actual wrapping. The object to
      be returned by this function should be an instance of a
      `ConcreteFunction`.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types. The returned function still may accept multiple
    arguments (it has not yet had
    `function_uils.wrap_as_zero_or_one_arg_callable` applied to it).

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
    try:
        fn_name = fn.__name__
    except AttributeError:
        fn_name = None
    signature = function_utils.get_signature(fn)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type is None and signature.parameters:
        # There is no TFF type specification, and the function/defun declares
        # parameters. Create a polymorphic template.
        def _wrap_polymorphic(parameter_type: computation_types.Type,
                              unpack: Optional[bool]):
            return wrapper_fn(fn, parameter_type, unpack=unpack, name=fn_name)

        polymorphic_fn = function_utils.PolymorphicFunction(_wrap_polymorphic)

        # When applying a decorator, the __doc__ attribute with the documentation
        # in triple-quotes is not automatically transferred from the function on
        # which it was applied to the wrapped object, so we must transfer it here
        # explicitly.
        polymorphic_fn.__doc__ = getattr(fn, '__doc__', None)
        return polymorphic_fn

    # Either we have a concrete parameter type, or this is no-arg function.
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
    py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
                            'value returned by the wrapper')
    if (concrete_fn.type_signature.parameter is not None
            and not concrete_fn.type_signature.parameter.is_equivalent_to(
                parameter_type)):
        raise TypeError(
            'Expected a concrete function that takes parameter {}, got one '
            'that takes {}.'.format(str(parameter_type),
                                    str(concrete_fn.type_signature.parameter)))
    # When applying a decorator, the __doc__ attribute with the documentation
    # in triple-quotes is not automatically transferred from the function on
    concrete_fn.__doc__ = getattr(fn, '__doc__', None)
    return concrete_fn
def serialize_tf2_as_tf_computation(target, parameter_type, unpack=None):
  """Serializes the 'target' as a TF computation with a given parameter type.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function or `tf.function`, with arguments
      matching the 'parameter_type'.
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    unpack: Whether to always unpack the parameter_type. Necessary for support
      of polymorphic tf2_computations.

  Returns:
    The constructed `pb.Computation` instance with the `pb.TensorFlow` variant
      set.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
  py_typecheck.check_callable(target)
  parameter_type = computation_types.to_type(parameter_type)
  signature = function_utils.get_signature(target)
  if signature.parameters and parameter_type is None:
    raise ValueError(
        'Expected the target to declare no parameters, found {!r}.'.format(
            signature.parameters))

  # In the codepath for TF V1 based serialization (tff.tf_computation),
  # we get the "wrapped" function to serialize. Here, target is the
  # raw function to be wrapped; however, we still need to know if
  # the parameter_type should be unpacked into multiple args and kwargs
  # in order to construct the TensorSpecs to be passed in the call
  # to get_concrete_fn below.
  unpack = function_utils.infer_unpack_needed(target, parameter_type, unpack)
  arg_typespecs, kwarg_typespecs, parameter_binding = (
      tensorflow_utils.get_tf_typespec_and_binding(
          parameter_type,
          arg_names=list(signature.parameters.keys()),
          unpack=unpack))

  # Pseudo-global to be appended to once when target_poly below is traced.
  type_and_binding_slot = []

  # N.B. To serialize a tf.function or eager python code,
  # the return type must be a flat list, tuple, or dict. However, the
  # tff.tf_computation must be able to handle structured inputs and outputs.
  # Thus, we intercept the result of calling the original target fn, introspect
  # its structure to create a result_type and bindings, and then return a
  # flat dict output. It is this new "unpacked" tf.function that we will
  # serialize using tf.saved_model.save.
  #
  # TODO(b/117428091): The return type limitation is primarily a limitation of
  # SignatureDefs  and therefore of the signatures argument to
  # tf.saved_model.save. tf.functions attached to objects and loaded back with
  # tf.saved_model.load can take/return nests; this might offer a better
  # approach to the one taken here.

  @tf.function
  def target_poly(*args, **kwargs):
    result = target(*args, **kwargs)
    result_dict, result_type, result_binding = (
        tensorflow_utils.get_tf2_result_dict_and_binding(result))
    assert not type_and_binding_slot
    # A "side channel" python output.
    type_and_binding_slot.append((result_type, result_binding))
    return result_dict

  # Triggers tracing so that type_and_binding_slot is filled.
  cc_fn = target_poly.get_concrete_function(*arg_typespecs, **kwarg_typespecs)
  assert len(type_and_binding_slot) == 1
  result_type, result_binding = type_and_binding_slot[0]

  # N.B. Note that cc_fn does *not* accept the same args and kwargs as the
  # Python target_poly; instead, it must be called with **kwargs based on the
  # unique names embedded in the TensorSpecs inside arg_typespecs and
  # kwarg_typespecs. The (preliminary) parameter_binding tracks the mapping
  # between these tensor names and the components of the (possibly nested) TFF
  # input type. When cc_fn is serialized, concrete tensors for each input are
  # introduced, and the call finalize_binding(parameter_binding,
  # sigs['serving_default'].inputs) updates the bindings to reference these
  # concrete tensors.

  # Associate vars with unique names and explicitly attach to the Checkpoint:
  var_dict = {
      'var{:02d}'.format(i): v for i, v in enumerate(cc_fn.graph.variables)
  }
  saveable = tf.train.Checkpoint(fn=target_poly, **var_dict)

  try:
    # TODO(b/122081673): All we really need is the  meta graph def, we could
    # probably just load that directly, e.g., using parse_saved_model from
    # tensorflow/python/saved_model/loader_impl.py, but I'm not sure we want to
    # depend on that presumably non-public symbol. Perhaps TF can expose a way
    # to just get the MetaGraphDef directly without saving to a tempfile? This
    # looks like a small change to v2.saved_model.save().
    outdir = tempfile.mkdtemp('savedmodel')
    tf.saved_model.save(saveable, outdir, signatures=cc_fn)

    graph = tf.Graph()
    with tf.compat.v1.Session(graph=graph) as sess:
      mgd = tf.compat.v1.saved_model.load(
          sess, tags=[tf.saved_model.SERVING], export_dir=outdir)
  finally:
    shutil.rmtree(outdir)
  sigs = mgd.signature_def

  # TODO(b/123102455): Figure out how to support the init_op. The meta graph def
  # contains sigs['__saved_model_init_op'].outputs['__saved_model_init_op']. It
  # probably won't do what we want, because it will want to read from
  # Checkpoints, not just run Variable initializerse (?). The right solution may
  # be to grab the target_poly.get_initialization_function(), and save a sig for
  # that.

  # Now, traverse the signature from the MetaGraphDef to find
  # find the actual tensor names and write them into the bindings.
  finalize_binding(parameter_binding, sigs['serving_default'].inputs)
  finalize_binding(result_binding, sigs['serving_default'].outputs)

  annotated_type = computation_types.FunctionType(parameter_type, result_type)

  return pb.Computation(
      type=pb.Type(
          function=pb.FunctionType(
              parameter=type_serialization.serialize_type(parameter_type),
              result=type_serialization.serialize_type(result_type))),
      tensorflow=pb.TensorFlow(
          graph_def=serialization_utils.pack_graph_def(mgd.graph_def),
          parameter=parameter_binding,
          result=result_binding)), annotated_type
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
  """Serializes the 'target' as a TF computation with a given parameter type.

  See also `serialize_tf2_as_tf_computation` for TensorFlow 2
  serialization.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    A tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
  # TODO(b/113112108): Support a greater variety of target type signatures,
  # with keyword args or multiple args corresponding to elements of a tuple.
  # Document all accepted forms with examples in the API, and point to there
  # from here.

  py_typecheck.check_type(target, types.FunctionType)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  parameter_type = computation_types.to_type(parameter_type)
  signature = function_utils.get_signature(target)

  with tf.Graph().as_default() as graph:
    args = []
    if parameter_type is not None:
      if len(signature.parameters) != 1:
        raise ValueError(
            'Expected the target to declare exactly one parameter, found {!r}.'
            .format(signature.parameters))
      parameter_name = next(iter(signature.parameters))
      parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
          parameter_name, parameter_type, graph)
      args.append(parameter_value)
    else:
      if signature.parameters:
        raise ValueError(
            'Expected the target to declare no parameters, found {!r}.'.format(
                signature.parameters))
      parameter_binding = None
    context = tf_computation_context.TensorFlowComputationContext(graph)
    with context_stack.install(context):
      result = target(*args)

      # TODO(b/122081673): This needs to change for TF 2.0. We may also
      # want to allow the person creating a tff.tf_computation to specify
      # a different initializer; e.g., if it is known that certain
      # variables will be assigned immediately to arguments of the function,
      # then it is wasteful to initialize them before this.
      #
      # The following is a bit of a work around: the collections below may
      # contain variables more than once, hence we throw into a set. TFF needs
      # to ensure all variables are initialized, but not all variables are
      # always in the collections we expect. tff.learning._KerasModel tries to
      # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into
      # VARS_FOR_TFF_TO_INITIALIZE for now.
      all_variables = set(tf.compat.v1.global_variables() +
                          tf.compat.v1.local_variables() +
                          tf.compat.v1.get_collection(
                              graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE))
      if all_variables:
        # Use a readable but not-too-long name for the init_op.
        name = 'init_op_for_' + '_'.join(
            [v.name.replace(':0', '') for v in all_variables])
        if len(name) > 50:
          name = 'init_op_for_{}_variables'.format(len(all_variables))
        with tf.control_dependencies(context.init_ops):
          # Before running the main new init op, run any initializers for sub-
          # computations from context.init_ops. Variables from import_graph_def
          # will not make it into the global collections, and so will not be
          # initialized without this code path.
          init_op_name = tf.group(
              tf.compat.v1.initializers.variables(all_variables, name=name),
              *tf.compat.v1.get_collection(
                  tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)).name
      elif context.init_ops:
        init_op_name = tf.group(
            *context.init_ops, name='subcomputation_init_ops').name
      else:
        init_op_name = None

    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  annotated_type = computation_types.FunctionType(parameter_type, result_type)

  # WARNING: we do not really want to be modifying the graph here if we can
  # avoid it. This is purely to work around performance issues uncovered with
  # the non-standard usage of Tensorflow and have been discussed with the
  # Tensorflow core team before being added.
  clean_graph_def = _clean_graph_def(graph.as_graph_def())

  return pb.Computation(
      type=pb.Type(
          function=pb.FunctionType(
              parameter=type_serialization.serialize_type(parameter_type),
              result=type_serialization.serialize_type(result_type))),
      tensorflow=pb.TensorFlow(
          graph_def=serialization_utils.pack_graph_def(clean_graph_def),
          parameter=parameter_binding,
          result=result_binding,
          initialize_op=init_op_name)), annotated_type
def _parameters(fn):
  return function_utils.get_signature(fn).parameters.values()
Exemple #12
0
class FunctionUtilsTest(test.TestCase, parameterized.TestCase):
    def test_is_defun(self):
        self.assertTrue(function_utils.is_defun(tf.function(lambda x: None)))
        fn = tf.function(lambda x: None, (tf.TensorSpec(None, tf.int32), ))
        self.assertTrue(function_utils.is_defun(fn))
        self.assertFalse(function_utils.is_defun(lambda x: None))
        self.assertFalse(function_utils.is_defun(None))

    def test_get_defun_argspec_with_typed_non_eager_defun(self):
        # In a tf.function with a defined input signature, **kwargs or default
        # values are not allowed, but *args are, and the input signature may overlap
        # with *args.
        fn = tf.function(lambda x, y, *z: None, (
            tf.TensorSpec(None, tf.int32),
            tf.TensorSpec(None, tf.bool),
            tf.TensorSpec(None, tf.float32),
            tf.TensorSpec(None, tf.float32),
        ))
        self.assertEqual(
            collections.OrderedDict(
                function_utils.get_signature(fn).parameters),
            collections.OrderedDict(
                x=inspect.Parameter('x',
                                    inspect.Parameter.POSITIONAL_OR_KEYWORD),
                y=inspect.Parameter('y',
                                    inspect.Parameter.POSITIONAL_OR_KEYWORD),
                z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL),
            ))

    def test_get_defun_argspec_with_untyped_non_eager_defun(self):
        # In a tf.function with no input signature, the same restrictions as in a
        # typed eager function apply.
        fn = tf.function(lambda x, y, *z: None)
        self.assertEqual(
            collections.OrderedDict(
                function_utils.get_signature(fn).parameters),
            collections.OrderedDict(
                x=inspect.Parameter('x',
                                    inspect.Parameter.POSITIONAL_OR_KEYWORD),
                y=inspect.Parameter('y',
                                    inspect.Parameter.POSITIONAL_OR_KEYWORD),
                z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL),
            ))

    # pyformat: disable
    @parameterized.parameters(
        itertools.product(
            # Values of 'fn' to test.
            [
                lambda: None, lambda a: None, lambda a, b: None,
                lambda *a: None, lambda **a: None, lambda *a, **b: None,
                lambda a, *b: None, lambda a, **b: None,
                lambda a, b, **c: None, lambda a, b=10: None,
                lambda a, b=10, c=20: None, lambda a, b=10, *c: None,
                lambda a, b=10, **c: None, lambda a, b=10, *c, **d: None,
                lambda a, b, c=10, *d: None, lambda a=10, b=20, c=30, **d: None
            ],
            # Values of 'args' to test.
            [[], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4]],
            # Values of 'kwargs' to test.
            [{}, {
                'b': 100
            }, {
                'name': 'foo'
            }, {
                'b': 100,
                'name': 'foo'
            }]))
    # pyformat: enable
    def test_get_callargs_for_signature(self, fn, args, kwargs):
        signature = function_utils.get_signature(fn)
        expected_error = None
        try:
            signature = inspect.signature(fn)
            bound_arguments = signature.bind(*args, **kwargs)
            expected_callargs = bound_arguments.arguments
        except TypeError as e:
            expected_error = e
            expected_callargs = None

        result_callargs = None
        if expected_error is None:
            try:
                bound_args = signature.bind(*args, **kwargs).arguments
                self.assertEqual(bound_args, expected_callargs)
            except (TypeError, AssertionError) as test_err:
                raise AssertionError(
                    'With signature `{!s}`, args {!s}, kwargs {!s}, expected bound '
                    'args {!s} and error {!s}, tested function returned {!s} and the '
                    'test has failed with message: {!s}'.format(
                        signature, args, kwargs, expected_callargs,
                        expected_error, result_callargs, test_err))
        else:
            with self.assertRaises(TypeError):
                _ = signature.bind(*args, **kwargs)

    # pyformat: disable
    # pylint: disable=g-complex-comprehension
    @parameterized.parameters(
        (function_utils.get_signature(params[0]), ) + params[1:]
        for params in [
            (lambda a: None, [tf.int32], {}),
            (lambda a, b=True: None, [tf.int32, tf.bool], {}),
            (lambda a, b=True: None, [tf.int32], {
                'b': tf.bool
            }),
            (lambda a, b=True: None, [tf.bool], {
                'b': tf.bool
            }),
            (lambda a=10, b=True: None, [tf.int32], {
                'b': tf.bool
            }),
        ])
    # pylint: enable=g-complex-comprehension
    # pyformat: enable
    def test_is_signature_compatible_with_types_true(self, signature, args,
                                                     kwargs):
        self.assertTrue(
            function_utils.is_signature_compatible_with_types(
                signature, *[computation_types.to_type(a) for a in args],
                **{k: computation_types.to_type(v)
                   for k, v in kwargs.items()}))

    # pyformat: disable
    # pylint: disable=g-complex-comprehension
    @parameterized.parameters(
        (function_utils.get_signature(params[0]), ) + params[1:]
        for params in [
            (lambda a=True: None, [tf.int32], {}),
            (lambda a=10, b=True: None, [tf.bool], {
                'b': tf.bool
            }),
        ])
    # pylint: enable=g-complex-comprehension
    # pyformat: enable
    def test_is_signature_compatible_with_types_false(self, signature, args,
                                                      kwargs):
        self.assertFalse(
            function_utils.is_signature_compatible_with_types(
                signature, *[computation_types.to_type(a) for a in args],
                **{k: computation_types.to_type(v)
                   for k, v in kwargs.items()}))

    # pyformat: disable
    @parameterized.parameters(
        (tf.int32, False), ([tf.int32, tf.int32], True),
        ([tf.int32, ('b', tf.int32)], True), ([('a', tf.int32),
                                               ('b', tf.int32)], True),
        ([('a', tf.int32), tf.int32], False),
        (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), True),
        (anonymous_tuple.AnonymousTuple([('a', 1), (None, 2)]), False))
    # pyformat: enable
    def test_is_argument_tuple(self, arg, expected_result):
        self.assertEqual(function_utils.is_argument_tuple(arg),
                         expected_result)

    # pyformat: disable
    @parameterized.parameters(
        (anonymous_tuple.AnonymousTuple([(None, 1)]), [1], {}),
        (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), [1], {
            'a': 2
        }))
    # pyformat: enable
    def test_unpack_args_from_anonymous_tuple(self, tuple_with_args,
                                              expected_args, expected_kwargs):
        self.assertEqual(
            function_utils.unpack_args_from_tuple(tuple_with_args),
            (expected_args, expected_kwargs))

    # pyformat: disable
    @parameterized.parameters(
        ([tf.int32], [tf.int32], {}), ([('a', tf.int32)], [], {
            'a': tf.int32
        }), ([tf.int32, tf.bool], [tf.int32, tf.bool], {}),
        ([tf.int32, ('b', tf.bool)], [tf.int32], {
            'b': tf.bool
        }), ([('a', tf.int32), ('b', tf.bool)], [], {
            'a': tf.int32,
            'b': tf.bool
        }))
    # pyformat: enable
    def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args,
                                         expected_kwargs):
        args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args)
        self.assertEqual(len(args), len(expected_args))
        for idx, arg in enumerate(args):
            self.assertTrue(
                type_utils.are_equivalent_types(
                    arg, computation_types.to_type(expected_args[idx])))
        self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys()))
        for k, v in kwargs.items():
            self.assertTrue(
                type_utils.are_equivalent_types(computation_types.to_type(v),
                                                expected_kwargs[k]))

    def test_pack_args_into_anonymous_tuple_without_type_spec(self):
        self.assertEqual(
            function_utils.pack_args_into_anonymous_tuple([1], {'a': 10}),
            anonymous_tuple.AnonymousTuple([(None, 1), ('a', 10)]))
        self.assertIn(
            function_utils.pack_args_into_anonymous_tuple([1, 2], {
                'a': 10,
                'b': 20
            }), [
                anonymous_tuple.AnonymousTuple([
                    (None, 1),
                    (None, 2),
                    ('a', 10),
                    ('b', 20),
                ]),
                anonymous_tuple.AnonymousTuple([
                    (None, 1),
                    (None, 2),
                    ('b', 20),
                    ('a', 10),
                ])
            ])
        self.assertIn(
            function_utils.pack_args_into_anonymous_tuple([], {
                'a': 10,
                'b': 20
            }), [
                anonymous_tuple.AnonymousTuple([('a', 10), ('b', 20)]),
                anonymous_tuple.AnonymousTuple([('b', 20), ('a', 10)])
            ])
        self.assertEqual(
            function_utils.pack_args_into_anonymous_tuple([1], {}),
            anonymous_tuple.AnonymousTuple([(None, 1)]))

    # pyformat: disable
    @parameterized.parameters(
        ([1], {}, [tf.int32], [(None, 1)]),
        ([1, True], {}, [tf.int32, tf.bool], [(None, 1), (None, True)]),
        ([1, True], {}, [('x', tf.int32),
                         ('y', tf.bool)], [('x', 1), ('y', True)]), ([1], {
                             'y': True
                         }, [('x', tf.int32), ('y', tf.bool)], [('x', 1),
                                                                ('y', True)]),
        ([], {
            'x': 1,
            'y': True
        }, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]),
        ([], collections.OrderedDict([('y', True), ('x', 1)]), [
            ('x', tf.int32), ('y', tf.bool)
        ], [('x', 1), ('y', True)]))
    # pyformat: enable
    def test_pack_args_into_anonymous_tuple_with_type_spec_expect_success(
            self, args, kwargs, type_spec, elements):
        self.assertEqual(
            function_utils.pack_args_into_anonymous_tuple(
                args, kwargs, type_spec, NoopIngestContextForTest()),
            anonymous_tuple.AnonymousTuple(elements))

    # pyformat: disable
    @parameterized.parameters(([1], {}, [(tf.bool)]), ([], {
        'x': 1,
        'y': True
    }, [(tf.int32), (tf.bool)]))
    # pyformat: enable
    def test_pack_args_into_anonymous_tuple_with_type_spec_expect_failure(
            self, args, kwargs, type_spec):
        with self.assertRaises(TypeError):
            function_utils.pack_args_into_anonymous_tuple(
                args, kwargs, type_spec, NoopIngestContextForTest())

    # pyformat: disable
    @parameterized.parameters(
        (None, [], {}, 'None'), (tf.int32, [1], {}, '1'),
        ([tf.int32, tf.bool], [1, True], {}, '<1,True>'),
        ([('x', tf.int32), ('y', tf.bool)], [1, True], {}, '<x=1,y=True>'),
        ([('x', tf.int32), ('y', tf.bool)], [1], {
            'y': True
        }, '<x=1,y=True>'),
        ([tf.int32, tf.bool],
         [anonymous_tuple.AnonymousTuple([(None, 1),
                                          (None, True)])], {}, '<1,True>'))
    # pyformat: enable
    def test_pack_args(self, parameter_type, args, kwargs,
                       expected_value_string):
        self.assertEqual(
            str(
                function_utils.pack_args(parameter_type, args, kwargs,
                                         NoopIngestContextForTest())),
            expected_value_string)

    # pyformat: disable
    @parameterized.parameters(
        (1, lambda: 10, None, None, None, 10),
        (2, lambda x=1: x + 10, None, None, None, 11),
        (3, lambda x=1: x + 10, tf.int32, None, 20, 30),
        (4, lambda x, y: x + y, [tf.int32, tf.int32], None,
         anonymous_tuple.AnonymousTuple([('x', 5), ('y', 6)]), 11),
        (5, lambda *args: str(args), [tf.int32, tf.int32], True,
         anonymous_tuple.AnonymousTuple([(None, 5), (None, 6)]), '(5, 6)'),
        (6, lambda *args: str(args), [
            ('x', tf.int32), ('y', tf.int32)
        ], False, anonymous_tuple.AnonymousTuple([
            ('x', 5), ('y', 6)
        ]), '(AnonymousTuple([(\'x\', 5), (\'y\', 6)]),)'),
        (
            7,
            lambda x: str(x),  # pylint: disable=unnecessary-lambda
            [tf.int32],
            None,
            anonymous_tuple.AnonymousTuple([(None, 10)]),
            '[10]'))
    # pyformat: enable
    def test_wrap_as_zero_or_one_arg_callable(self, unused_index, fn,
                                              parameter_type, unpack, arg,
                                              expected_result):
        wrapped_fn = function_utils.wrap_as_zero_or_one_arg_callable(
            fn, parameter_type, unpack)
        actual_result = wrapped_fn(arg) if parameter_type else wrapped_fn()
        self.assertEqual(actual_result, expected_result)

    def test_polymorphic_function(self):
        class ContextForTest(context_base.Context):
            def ingest(self, val, type_spec):
                return val

            def invoke(self, comp, arg):
                return 'name={},type={},arg={}'.format(
                    comp.name, str(comp.type_signature.parameter), str(arg))

        class TestFunction(function_utils.ConcreteFunction):
            def __init__(self, name, parameter_type):
                self._name = name
                super().__init__(
                    computation_types.FunctionType(parameter_type, tf.string),
                    context_stack_impl.context_stack)

            @property
            def name(self):
                return self._name

        class TestFunctionFactory(object):
            def __init__(self):
                self._count = 0

            def __call__(self, parameter_type):
                self._count = self._count + 1
                return TestFunction(str(self._count), parameter_type)

        with context_stack_impl.context_stack.install(ContextForTest()):
            fn = function_utils.PolymorphicFunction(TestFunctionFactory())
            self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>')
            self.assertEqual(fn(20, x=True),
                             'name=2,type=<int32,x=bool>,arg=<20,x=True>')
            self.assertEqual(fn(True), 'name=3,type=<bool>,arg=<True>')
            self.assertEqual(fn(30, x=40),
                             'name=4,type=<int32,x=int32>,arg=<30,x=40>')
            self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>')
            self.assertEqual(fn(0, x=False),
                             'name=2,type=<int32,x=bool>,arg=<0,x=False>')
            self.assertEqual(fn(False), 'name=3,type=<bool>,arg=<False>')
            self.assertEqual(fn(60, x=70),
                             'name=4,type=<int32,x=int32>,arg=<60,x=70>')

    def test_concrete_function(self):
        class ContextForTest(context_base.Context):
            def ingest(self, val, type_spec):
                return val

            def invoke(self, comp, arg):
                return comp.invoke_fn(arg)

        class TestFunction(function_utils.ConcreteFunction):
            def __init__(self, type_signature, invoke_fn):
                super().__init__(type_signature,
                                 context_stack_impl.context_stack)
                self._invoke_fn = invoke_fn

            def invoke_fn(self, arg):
                return self._invoke_fn(arg)

        with context_stack_impl.context_stack.install(ContextForTest()):
            fn = TestFunction(
                computation_types.FunctionType(tf.int32, tf.bool),
                lambda x: x > 10)
            self.assertEqual(fn(5), False)
            self.assertEqual(fn(15), True)

            fn = TestFunction(
                computation_types.FunctionType([('x', tf.int32),
                                                ('y', tf.int32)], tf.bool),
                lambda arg: arg.x > arg.y)
            self.assertEqual(fn(5, 10), False)
            self.assertEqual(fn(10, 5), True)
            self.assertEqual(fn(y=10, x=5), False)
            self.assertEqual(fn(y=5, x=10), True)
            self.assertEqual(fn(10, y=5), True)
Exemple #13
0
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
    """Serializes the 'target' as a TF computation with a given parameter type.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `computation_types.Type`.
    context_stack: The context stack to use.

  Returns:
    A tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(target, types.FunctionType)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if parameter_type is not None:
        py_typecheck.check_type(parameter_type, computation_types.Type)
    signature = function_utils.get_signature(target)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            if len(signature.parameters) != 1:
                raise ValueError(
                    'Expected the target to declare exactly one parameter, found {!r}.'
                    .format(signature.parameters))
            parameter_name = next(iter(signature.parameters))
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                parameter_name, parameter_type, graph)
        else:
            if signature.parameters:
                raise ValueError(
                    'Expected the target to declare no parameters, found {!r}.'
                    .format(signature.parameters))
            parameter_value = None
            parameter_binding = None
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)
        with context_stack.install(context):
            with variable_utils.record_variable_creation_scope(
            ) as all_variables:
                if parameter_value is not None:
                    result = target(parameter_value)
                else:
                    result = target()
            initializer_ops = []
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                initializer_ops.append(
                    tf.compat.v1.initializers.variables(all_variables,
                                                        name=name))
            initializer_ops.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
            if initializer_ops:
                # Before running the main new init op, run any initializers for sub-
                # computations from context.init_ops. Variables from import_graph_def
                # will not make it into the global collections, and so will not be
                # initialized without this code path.
                with tf.compat.v1.control_dependencies(context.init_ops):
                    init_op_name = tf.group(*initializer_ops,
                                            name='grouped_initializers').name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)

    # WARNING: we do not really want to be modifying the graph here if we can
    # avoid it. This is purely to work around performance issues uncovered with
    # the non-standard usage of Tensorflow and have been discussed with the
    # Tensorflow core team before being added.
    clean_graph_def = _clean_graph_def(graph.as_graph_def())
    tensorflow = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(clean_graph_def),
        parameter=parameter_binding,
        result=result_binding,
        initialize_op=init_op_name)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow), type_signature