def test_get_callargs_for_argspec(self, fn, args, kwargs):
    argspec = function_utils.get_argspec(fn)
    expected_error = None
    try:
      signature = inspect.signature(fn)
      bound_arguments = signature.bind(*args, **kwargs)
      bound_arguments.apply_defaults()
      expected_callargs = bound_arguments.arguments
    except TypeError as e:
      expected_error = e
      expected_callargs = None

    result_callargs = None
    if expected_error is None:
      try:
        result_callargs = function_utils.get_callargs_for_argspec(
            argspec, *args, **kwargs)
        self.assertEqual(result_callargs, expected_callargs)
      except (TypeError, AssertionError) as test_err:
        raise AssertionError(
            'With argspec {!s}, args {!s}, kwargs {!s}, expected callargs {!s} '
            'and error {!s}, tested function returned {!s} and the test has '
            'failed with message: {!s}'.format(argspec, args, kwargs,
                                               expected_callargs,
                                               expected_error, result_callargs,
                                               test_err))
    else:
      with self.assertRaises(TypeError):
        result_callargs = function_utils.get_callargs_for_argspec(
            argspec, *args, **kwargs)
 def test_get_defun_argspec_with_untyped_non_eager_defun(self):
   # In a non-eager function with no input signature, the same restrictions as
   # in a typed eager function apply.
   fn = tf.function(lambda x, y, *z: None)
   self.assertEqual(
       function_utils.get_argspec(fn),
       function_utils.SimpleArgSpec(
           args=['x', 'y'], varargs='z', keywords=None, defaults=None))
 def test_get_defun_argspec_with_typed_non_eager_defun(self):
   # In a non-eager function with a defined input signature, **kwargs or
   # default values are not allowed, but *args are, and the input signature may
   # overlap with *args.
   fn = tf.function(lambda x, y, *z: None, (
       tf.TensorSpec(None, tf.int32),
       tf.TensorSpec(None, tf.bool),
       tf.TensorSpec(None, tf.float32),
       tf.TensorSpec(None, tf.float32),
   ))
   self.assertEqual(
       function_utils.get_argspec(fn),
       function_utils.SimpleArgSpec(
           args=['x', 'y'], varargs='z', keywords=None, defaults=None))
