Example #1
0
    def test_returns_canonical_form_from_tff_learning_structure(self):
        it = test_utils.construct_example_training_comp()
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)
        self.assertEqual(it.initialize.type_signature,
                         new_it.initialize.type_signature)
        # Notice next type_signatures need not be equal, since we may have appended
        # an empty tuple as client side-channel outputs if none existed
        self.assertEqual(it.next.type_signature.parameter,
                         new_it.next.type_signature.parameter)
        self.assertEqual(it.next.type_signature.result[0],
                         new_it.next.type_signature.result[0])
        self.assertEqual(it.next.type_signature.result[1],
                         new_it.next.type_signature.result[1])

        state1 = it.initialize()
        state2 = new_it.initialize()

        sample_batch = collections.OrderedDict(x=np.array([[1., 1.]],
                                                          dtype=np.float32),
                                               y=np.array([[0]],
                                                          dtype=np.int32))
        client_data = [sample_batch]

        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        state_names = anonymous_tuple.name_list(state)
        state_arrays = anonymous_tuple.flatten(state)
        metrics = round_1[1]
        metrics_names = [x[0] for x in anonymous_tuple.iter_elements(metrics)]
        metrics_arrays = anonymous_tuple.flatten(metrics)

        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        alt_state_names = anonymous_tuple.name_list(alt_state)
        alt_state_arrays = anonymous_tuple.flatten(alt_state)
        alt_metrics = alt_round_1[1]
        alt_metrics_names = [
            x[0] for x in anonymous_tuple.iter_elements(alt_metrics)
        ]
        alt_metrics_arrays = anonymous_tuple.flatten(alt_metrics)

        self.assertEmpty(state.delta_aggregate_state)
        self.assertEmpty(state.model_broadcast_state)
        self.assertAllEqual(state_names, alt_state_names)
        self.assertAllEqual(metrics_names, alt_metrics_names)
        self.assertAllClose(state_arrays, alt_state_arrays)
        self.assertAllClose(metrics_arrays[:2], alt_metrics_arrays[:2])
        # Final metric is execution time
        self.assertAlmostEqual(metrics_arrays[2],
                               alt_metrics_arrays[2],
                               delta=1e-3)
Example #2
0
    def test_canonical_form_with_learning_structure_does_not_change_execution_of_iterative_process(
            self):
        ip_1 = construct_example_training_comp()
        cf = tff.backends.mapreduce.get_canonical_form_for_iterative_process(
            ip_1)
        ip_2 = tff.backends.mapreduce.get_iterative_process_for_canonical_form(
            cf)

        self.assertEqual(ip_1.initialize.type_signature,
                         ip_2.initialize.type_signature)
        # The next functions type_signatures may not be equal, since we may have
        # appended an empty tuple as client side-channel outputs if none existed.
        self.assertEqual(ip_1.next.type_signature.parameter,
                         ip_2.next.type_signature.parameter)
        self.assertEqual(ip_1.next.type_signature.result[0],
                         ip_2.next.type_signature.result[0])
        self.assertEqual(ip_1.next.type_signature.result[1],
                         ip_2.next.type_signature.result[1])

        sample_batch = collections.OrderedDict(
            x=np.array([[1., 1.]], dtype=np.float32),
            y=np.array([[0]], dtype=np.int32),
        )
        client_data = [sample_batch]
        state_1 = ip_1.initialize()
        server_state_1, server_output_1 = ip_1.next(state_1, [client_data])
        server_state_1_names = anonymous_tuple.name_list(server_state_1)
        server_state_1_arrays = anonymous_tuple.flatten(server_state_1)
        server_output_1_names = [
            x[0] for x in anonymous_tuple.iter_elements(server_output_1)
        ]
        server_output_1_arrays = anonymous_tuple.flatten(server_output_1)
        state_2 = ip_2.initialize()
        server_state_2, server_output_2, _ = ip_2.next(state_2, [client_data])
        server_state_2_names = anonymous_tuple.name_list(server_state_2)
        server_state_2_arrays = anonymous_tuple.flatten(server_state_2)
        server_output_2_names = [
            x[0] for x in anonymous_tuple.iter_elements(server_output_2)
        ]
        server_output_2_arrays = anonymous_tuple.flatten(server_output_2)

        self.assertEmpty(server_state_1.delta_aggregate_state)
        self.assertEmpty(server_state_1.model_broadcast_state)
        self.assertAllEqual(server_state_1_names, server_state_2_names)
        self.assertAllEqual(server_output_1_names, server_output_2_names)
        self.assertAllClose(server_state_1_arrays, server_state_2_arrays)
        self.assertAllClose(server_output_1_arrays[:2],
                            server_output_2_arrays[:2])

        execution_time_1 = server_output_1_arrays[2]
        execution_time_2 = server_output_2_arrays[2]

        self.assertAlmostEqual(execution_time_1, execution_time_2, delta=1e-3)
