コード例 #1
0
def make_dataset_from_variant_tensor(variant_tensor, type_spec):
  """Constructs a `tf.data.Dataset` from a variant tensor and type spec.

  Args:
    variant_tensor: The variant tensor that represents the dataset.
    type_spec: The type spec of elements of the data set, either an instance of
      `types.Type` or something convertible to it.

  Returns:
    A corresponding instance of `tf.data.Dataset`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
  if not tf.is_tensor(variant_tensor):
    raise TypeError(
        'Expected `variant_tensor` to be a tensor, found {}.'.format(
            py_typecheck.type_string(type(variant_tensor))))
  if variant_tensor.dtype != tf.variant:
    raise TypeError(
        'Expected `variant_tensor` to be of a variant type, found {}.'.format(
            variant_tensor.dtype))
  return tf.data.experimental.from_variant(
      variant_tensor,
      structure=(type_conversions.type_to_tf_structure(
          computation_types.to_type(type_spec))))
コード例 #2
0
ファイル: serialization.py プロジェクト: tensorflow/federated
def _deserialize_type_spec(serialize_type_variable, python_container=None):
    """Deserialize a `tff.Type` protocol buffer into a python class instance."""
    type_spec = type_serialization.deserialize_type(
        computation_pb2.Type.FromString(
            serialize_type_variable.read_value().numpy()))
    if type_spec.is_struct() and python_container is not None:
        type_spec = computation_types.StructWithPythonType(
            structure.iter_elements(type_spec), python_container)
    return type_conversions.type_to_tf_structure(type_spec)
コード例 #3
0
 def test_without_names(self):
     expected_structure = (
         tf.TensorSpec(shape=(), dtype=tf.bool),
         tf.TensorSpec(shape=(), dtype=tf.int32),
     )
     type_spec = computation_types.to_type(expected_structure)
     tf_structure = type_conversions.type_to_tf_structure(type_spec)
     with tf.Graph().as_default():
         ds = tf.data.experimental.from_variant(tf.compat.v1.placeholder(
             tf.variant, shape=[]),
                                                structure=tf_structure)
         actual_structure = ds.element_spec
         self.assertEqual(expected_structure, actual_structure)
コード例 #4
0
 def test_with_names(self):
     expected_structure = collections.OrderedDict([
         ('a', tf.TensorSpec(shape=(), dtype=tf.bool)),
         ('b',
          collections.OrderedDict([
              ('c', tf.TensorSpec(shape=(), dtype=tf.float32)),
              ('d', tf.TensorSpec(shape=(20, ), dtype=tf.int32)),
          ])),
     ])
     type_spec = computation_types.to_type(expected_structure)
     tf_structure = type_conversions.type_to_tf_structure(type_spec)
     with tf.Graph().as_default():
         ds = tf.data.experimental.from_variant(tf.compat.v1.placeholder(
             tf.variant, shape=[]),
                                                structure=tf_structure)
         actual_structure = ds.element_spec
         self.assertEqual(expected_structure, actual_structure)
コード例 #5
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional `tf.config.LogicalDevice`.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_spec.is_equivalent_to(comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    # TODO(b/156302055): Currently, TF will raise on any function returning a
    # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
    # this gating when we can.
    must_pin_function_to_cpu = type_analysis.contains(
        type_spec.result, lambda t: t.is_sequence())
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        unexpected_building_block = building_blocks.ComputationBuildingBlock.from_proto(
            comp)
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            unexpected_building_block))

    if type_spec.is_function():
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    wrapped_fn = _get_wrapped_function_from_comp(comp,
                                                 must_pin_function_to_cpu,
                                                 param_type, device)
    param_fns = []
    if param_type is not None:
        for spec in structure.flatten(type_spec.parameter):
            if spec.is_tensor():
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                param_fns.append(tf.data.experimental.to_variant)

    result_fns = []
    for spec in structure.flatten(result_type):
        if spec.is_tensor():
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            tf_structure = type_conversions.type_to_tf_structure(spec.element)

            def fn(x, tf_structure=tf_structure):
                return tf.data.experimental.from_variant(x, tf_structure)

            result_fns.append(fn)

    ops = wrapped_fn.graph.get_operations()

    eager_cleanup_ops = []
    destroy_before_invocation = []
    for op in ops:
        if op.type == 'HashTableV2':
            eager_cleanup_ops += op.outputs
    if eager_cleanup_ops:
        for resource in wrapped_fn.prune(feeds={},
                                         fetches=eager_cleanup_ops)():
            destroy_before_invocation.append(resource)

    lazy_cleanup_ops = []
    destroy_after_invocation = []
    for op in ops:
        if op.type == 'VarHandleOp':
            lazy_cleanup_ops += op.outputs
    if lazy_cleanup_ops:
        for resource in wrapped_fn.prune(feeds={}, fetches=lazy_cleanup_ops)():
            destroy_after_invocation.append(resource)

    def fn_to_return(arg,
                     param_fns=tuple(param_fns),
                     result_fns=tuple(result_fns),
                     result_type=result_type,
                     wrapped_fn=wrapped_fn,
                     destroy_before=tuple(destroy_before_invocation),
                     destroy_after=tuple(destroy_after_invocation)):
        # This double-function pattern works around python late binding, forcing the
        # variables to bind eagerly.
        return _call_embedded_tf(arg=arg,
                                 param_fns=param_fns,
                                 result_fns=result_fns,
                                 result_type=result_type,
                                 wrapped_fn=wrapped_fn,
                                 destroy_before_invocation=destroy_before,
                                 destroy_after_invocation=destroy_after)

    # pylint: disable=function-redefined
    if must_pin_function_to_cpu:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device('cpu'):
                return old_fn_to_return(x)
    elif device is not None:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device(device.name):
                return old_fn_to_return(x)

    # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
コード例 #6
0
 def test_with_no_elements(self):
     with self.assertRaises(ValueError):
         type_conversions.type_to_tf_structure(
             computation_types.StructType([]))
コード例 #7
0
 def test_with_inconsistently_named_elements(self):
     with self.assertRaises(ValueError):
         type_conversions.type_to_tf_structure(
             computation_types.StructType([('a', tf.int32), tf.bool]))
コード例 #8
0
 def test_with_sequence_type(self):
     with self.assertRaises(ValueError):
         type_conversions.type_to_tf_structure(
             computation_types.SequenceType(tf.int32))
コード例 #9
0
 def test_with_none(self):
     with self.assertRaises(TypeError):
         type_conversions.type_to_tf_structure(None)
コード例 #10
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
  """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
  # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
  # since it deals exclusively with eager mode. Incubate here, and potentially
  # move there, once stable.

  py_typecheck.check_type(comp, pb.Computation)
  comp_type = type_serialization.deserialize_type(comp.type)
  type_spec = computation_types.to_type(type_spec)
  if type_spec is not None:
    if not type_analysis.are_equivalent_types(type_spec, comp_type):
      raise TypeError('Expected a computation of type {}, got {}.'.format(
          type_spec, comp_type))
  else:
    type_spec = comp_type
  which_computation = comp.WhichOneof('computation')
  if which_computation != 'tensorflow':
    raise TypeError('Expected a TensorFlow computation, found {}.'.format(
        which_computation))

  if isinstance(type_spec, computation_types.FunctionType):
    param_type = type_spec.parameter
    result_type = type_spec.result
  else:
    param_type = None
    result_type = type_spec

  if param_type is not None:
    input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.parameter)
  else:
    input_tensor_names = []

  output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
      comp.tensorflow.result)

  def function_to_wrap(*args):  # pylint: disable=missing-docstring
    if len(args) != len(input_tensor_names):
      raise RuntimeError('Expected {} arguments, found {}.'.format(
          len(input_tensor_names), len(args)))
    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    init_op = comp.tensorflow.initialize_op
    if init_op:
      graph_def = tensorflow_utils.add_control_deps_for_init_op(
          graph_def, init_op)

    def _import_fn():
      return tf.import_graph_def(
          graph_merge.uniquify_shared_names(graph_def),
          input_map=dict(list(zip(input_tensor_names, args))),
          return_elements=output_tensor_names)

    if device is not None:
      with tf.device(device):
        return _import_fn()
    else:
      return _import_fn()

  signature = []
  param_fns = []
  if param_type is not None:
    for spec in anonymous_tuple.flatten(type_spec.parameter):
      if isinstance(spec, computation_types.TensorType):
        signature.append(tf.TensorSpec(spec.shape, spec.dtype))
        param_fns.append(lambda x: x)
      else:
        py_typecheck.check_type(spec, computation_types.SequenceType)
        signature.append(tf.TensorSpec([], tf.variant))
        param_fns.append(tf.data.experimental.to_variant)

  wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

  result_fns = []
  for spec in anonymous_tuple.flatten(result_type):
    if isinstance(spec, computation_types.TensorType):
      result_fns.append(lambda x: x)
    else:
      py_typecheck.check_type(spec, computation_types.SequenceType)
      structure = type_conversions.type_to_tf_structure(spec.element)

      def fn(x, structure=structure):
        return tf.data.experimental.from_variant(x, structure)

      result_fns.append(fn)

  def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
    param_elements = []
    if arg is not None:
      arg_parts = anonymous_tuple.flatten(arg)
      if len(arg_parts) != len(param_fns):
        raise RuntimeError('Expected {} arguments, found {}.'.format(
            len(param_fns), len(arg_parts)))
      for arg_part, param_fn in zip(arg_parts, param_fns):
        param_elements.append(param_fn(arg_part))
    result_parts = wrapped_fn(*param_elements)

    # There is a tf.wrap_function(...) issue b/144127474 that variables created
    # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
    # destroyed.  So get all the variables from `wrapped_fn` and destroy
    # manually.
    # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
    # is fixed.
    resources = []
    for op in wrapped_fn.graph.get_operations():
      if op.type == 'VarHandleOp':
        resources += op.outputs
    if resources:
      for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
        tf.raw_ops.DestroyResourceOp(resource=resource)

    result_elements = []
    for result_part, result_fn in zip(result_parts, result_fns):
      result_elements.append(result_fn(result_part))
    return anonymous_tuple.pack_sequence_as(result_type, result_elements)

  fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(arg, p, w)

  if device is not None:
    old_fn_to_return = fn_to_return

    # pylint: disable=function-redefined
    def fn_to_return(x):
      with tf.device(device):
        return old_fn_to_return(x)

    # pylint: enable=function-redefined

  if param_type is not None:
    return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
  else:
    return lambda: fn_to_return(None)