Exemple #4
0
def serialize_tf2_as_tf_computation(target, parameter_type, unpack=None):
  """Serializes the 'target' as a TF computation with a given parameter type.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function or `tf.function`, with arguments
      matching the 'parameter_type'.
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    unpack: Whether to always unpack the parameter_type. Necessary for support
      of polymorphic tf2_computations.

  Returns:
    The constructed `pb.Computation` instance with the `pb.TensorFlow` variant
      set.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
  py_typecheck.check_callable(target)
  parameter_type = computation_types.to_type(parameter_type)
  argspec = function_utils.get_argspec(target)
  if argspec.args and parameter_type is None:
    raise ValueError(
        'Expected the target to declare no parameters, found {!r}.'.format(
            argspec.args))

  # In the codepath for TF V1 based serialization (tff.tf_computation),
  # we get the "wrapped" function to serialize. Here, target is the
  # raw function to be wrapped; however, we still need to know if
  # the parameter_type should be unpacked into multiple args and kwargs
  # in order to construct the TensorSpecs to be passed in the call
  # to get_concrete_fn below.
  unpack = function_utils.infer_unpack_needed(target, parameter_type, unpack)
  arg_typespecs, kwarg_typespecs, parameter_binding = (
      tensorflow_utils.get_tf_typespec_and_binding(
          parameter_type, arg_names=argspec.args, unpack=unpack))

  # Pseudo-global to be appended to once when target_poly below is traced.
  type_and_binding_slot = []

  # N.B. To serialize a tf.function or eager python code,
  # the return type must be a flat list, tuple, or dict. However, the
  # tff.tf_computation must be able to handle structured inputs and outputs.
  # Thus, we intercept the result of calling the original target fn, introspect
  # its structure to create a result_type and bindings, and then return a
  # flat dict output. It is this new "unpacked" tf.function that we will
  # serialize using tf.saved_model.save.
  #
  # TODO(b/117428091): The return type limitation is primarily a limitation of
  # SignatureDefs  and therefore of the signatures argument to
  # tf.saved_model.save. tf.functions attached to objects and loaded back with
  # tf.saved_model.load can take/return nests; this might offer a better
  # approach to the one taken here.

  @tf.function
  def target_poly(*args, **kwargs):
    result = target(*args, **kwargs)
    result_dict, result_type, result_binding = (
        tensorflow_utils.get_tf2_result_dict_and_binding(result))
    assert not type_and_binding_slot
    # A "side channel" python output.
    type_and_binding_slot.append((result_type, result_binding))
    return result_dict

  # Triggers tracing so that type_and_binding_slot is filled.
  cc_fn = target_poly.get_concrete_function(*arg_typespecs, **kwarg_typespecs)
  assert len(type_and_binding_slot) == 1
  result_type, result_binding = type_and_binding_slot[0]

  # N.B. Note that cc_fn does *not* accept the same args and kwargs as the
  # Python target_poly; instead, it must be called with **kwargs based on the
  # unique names embedded in the TensorSpecs inside arg_typespecs and
  # kwarg_typespecs. The (preliminary) parameter_binding tracks the mapping
  # between these tensor names and the components of the (possibly nested) TFF
  # input type. When cc_fn is serialized, concrete tensors for each input are
  # introduced, and the call finalize_binding(parameter_binding,
  # sigs['serving_default'].inputs) updates the bindings to reference these
  # concrete tensors.

  # Associate vars with unique names and explicitly attach to the Checkpoint:
  var_dict = {
      'var{:02d}'.format(i): v for i, v in enumerate(cc_fn.graph.variables)
  }
  saveable = tf.train.Checkpoint(fn=target_poly, **var_dict)

  try:
    # TODO(b/122081673): All we really need is the  meta graph def, we could
    # probably just load that directly, e.g., using parse_saved_model from
    # tensorflow/python/saved_model/loader_impl.py, but I'm not sure we want to
    # depend on that presumably non-public symbol. Perhaps TF can expose a way
    # to just get the MetaGraphDef directly without saving to a tempfile? This
    # looks like a small change to v2.saved_model.save().
    outdir = tempfile.mkdtemp('savedmodel')
    tf.saved_model.save(saveable, outdir, signatures=cc_fn)

    graph = tf.Graph()
    with tf.compat.v1.Session(graph=graph) as sess:
      mgd = tf.compat.v1.saved_model.load(
          sess, tags=[tf.saved_model.SERVING], export_dir=outdir)
  finally:
    shutil.rmtree(outdir)
  sigs = mgd.signature_def

  # TODO(b/123102455): Figure out how to support the init_op. The meta graph def
  # contains sigs['__saved_model_init_op'].outputs['__saved_model_init_op']. It
  # probably won't do what we want, because it will want to read from
  # Checkpoints, not just run Variable initializerse (?). The right solution may
  # be to grab the target_poly.get_initialization_function(), and save a sig for
  # that.

  # Now, traverse the signature from the MetaGraphDef to find
  # find the actual tensor names and write them into the bindings.
  finalize_binding(parameter_binding, sigs['serving_default'].inputs)
  finalize_binding(result_binding, sigs['serving_default'].outputs)

  annotated_type = computation_types.FunctionType(parameter_type, result_type)

  return pb.Computation(
      type=pb.Type(
          function=pb.FunctionType(
              parameter=type_serialization.serialize_type(parameter_type),
              result=type_serialization.serialize_type(result_type))),
      tensorflow=pb.TensorFlow(
          graph_def=serialization_utils.pack_graph_def(mgd.graph_def),
          parameter=parameter_binding,
          result=result_binding)), annotated_type
Exemple #5
0
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
  """Serializes the 'target' as a TF computation with a given parameter type.

  See also `serialize_tf2_as_tf_computation` for TensorFlow 2
  serialization.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    A tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
  # TODO(b/113112108): Support a greater variety of target type signatures,
  # with keyword args or multiple args corresponding to elements of a tuple.
  # Document all accepted forms with examples in the API, and point to there
  # from here.

  py_typecheck.check_type(target, types.FunctionType)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  parameter_type = computation_types.to_type(parameter_type)
  argspec = function_utils.get_argspec(target)

  with tf.Graph().as_default() as graph:
    args = []
    if parameter_type is not None:
      if len(argspec.args) != 1:
        raise ValueError(
            'Expected the target to declare exactly one parameter, found {!r}.'
            .format(argspec.args))
      parameter_name = argspec.args[0]
      parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
          parameter_name, parameter_type, graph)
      args.append(parameter_value)
    else:
      if argspec.args:
        raise ValueError(
            'Expected the target to declare no parameters, found {!r}.'.format(
                argspec.args))
      parameter_binding = None
    context = tf_computation_context.TensorFlowComputationContext(graph)
    with context_stack.install(context):
      result = target(*args)

      # TODO(b/122081673): This needs to change for TF 2.0. We may also
      # want to allow the person creating a tff.tf_computation to specify
      # a different initializer; e.g., if it is known that certain
      # variables will be assigned immediately to arguments of the function,
      # then it is wasteful to initialize them before this.
      #
      # The following is a bit of a work around: the collections below may
      # contain variables more than once, hence we throw into a set. TFF needs
      # to ensure all variables are initialized, but not all variables are
      # always in the collections we expect. tff.learning._KerasModel tries to
      # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into
      # VARS_FOR_TFF_TO_INITIALIZE for now.
      all_variables = set(tf.compat.v1.global_variables() +
                          tf.compat.v1.local_variables() +
                          tf.compat.v1.get_collection(
                              graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE))
      if all_variables:
        # Use a readable but not-too-long name for the init_op.
        name = 'init_op_for_' + '_'.join(
            [v.name.replace(':0', '') for v in all_variables])
        if len(name) > 50:
          name = 'init_op_for_{}_variables'.format(len(all_variables))
        with tf.control_dependencies(context.init_ops):
          # Before running the main new init op, run any initializers for sub-
          # computations from context.init_ops. Variables from import_graph_def
          # will not make it into the global collections, and so will not be
          # initialized without this code path.
          init_op_name = tf.compat.v1.initializers.variables(
              all_variables, name=name).name
      elif context.init_ops:
        init_op_name = tf.group(
            *context.init_ops, name='subcomputation_init_ops').name
      else:
        init_op_name = None

    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  annotated_type = computation_types.FunctionType(parameter_type, result_type)

  return pb.Computation(
      type=pb.Type(
          function=pb.FunctionType(
              parameter=type_serialization.serialize_type(parameter_type),
              result=type_serialization.serialize_type(result_type))),
      tensorflow=pb.TensorFlow(
          graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
          parameter=parameter_binding,
          result=result_binding,
          initialize_op=init_op_name)), annotated_type