Example #3
0
 def clip_by_global_norm(delta, clip_norm):
     # TODO(b/123092620): Replace anonymous_tuple with tf.nest.
     delta = anonymous_tuple.from_container(delta)
     clipped, global_norm = tf.clip_by_global_norm(
         anonymous_tuple.flatten(delta), clip_norm)
     return anonymous_tuple.pack_sequence_as(delta,
                                             clipped), global_norm
def assign(target, source):
    """Creates an op that assigns `target` from `source`.

  This utility function provides the exact same behavior as
  `tf.Variable.assign`, but it generalizes to a wider class of objects,
  including ordinary variables as well as various types of nested structures.

  Args:
    target: A nested structure composed of variables embedded in containers that
      are compatible with `tf.nest`, or instances of
      `anonymous_tuple.AnonymousTuple`.
    source: A nsested structure composed of tensors, matching that of `target`.

  Returns:
    A single op that represents the assignment.

  Raises:
    TypeError: If types mismatch.
  """
    # TODO(b/113112108): Extend this to containers of mixed types.
    if isinstance(target, anonymous_tuple.AnonymousTuple):
        return tf.group(*anonymous_tuple.flatten(
            anonymous_tuple.map_structure(lambda a, b: a.assign(b), target,
                                          source)))
    else:
        return tf.group(*tf.nest.flatten(
            tf.nest.map_structure(lambda a, b: a.assign(b), target, source)))
Example #5
0
    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)
Example #6
0
    def test_canonical_form_with_learning_structure_does_not_change_execution_of_iterative_process(
            self):
        ip_1 = construct_example_training_comp()
        cf = tff.backends.mapreduce.get_canonical_form_for_iterative_process(
            ip_1)
        ip_2 = tff.backends.mapreduce.get_iterative_process_for_canonical_form(
            cf)

        ip_1.initialize.type_signature.check_equivalent_to(
            ip_2.initialize.type_signature)
        # The next functions type_signatures may not be equal, since we may have
        # appended an empty tuple as client side-channel outputs if none existed.
        ip_1.next.type_signature.parameter.check_equivalent_to(
            ip_2.next.type_signature.parameter)
        ip_1.next.type_signature.result.check_equivalent_to(
            ip_2.next.type_signature.result)

        sample_batch = collections.OrderedDict(
            x=np.array([[1., 1.]], dtype=np.float32),
            y=np.array([[0]], dtype=np.int32),
        )
        client_data = [sample_batch]
        state_1 = ip_1.initialize()
        server_state_1, server_output_1 = ip_1.next(state_1, [client_data])
        server_state_1 = anonymous_tuple.from_container(server_state_1,
                                                        recursive=True)
        server_output_1 = anonymous_tuple.from_container(server_output_1,
                                                         recursive=True)
        server_state_1_arrays = anonymous_tuple.flatten(server_state_1)
        server_output_1_arrays = anonymous_tuple.flatten(server_output_1)
        state_2 = ip_2.initialize()
        server_state_2, server_output_2 = ip_2.next(state_2, [client_data])
        server_state_2_arrays = anonymous_tuple.flatten(server_state_2)
        server_output_2_arrays = anonymous_tuple.flatten(server_output_2)

        self.assertEmpty(server_state_1.delta_aggregate_state)
        self.assertEmpty(server_state_1.model_broadcast_state)
        # Note that we cannot simply use assertEqual because the values may differ
        # due to floating point issues.
        self.assertTrue(
            anonymous_tuple.is_same_structure(server_state_1, server_state_2))
        self.assertTrue(
            anonymous_tuple.is_same_structure(server_output_1,
                                              server_output_2))
        self.assertAllClose(server_state_1_arrays, server_state_2_arrays)
        self.assertAllClose(server_output_1_arrays[:2],
                            server_output_2_arrays[:2])