コード例 #11
0
ファイル: eager_tf_executor.py プロジェクト: skyejy/federated
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional `tf.config.LogicalDevice`.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_spec.is_equivalent_to(comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    # TODO(b/156302055): Currently, TF will raise on any function returning a
    # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
    # this gating when we can.
    must_pin_function_to_cpu = type_analysis.contains(
        type_spec.result, lambda t: t.is_sequence())
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        unexpected_building_block = building_blocks.ComputationBuildingBlock.from_proto(
            comp)
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            unexpected_building_block))

    if type_spec.is_function():
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    wrapped_fn = _get_wrapped_function_from_comp(comp,
                                                 must_pin_function_to_cpu,
                                                 param_type, device)
    param_fns = []
    if param_type is not None:
        for spec in structure.flatten(type_spec.parameter):
            if spec.is_tensor():
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                param_fns.append(tf.data.experimental.to_variant)

    result_fns = []
    for spec in structure.flatten(result_type):
        if spec.is_tensor():
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            tf_structure = type_conversions.type_to_tf_structure(spec.element)

            def fn(x, tf_structure=tf_structure):
                return tf.data.experimental.from_variant(x, tf_structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring

        # TODO(b/166479382): This cleanup-before-invocation pattern is a workaround
        # to square the circle of TF data expecting to lazily reference this
        # resource on iteration, as well as usages that expect to reinitialize a
        # table with new data. Revisit the semantics implied by this cleanup
        # pattern.
        eager_cleanup_resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'HashTableV2':
                eager_cleanup_resources += op.outputs
        if eager_cleanup_resources:
            for resource in wrapped_fn.prune(
                    feeds={}, fetches=eager_cleanup_resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        param_elements = []
        if arg is not None:
            arg_parts = structure.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return structure.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)

    # pylint: disable=function-redefined
    if must_pin_function_to_cpu:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device('cpu'):
                return old_fn_to_return(x)
    elif device is not None:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device(device.name):
                return old_fn_to_return(x)

    # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
コード例 #12
0
 def check_round_trip(self, first_spec):
     first_type = computation_types.to_type(first_spec)
     round_trip_spec = type_conversions.type_to_tf_structure(first_type)
     round_trip_type = computation_types.to_type(round_trip_spec)
     self.assert_types_identical(first_type, round_trip_type)
コード例 #13
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional `tf.config.LogicalDevice`.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_spec.is_equivalent_to(comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    # TODO(b/155198591): Currently, TF will raise on any function returning a
    # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
    # this gating when we can.
    must_pin_function_to_cpu = type_analysis.contains_types(
        type_spec.result, computation_types.SequenceType)
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap():
        """No-arg function to import graph def.

    We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid
    the leftover placeholders that can result from binding arguments to the
    imported graphdef via `input_map`. The correct signature will be added to
    this function later, via the `prune` call below.

    Returns:
      Result of importing graphdef backing `comp`.
    """
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def), name='')

        if must_pin_function_to_cpu:
            with tf.device('cpu'):
                return _import_fn()
        elif device is not None:
            with tf.device(device.name):
                return _import_fn()
        else:
            return _import_fn()

    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_noarg_fn = tf.compat.v1.wrap_function(function_to_wrap,
                                                  signature=[])
    import_graph = wrapped_noarg_fn.graph
    try:
        wrapped_fn = wrapped_noarg_fn.prune(
            feeds=tf.nest.map_structure(import_graph.as_graph_element,
                                        input_tensor_names),
            fetches=tf.nest.map_structure(import_graph.as_graph_element,
                                          output_tensor_names),
        )
    except KeyError as e:
        raise TypeError(
            'Caught exception trying to prune graph `{g}` with '
            'feeds {feeds} and fetches {fetches}. This indicates that these '
            'names may not refer to tensors in the graph. .\nException: {e}'.
            format(g=import_graph,
                   feeds=input_tensor_names,
                   fetches=output_tensor_names,
                   e=e))

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_conversions.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)

    # pylint: disable=function-redefined
    if must_pin_function_to_cpu:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device('cpu'):
                return old_fn_to_return(x)
    elif device is not None:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device(device.name):
                return old_fn_to_return(x)

    # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)