def _wrap(fn, parameter_type, wrapper_fn):
    """Wrap a given `fn` with a given `parameter_type` using `wrapper_fn`.

  This method does not handle the multiple modes of usage as wrapper/decorator,
  as those are handled by ComputationWrapper below. It focused on the simple
  case with a function/defun (always present) and either a valid parameter type
  or an indication that there's no parameter (None).

  The only ambiguity left to resolve is whether `fn` should be immediately
  wrapped, or treated as a polymorphic callable to be wrapped upon invocation
  based on actual parameter types. The determination is based on the presence
  or absence of parameters in the declaration of `fn`. In order to be
  treated as a concrete no-argument computation, `fn` shouldn't declare any
  arguments (even with default values).

  The `wrapper_fn` must accept three arguments, and optional forth kwarg `name`:

  * `target_fn'`, the Python function that to be wrapped, accepting possibly
    *args and **kwargs.

  * Either None for a no-parameter computation, or the type of the computation's
    parameter (an instance of `computation_types.Type`) if the computation has
    one.

  * `unpack`, an argument which will be passed on to
    `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`.
    See that function for details.

  * Optional `name`, the name of the function that is being wrapped (only for
    debugging purposes).

  Args:
    fn: The function or defun to wrap as a computation.
    parameter_type: The parameter type accepted by the computation, or None if
      there is no parameter.
    wrapper_fn: The Python callable that performs actual wrapping. The object to
      be returned by this function should be an instance of a
      `ConcreteFunction`.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types. The returned function still may accept multiple
    arguments (it has not yet had
    `function_uils.wrap_as_zero_or_one_arg_callable` applied to it).

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
    try:
        fn_name = fn.__name__
    except AttributeError:
        fn_name = None
    argspec = function_utils.get_argspec(fn)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type is None:
        if (argspec.args or argspec.varargs or argspec.keywords):
            # There is no TFF type specification, and the function/defun declares
            # parameters. Create a polymorphic template.
            def _wrap_polymorphic(wrapper_fn,
                                  fn,
                                  parameter_type,
                                  name=fn_name):
                return wrapper_fn(fn, parameter_type, unpack=True, name=name)

            polymorphic_fn = function_utils.PolymorphicFunction(
                lambda pt: _wrap_polymorphic(wrapper_fn, fn, pt))

            # When applying a decorator, the __doc__ attribute with the documentation
            # in triple-quotes is not automatically transferred from the function on
            # which it was applied to the wrapped object, so we must transfer it here
            # explicitly.
            polymorphic_fn.__doc__ = getattr(fn, '__doc__', None)
            return polymorphic_fn

    # Either we have a concrete parameter type, or this is no-arg function.
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
    py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
                            'value returned by the wrapper')
    if not type_utils.are_equivalent_types(
            concrete_fn.type_signature.parameter, parameter_type):
        raise TypeError(
            'Expected a concrete function that takes parameter {}, got one '
            'that takes {}.'.format(str(parameter_type),
                                    str(concrete_fn.type_signature.parameter)))
    # When applying a decorator, the __doc__ attribute with the documentation
    # in triple-quotes is not automatically transferred from the function on
    concrete_fn.__doc__ = getattr(fn, '__doc__', None)
    return concrete_fn
class FuncUtilsTest(test.TestCase, parameterized.TestCase):

  def test_is_defun(self):
    self.assertTrue(function_utils.is_defun(tf.function(lambda x: None)))
    fn = tf.function(lambda x: None, (tf.TensorSpec(None, tf.int32),))
    self.assertTrue(function_utils.is_defun(fn))
    self.assertFalse(function_utils.is_defun(lambda x: None))
    self.assertFalse(function_utils.is_defun(None))

  def test_get_defun_argspec_with_typed_non_eager_defun(self):
    # In a non-eager function with a defined input signature, **kwargs or
    # default values are not allowed, but *args are, and the input signature may
    # overlap with *args.
    fn = tf.function(lambda x, y, *z: None, (
        tf.TensorSpec(None, tf.int32),
        tf.TensorSpec(None, tf.bool),
        tf.TensorSpec(None, tf.float32),
        tf.TensorSpec(None, tf.float32),
    ))
    self.assertEqual(
        function_utils.get_argspec(fn),
        function_utils.SimpleArgSpec(
            args=['x', 'y'], varargs='z', keywords=None, defaults=None))

  def test_get_defun_argspec_with_untyped_non_eager_defun(self):
    # In a non-eager function with no input signature, the same restrictions as
    # in a typed eager function apply.
    fn = tf.function(lambda x, y, *z: None)
    self.assertEqual(
        function_utils.get_argspec(fn),
        function_utils.SimpleArgSpec(
            args=['x', 'y'], varargs='z', keywords=None, defaults=None))

  # pyformat: disable
  @parameterized.parameters(
      itertools.product(
          # Values of 'fn' to test.
          [lambda: None,
           lambda a: None,
           lambda a, b: None,
           lambda *a: None,
           lambda **a: None,
           lambda *a, **b: None,
           lambda a, *b: None,
           lambda a, **b: None,
           lambda a, b, **c: None,
           lambda a, b=10: None,
           lambda a, b=10, c=20: None,
           lambda a, b=10, *c: None,
           lambda a, b=10, **c: None,
           lambda a, b=10, *c, **d: None,
           lambda a, b, c=10, *d: None,
           lambda a=10, b=20, c=30, **d: None],
          # Values of 'args' to test.
          [[], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4]],
          # Values of 'kwargs' to test.
          [{}, {'b': 100}, {'name': 'foo'}, {'b': 100, 'name': 'foo'}]))
  # pyformat: enable
  def test_get_callargs_for_argspec(self, fn, args, kwargs):
    argspec = function_utils.get_argspec(fn)
    expected_error = None
    try:
      signature = inspect.signature(fn)
      bound_arguments = signature.bind(*args, **kwargs)
      bound_arguments.apply_defaults()
      expected_callargs = bound_arguments.arguments
    except TypeError as e:
      expected_error = e
      expected_callargs = None

    result_callargs = None
    if expected_error is None:
      try:
        result_callargs = function_utils.get_callargs_for_argspec(
            argspec, *args, **kwargs)
        self.assertEqual(result_callargs, expected_callargs)
      except (TypeError, AssertionError) as test_err:
        raise AssertionError(
            'With argspec {!s}, args {!s}, kwargs {!s}, expected callargs {!s} '
            'and error {!s}, tested function returned {!s} and the test has '
            'failed with message: {!s}'.format(argspec, args, kwargs,
                                               expected_callargs,
                                               expected_error, result_callargs,
                                               test_err))
    else:
      with self.assertRaises(TypeError):
        result_callargs = function_utils.get_callargs_for_argspec(
            argspec, *args, **kwargs)

  # pyformat: disable
  # pylint: disable=g-complex-comprehension
  @parameterized.parameters(
      (function_utils.get_argspec(params[0]),) + params[1:]
      for params in [
          (lambda a: None, [tf.int32], {}),
          (lambda a, b=True: None, [tf.int32, tf.bool], {}),
          (lambda a, b=True: None, [tf.int32], {'b': tf.bool}),
          (lambda a, b=True: None, [tf.bool], {'b': tf.bool}),
          (lambda a=10, b=True: None, [tf.int32], {'b': tf.bool}),
      ]
  )
  # pylint: enable=g-complex-comprehension
  # pyformat: enable
  def test_is_argspec_compatible_with_types_true(self, argspec, args, kwargs):
    self.assertTrue(
        function_utils.is_argspec_compatible_with_types(
            argspec, *[computation_types.to_type(a) for a in args],
            **{k: computation_types.to_type(v) for k, v in kwargs.items()}))

  # pyformat: disable
  # pylint: disable=g-complex-comprehension
  @parameterized.parameters(
      (function_utils.get_argspec(params[0]),) + params[1:]
      for params in [
          (lambda a=True: None, [tf.int32], {}),
          (lambda a=10, b=True: None, [tf.bool], {'b': tf.bool}),
      ]
  )
  # pylint: enable=g-complex-comprehension
  # pyformat: enable
  def test_is_argspec_compatible_with_types_false(self, argspec, args, kwargs):
    self.assertFalse(
        function_utils.is_argspec_compatible_with_types(
            argspec, *[computation_types.to_type(a) for a in args],
            **{k: computation_types.to_type(v) for k, v in kwargs.items()}))

  # pyformat: disable
  @parameterized.parameters(
      (tf.int32, False),
      ([tf.int32, tf.int32], True),
      ([tf.int32, ('b', tf.int32)], True),
      ([('a', tf.int32), ('b', tf.int32)], True),
      ([('a', tf.int32), tf.int32], False),
      (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), True),
      (anonymous_tuple.AnonymousTuple([('a', 1), (None, 2)]), False))
  # pyformat: enable
  def test_is_argument_tuple(self, arg, expected_result):
    self.assertEqual(function_utils.is_argument_tuple(arg), expected_result)

  # pyformat: disable
  @parameterized.parameters(
      (anonymous_tuple.AnonymousTuple([(None, 1)]), [1], {}),
      (anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2)]), [1], {'a': 2}))
  # pyformat: enable
  def test_unpack_args_from_anonymous_tuple(self, tuple_with_args,
                                            expected_args, expected_kwargs):
    self.assertEqual(
        function_utils.unpack_args_from_tuple(tuple_with_args),
        (expected_args, expected_kwargs))

  # pyformat: disable
  @parameterized.parameters(
      ([tf.int32], [tf.int32], {}),
      ([('a', tf.int32)], [], {'a': tf.int32}),
      ([tf.int32, tf.bool], [tf.int32, tf.bool], {}),
      ([tf.int32, ('b', tf.bool)], [tf.int32], {'b': tf.bool}),
      ([('a', tf.int32), ('b', tf.bool)], [], {'a': tf.int32, 'b': tf.bool}))
  # pyformat: enable
  def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args,
                                       expected_kwargs):
    args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args)
    self.assertEqual(len(args), len(expected_args))
    for idx, arg in enumerate(args):
      self.assertTrue(
          type_utils.are_equivalent_types(
              arg, computation_types.to_type(expected_args[idx])))
    self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys()))
    for k, v in kwargs.items():
      self.assertTrue(
          type_utils.are_equivalent_types(
              computation_types.to_type(v), expected_kwargs[k]))

  def test_pack_args_into_anonymous_tuple_without_type_spec(self):
    self.assertEqual(
        function_utils.pack_args_into_anonymous_tuple([1], {'a': 10}),
        anonymous_tuple.AnonymousTuple([(None, 1), ('a', 10)]))
    self.assertIn(
        function_utils.pack_args_into_anonymous_tuple([1, 2], {
            'a': 10,
            'b': 20
        }), [
            anonymous_tuple.AnonymousTuple([
                (None, 1),
                (None, 2),
                ('a', 10),
                ('b', 20),
            ]),
            anonymous_tuple.AnonymousTuple([
                (None, 1),
                (None, 2),
                ('b', 20),
                ('a', 10),
            ])
        ])
    self.assertIn(
        function_utils.pack_args_into_anonymous_tuple([], {
            'a': 10,
            'b': 20
        }), [
            anonymous_tuple.AnonymousTuple([('a', 10), ('b', 20)]),
            anonymous_tuple.AnonymousTuple([('b', 20), ('a', 10)])
        ])
    self.assertEqual(
        function_utils.pack_args_into_anonymous_tuple([1], {}),
        anonymous_tuple.AnonymousTuple([(None, 1)]))

  # pyformat: disable
  @parameterized.parameters(
      ([1], {}, [tf.int32], [(None, 1)]),
      ([1, True], {}, [tf.int32, tf.bool], [(None, 1), (None, True)]),
      ([1, True], {}, [('x', tf.int32), ('y', tf.bool)],
       [('x', 1), ('y', True)]),
      ([1], {'y': True}, [('x', tf.int32), ('y', tf.bool)],
       [('x', 1), ('y', True)]),
      ([], {'x': 1, 'y': True}, [('x', tf.int32), ('y', tf.bool)],
       [('x', 1), ('y', True)]),
      ([], collections.OrderedDict([('y', True), ('x', 1)]),
       [('x', tf.int32), ('y', tf.bool)],
       [('x', 1), ('y', True)]))
  # pyformat: enable
  def test_pack_args_into_anonymous_tuple_with_type_spec_expect_success(
      self, args, kwargs, type_spec, elements):
    self.assertEqual(
        function_utils.pack_args_into_anonymous_tuple(
            args, kwargs, type_spec, NoopIngestContextForTest()),
        anonymous_tuple.AnonymousTuple(elements))

  # pyformat: disable
  @parameterized.parameters(
      ([1], {}, [(tf.bool)]),
      ([], {'x': 1, 'y': True}, [(tf.int32), (tf.bool)]))
  # pyformat: enable
  def test_pack_args_into_anonymous_tuple_with_type_spec_expect_failure(
      self, args, kwargs, type_spec):
    with self.assertRaises(TypeError):
      function_utils.pack_args_into_anonymous_tuple(args, kwargs, type_spec,
                                                    NoopIngestContextForTest())

  # pyformat: disable
  @parameterized.parameters(
      (None, [], {}, 'None'),
      (tf.int32, [1], {}, '1'),
      ([tf.int32, tf.bool], [1, True], {}, '<1,True>'),
      ([('x', tf.int32), ('y', tf.bool)], [1, True], {}, '<x=1,y=True>'),
      ([('x', tf.int32), ('y', tf.bool)], [1], {'y': True}, '<x=1,y=True>'),
      ([tf.int32, tf.bool],
       [anonymous_tuple.AnonymousTuple([(None, 1), (None, True)])], {},
       '<1,True>'))
  # pyformat: enable
  def test_pack_args(self, parameter_type, args, kwargs, expected_value_string):
    self.assertEqual(
        str(
            function_utils.pack_args(parameter_type, args, kwargs,
                                     NoopIngestContextForTest())),
        expected_value_string)

  # pyformat: disable
  @parameterized.parameters(
      (1, lambda: 10, None, None, None, 10),
      (2, lambda x=1: x + 10, None, None, None, 11),
      (3, lambda x=1: x + 10, tf.int32, None, 20, 30),
      (4, lambda x, y: x + y, [tf.int32, tf.int32], None,
       anonymous_tuple.AnonymousTuple([('x', 5), ('y', 6)]), 11),
      (5, lambda *args: str(args), [tf.int32, tf.int32], True,
       anonymous_tuple.AnonymousTuple([(None, 5), (None, 6)]), '(5, 6)'),
      (6, lambda *args: str(args), [('x', tf.int32), ('y', tf.int32)], False,
       anonymous_tuple.AnonymousTuple([('x', 5), ('y', 6)]),
       '(AnonymousTuple([(\'x\', 5), (\'y\', 6)]),)'),
      (7, lambda x: str(x),  # pylint: disable=unnecessary-lambda
       [tf.int32], None, anonymous_tuple.AnonymousTuple([(None, 10)]), '[10]'))
  # pyformat: enable
  def test_wrap_as_zero_or_one_arg_callable(self, unused_index, fn,
                                            parameter_type, unpack, arg,
                                            expected_result):
    wrapped_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        fn, parameter_type, unpack)
    actual_result = wrapped_fn(arg) if parameter_type else wrapped_fn()
    self.assertEqual(actual_result, expected_result)

  def test_polymorphic_function(self):

    class ContextForTest(context_base.Context):

      def ingest(self, val, type_spec):
        return val

      def invoke(self, comp, arg):
        return 'name={},type={},arg={}'.format(
            comp.name, str(comp.type_signature.parameter), str(arg))

    class TestFunction(function_utils.ConcreteFunction):

      def __init__(self, name, parameter_type):
        self._name = name
        super().__init__(
            computation_types.FunctionType(parameter_type, tf.string),
            context_stack_impl.context_stack)

      @property
      def name(self):
        return self._name

    class TestFunctionFactory(object):

      def __init__(self):
        self._count = 0

      def __call__(self, parameter_type):
        self._count = self._count + 1
        return TestFunction(str(self._count), parameter_type)

    with context_stack_impl.context_stack.install(ContextForTest()):
      fn = function_utils.PolymorphicFunction(TestFunctionFactory())
      self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>')
      self.assertEqual(
          fn(20, x=True), 'name=2,type=<int32,x=bool>,arg=<20,x=True>')
      self.assertEqual(fn(True), 'name=3,type=<bool>,arg=<True>')
      self.assertEqual(
          fn(30, x=40), 'name=4,type=<int32,x=int32>,arg=<30,x=40>')
      self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>')
      self.assertEqual(
          fn(0, x=False), 'name=2,type=<int32,x=bool>,arg=<0,x=False>')
      self.assertEqual(fn(False), 'name=3,type=<bool>,arg=<False>')
      self.assertEqual(
          fn(60, x=70), 'name=4,type=<int32,x=int32>,arg=<60,x=70>')

  def test_concrete_function(self):

    class ContextForTest(context_base.Context):

      def ingest(self, val, type_spec):
        return val

      def invoke(self, comp, arg):
        return comp.invoke_fn(arg)

    class TestFunction(function_utils.ConcreteFunction):

      def __init__(self, type_signature, invoke_fn):
        super().__init__(type_signature, context_stack_impl.context_stack)
        self._invoke_fn = invoke_fn

      def invoke_fn(self, arg):
        return self._invoke_fn(arg)

    with context_stack_impl.context_stack.install(ContextForTest()):
      fn = TestFunction(
          computation_types.FunctionType(tf.int32, tf.bool), lambda x: x > 10)
      self.assertEqual(fn(5), False)
      self.assertEqual(fn(15), True)

      fn = TestFunction(
          computation_types.FunctionType([('x', tf.int32), ('y', tf.int32)],
                                         tf.bool), lambda arg: arg.x > arg.y)
      self.assertEqual(fn(5, 10), False)
      self.assertEqual(fn(10, 5), True)
      self.assertEqual(fn(y=10, x=5), False)
      self.assertEqual(fn(y=5, x=10), True)
      self.assertEqual(fn(10, y=5), True)
def _wrap(fn, parameter_type, wrapper_fn):
    """Wraps a possibly-polymorphic `fn` in `wrapper_fn`.

  If `parameter_type` is `None` and `fn` takes any arguments (even with default
  values), `fn` is inferred to be polymorphic and won't be passed to
  `wrapper_fn` until invocation time (when concrete parameter types are
  available).

  `wrapper_fn` must accept three positional arguments and one defaulted argument
  `name`:

  * `target_fn`, the Python function to be wrapped.

  * `parameter_type`, the optional type of the computation's
    parameter (an instance of `computation_types.Type`).

  * `unpack`, an argument which will be passed on to
    `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`.
    See that function for details.

  * Optional `name`, the name of the function that is being wrapped (only for
    debugging purposes).

  Args:
    fn: The function or defun to wrap as a computation.
    parameter_type: Optional type of any arguments to `fn`.
    wrapper_fn: The Python callable that performs actual wrapping. The object to
      be returned by this function should be an instance of a
      `ConcreteFunction`.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types. The returned function still may accept multiple
    arguments (it has not yet had
    `function_uils.wrap_as_zero_or_one_arg_callable` applied to it).

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
    try:
        fn_name = fn.__name__
    except AttributeError:
        fn_name = None
    argspec = function_utils.get_argspec(fn)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type is None and (argspec.args or argspec.varargs
                                   or argspec.keywords):
        # There is no TFF type specification, and the function/defun declares
        # parameters. Create a polymorphic template.
        def _wrap_polymorphic(wrapper_fn, fn, parameter_type, name=fn_name):
            return wrapper_fn(fn, parameter_type, unpack=True, name=name)

        polymorphic_fn = function_utils.PolymorphicFunction(
            lambda pt: _wrap_polymorphic(wrapper_fn, fn, pt))

        # When applying a decorator, the __doc__ attribute with the documentation
        # in triple-quotes is not automatically transferred from the function on
        # which it was applied to the wrapped object, so we must transfer it here
        # explicitly.
        polymorphic_fn.__doc__ = getattr(fn, '__doc__', None)
        return polymorphic_fn

    # Either we have a concrete parameter type, or this is no-arg function.
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
    py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
                            'value returned by the wrapper')
    if not type_utils.are_equivalent_types(
            concrete_fn.type_signature.parameter, parameter_type):
        raise TypeError(
            'Expected a concrete function that takes parameter {}, got one '
            'that takes {}.'.format(str(parameter_type),
                                    str(concrete_fn.type_signature.parameter)))
    # When applying a decorator, the __doc__ attribute with the documentation
    # in triple-quotes is not automatically transferred from the function on
    concrete_fn.__doc__ = getattr(fn, '__doc__', None)
    return concrete_fn