Example #7
0
def fetch_value_in_session(sess, value):
    """Fetches `value` in `session`.

  Args:
    sess: The session in which to perform the fetch (as a single run).
    value: A Python object of a form analogous to that constructed by the
      function `assemble_result_from_graph`, made of tensors and anononymous
      tuples, or a `tf.data.Dataset`.

  Returns:
    A Python object with structure similar to `value`, but with tensors
    replaced with their values, and data sets replaced with lists of their
    elements, all fetched with a single call `session.run()`.

  Raises:
    ValueError: If `value` is not a `tf.data.Dataset` or not a structure of
      tensors and anonoymous tuples.
  """
    py_typecheck.check_type(sess, tf.Session)
    # TODO(b/113123634): Investigate handling `list`s and `tuple`s of
    # `tf.data.Dataset`s and what the API would look like to support this.
    if isinstance(value, DATASET_REPRESENTATION_TYPES):
        with sess.graph.as_default():
            iterator = tf.compat.v1.data.make_one_shot_iterator(value)
            next_element = iterator.get_next()
        elements = []
        while True:
            try:
                elements.append(sess.run(next_element))
            except tf.errors.OutOfRangeError:
                break
        return elements
    else:
        flattened_value = anonymous_tuple.flatten(value)
        dataset_results = {}
        flat_tensors = []
        for idx, v in enumerate(flattened_value):
            if isinstance(v, DATASET_REPRESENTATION_TYPES):
                dataset_results[idx] = fetch_value_in_session(sess, v)
            elif tf.is_tensor(v):
                flat_tensors.append(v)
            else:
                raise ValueError('Unsupported value type {}.'.format(str(v)))
        flat_computed_tensors = sess.run(flat_tensors)
        flattened_results = _interleave_dataset_results_and_tensors(
            dataset_results, flat_computed_tensors)

        def _to_unicode(v):
            if six.PY3 and isinstance(v, bytes):
                return v.decode('utf-8')
            return v

        if tf.is_tensor(value) and value.dtype == tf.string:
            flattened_results = [
                _to_unicode(result) for result in flattened_results
            ]
        return anonymous_tuple.pack_sequence_as(value, flattened_results)
Example #8
0
 def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
     param_elements = []
     if arg is not None:
         arg_parts = anonymous_tuple.flatten(arg)
         if len(arg_parts) != len(param_fns):
             raise RuntimeError('Expected {} arguments, found {}.'.format(
                 str(len(param_fns)), str(len(arg_parts))))
         for arg_part, param_fn in zip(arg_parts, param_fns):
             param_elements.append(param_fn(arg_part))
     result_parts = wrapped_fn(*param_elements)
     result_elements = []
     for result_part, result_fn in zip(result_parts, result_fns):
         result_elements.append(result_fn(result_part))
     return anonymous_tuple.pack_sequence_as(result_type, result_elements)
 def test_flatten_and_pack_sequence_as(self):
     x = anonymous_tuple.AnonymousTuple([
         ('a', 10),
         ('b',
          anonymous_tuple.AnonymousTuple([
              ('x', anonymous_tuple.AnonymousTuple([('p', 40)])),
              ('y', 30),
              ('z', anonymous_tuple.AnonymousTuple([('q', 50), ('r', 60)])),
          ])),
         ('c', 20),
     ])
     y = anonymous_tuple.flatten(x)
     self.assertEqual(y, [10, 40, 30, 50, 60, 20])
     z = anonymous_tuple.pack_sequence_as(x, y)
     self.assertEqual(str(z), '<a=10,b=<x=<p=40>,y=30,z=<q=50,r=60>>,c=20>')
Example #10
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since
    # it deals exclusively with eager mode. Incubate here, and potentially move
    # there, once stable.

    if device is not None:
        raise NotImplementedError(
            'Unable to embed TF code on a specific device.')

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_utils.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    str(type_spec), str(comp_type)))
    else:
        type_spec = comp_type
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = graph_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = graph_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                str(len(input_tensor_names)), str(len(args))))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        init_names = [init_op] if init_op else []
        returned_elements = tf.import_graph_def(
            graph_merge.uniquify_shared_names(graph_def),
            input_map=dict(zip(input_tensor_names, args)),
            return_elements=output_tensor_names + init_names)
        if init_names:
            with tf.control_dependencies([returned_elements[-1]]):
                return [tf.identity(x) for x in returned_elements[0:-1]]
        else:
            return returned_elements

    signature = []
    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                signature.append(tf.TensorSpec(spec.shape, spec.dtype))
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                signature.append(tf.TensorSpec([], tf.variant))
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_utils.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    str(len(param_fns)), str(len(arg_parts))))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)
        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)
    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
Example #11
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_utils.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                len(input_tensor_names), len(args)))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def),
                input_map=dict(list(zip(input_tensor_names, args))),
                return_elements=output_tensor_names)

        if device is not None:
            with tf.device(device):
                return _import_fn()
        else:
            return _import_fn()

    signature = []
    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                signature.append(tf.TensorSpec(spec.shape, spec.dtype))
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                signature.append(tf.TensorSpec([], tf.variant))
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_utils.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)

    if device is not None:
        old_fn_to_return = fn_to_return

        # pylint: disable=function-redefined
        def fn_to_return(x):
            with tf.device(device):
                return old_fn_to_return(x)

        # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
Example #12
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional `tf.config.LogicalDevice`.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_spec.is_equivalent_to(comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    # TODO(b/155198591): Currently, TF will raise on any function returning a
    # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
    # this gating when we can.
    must_pin_function_to_cpu = type_analysis.contains_types(
        type_spec.result, computation_types.SequenceType)
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap():
        """No-arg function to import graph def.

    We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid
    the leftover placeholders that can result from binding arguments to the
    imported graphdef via `input_map`. The correct signature will be added to
    this function later, via the `prune` call below.

    Returns:
      Result of importing graphdef backing `comp`.
    """
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def), name='')

        if must_pin_function_to_cpu:
            with tf.device('cpu'):
                return _import_fn()
        elif device is not None:
            with tf.device(device.name):
                return _import_fn()
        else:
            return _import_fn()

    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_noarg_fn = tf.compat.v1.wrap_function(function_to_wrap,
                                                  signature=[])
    import_graph = wrapped_noarg_fn.graph
    try:
        wrapped_fn = wrapped_noarg_fn.prune(
            feeds=tf.nest.map_structure(import_graph.as_graph_element,
                                        input_tensor_names),
            fetches=tf.nest.map_structure(import_graph.as_graph_element,
                                          output_tensor_names),
        )
    except KeyError as e:
        raise TypeError(
            'Caught exception trying to prune graph `{g}` with '
            'feeds {feeds} and fetches {fetches}. This indicates that these '
            'names may not refer to tensors in the graph. .\nException: {e}'.
            format(g=import_graph,
                   feeds=input_tensor_names,
                   fetches=output_tensor_names,
                   e=e))

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_conversions.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)

    # pylint: disable=function-redefined
    if must_pin_function_to_cpu:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device('cpu'):
                return old_fn_to_return(x)
    elif device is not None:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device(device.name):
                return old_fn_to_return(x)

    # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
Example #13
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
  """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional `tf.config.LogicalDevice`.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
  # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
  # since it deals exclusively with eager mode. Incubate here, and potentially
  # move there, once stable.

  py_typecheck.check_type(comp, pb.Computation)
  comp_type = type_serialization.deserialize_type(comp.type)
  type_spec = computation_types.to_type(type_spec)
  if type_spec is not None:
    if not type_spec.is_equivalent_to(comp_type):
      raise TypeError('Expected a computation of type {}, got {}.'.format(
          type_spec, comp_type))
  else:
    type_spec = comp_type
  # TODO(b/155198591): Currently, TF will raise on any function returning a
  # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
  # this gating when we can.
  must_pin_function_to_cpu = type_analysis.contains(type_spec.result,
                                                    lambda t: t.is_sequence())
  which_computation = comp.WhichOneof('computation')
  if which_computation != 'tensorflow':
    raise TypeError('Expected a TensorFlow computation, found {}.'.format(
        which_computation))

  if type_spec.is_function():
    param_type = type_spec.parameter
    result_type = type_spec.result
  else:
    param_type = None
    result_type = type_spec

  wrapped_fn = _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu,
                                               param_type, device)

  param_fns = []
  if param_type is not None:
    for spec in anonymous_tuple.flatten(type_spec.parameter):
      if spec.is_tensor():
        param_fns.append(lambda x: x)
      else:
        py_typecheck.check_type(spec, computation_types.SequenceType)
        param_fns.append(tf.data.experimental.to_variant)

  result_fns = []
  for spec in anonymous_tuple.flatten(result_type):
    if spec.is_tensor():
      result_fns.append(lambda x: x)
    else:
      py_typecheck.check_type(spec, computation_types.SequenceType)
      structure = type_conversions.type_to_tf_structure(spec.element)

      def fn(x, structure=structure):
        return tf.data.experimental.from_variant(x, structure)

      result_fns.append(fn)

  def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
    param_elements = []
    if arg is not None:
      arg_parts = anonymous_tuple.flatten(arg)
      if len(arg_parts) != len(param_fns):
        raise RuntimeError('Expected {} arguments, found {}.'.format(
            len(param_fns), len(arg_parts)))
      for arg_part, param_fn in zip(arg_parts, param_fns):
        param_elements.append(param_fn(arg_part))
    result_parts = wrapped_fn(*param_elements)

    # There is a tf.wrap_function(...) issue b/144127474 that variables created
    # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
    # destroyed.  So get all the variables from `wrapped_fn` and destroy
    # manually.
    # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
    # is fixed.
    resources = []
    for op in wrapped_fn.graph.get_operations():
      if op.type == 'VarHandleOp':
        resources += op.outputs
    if resources:
      for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
        tf.raw_ops.DestroyResourceOp(resource=resource)

    result_elements = []
    for result_part, result_fn in zip(result_parts, result_fns):
      result_elements.append(result_fn(result_part))
    return anonymous_tuple.pack_sequence_as(result_type, result_elements)

  fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(arg, p, w)

  # pylint: disable=function-redefined
  if must_pin_function_to_cpu:
    old_fn_to_return = fn_to_return

    def fn_to_return(x):
      with tf.device('cpu'):
        return old_fn_to_return(x)
  elif device is not None:
    old_fn_to_return = fn_to_return

    def fn_to_return(x):
      with tf.device(device.name):
        return old_fn_to_return(x)

  # pylint: enable=function-redefined

  if param_type is not None:
    return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
  else:
    return lambda: fn_to_return(None)