def test_returns_canonical_form_from_tff_learning_structure(self): it = test_utils.construct_example_training_comp() cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) self.assertIsInstance(cf, canonical_form.CanonicalForm) self.assertEqual(it.initialize.type_signature, new_it.initialize.type_signature) # Notice next type_signatures need not be equal, since we may have appended # an empty tuple as client side-channel outputs if none existed self.assertEqual(it.next.type_signature.parameter, new_it.next.type_signature.parameter) self.assertEqual(it.next.type_signature.result[0], new_it.next.type_signature.result[0]) self.assertEqual(it.next.type_signature.result[1], new_it.next.type_signature.result[1]) state1 = it.initialize() state2 = new_it.initialize() sample_batch = collections.OrderedDict(x=np.array([[1., 1.]], dtype=np.float32), y=np.array([[0]], dtype=np.int32)) client_data = [sample_batch] round_1 = it.next(state1, [client_data]) state = round_1[0] state_names = anonymous_tuple.name_list(state) state_arrays = anonymous_tuple.flatten(state) metrics = round_1[1] metrics_names = [x[0] for x in anonymous_tuple.iter_elements(metrics)] metrics_arrays = anonymous_tuple.flatten(metrics) alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_state_names = anonymous_tuple.name_list(alt_state) alt_state_arrays = anonymous_tuple.flatten(alt_state) alt_metrics = alt_round_1[1] alt_metrics_names = [ x[0] for x in anonymous_tuple.iter_elements(alt_metrics) ] alt_metrics_arrays = anonymous_tuple.flatten(alt_metrics) self.assertEmpty(state.delta_aggregate_state) self.assertEmpty(state.model_broadcast_state) self.assertAllEqual(state_names, alt_state_names) self.assertAllEqual(metrics_names, alt_metrics_names) self.assertAllClose(state_arrays, alt_state_arrays) self.assertAllClose(metrics_arrays[:2], alt_metrics_arrays[:2]) # Final metric is execution time self.assertAlmostEqual(metrics_arrays[2], alt_metrics_arrays[2], delta=1e-3)
def test_canonical_form_with_learning_structure_does_not_change_execution_of_iterative_process( self): ip_1 = construct_example_training_comp() cf = tff.backends.mapreduce.get_canonical_form_for_iterative_process( ip_1) ip_2 = tff.backends.mapreduce.get_iterative_process_for_canonical_form( cf) self.assertEqual(ip_1.initialize.type_signature, ip_2.initialize.type_signature) # The next functions type_signatures may not be equal, since we may have # appended an empty tuple as client side-channel outputs if none existed. self.assertEqual(ip_1.next.type_signature.parameter, ip_2.next.type_signature.parameter) self.assertEqual(ip_1.next.type_signature.result[0], ip_2.next.type_signature.result[0]) self.assertEqual(ip_1.next.type_signature.result[1], ip_2.next.type_signature.result[1]) sample_batch = collections.OrderedDict( x=np.array([[1., 1.]], dtype=np.float32), y=np.array([[0]], dtype=np.int32), ) client_data = [sample_batch] state_1 = ip_1.initialize() server_state_1, server_output_1 = ip_1.next(state_1, [client_data]) server_state_1_names = anonymous_tuple.name_list(server_state_1) server_state_1_arrays = anonymous_tuple.flatten(server_state_1) server_output_1_names = [ x[0] for x in anonymous_tuple.iter_elements(server_output_1) ] server_output_1_arrays = anonymous_tuple.flatten(server_output_1) state_2 = ip_2.initialize() server_state_2, server_output_2, _ = ip_2.next(state_2, [client_data]) server_state_2_names = anonymous_tuple.name_list(server_state_2) server_state_2_arrays = anonymous_tuple.flatten(server_state_2) server_output_2_names = [ x[0] for x in anonymous_tuple.iter_elements(server_output_2) ] server_output_2_arrays = anonymous_tuple.flatten(server_output_2) self.assertEmpty(server_state_1.delta_aggregate_state) self.assertEmpty(server_state_1.model_broadcast_state) self.assertAllEqual(server_state_1_names, server_state_2_names) self.assertAllEqual(server_output_1_names, server_output_2_names) self.assertAllClose(server_state_1_arrays, server_state_2_arrays) self.assertAllClose(server_output_1_arrays[:2], server_output_2_arrays[:2]) execution_time_1 = server_output_1_arrays[2] execution_time_2 = server_output_2_arrays[2] self.assertAlmostEqual(execution_time_1, execution_time_2, delta=1e-3)
def clip_by_global_norm(delta, clip_norm): # TODO(b/123092620): Replace anonymous_tuple with tf.nest. delta = anonymous_tuple.from_container(delta) clipped, global_norm = tf.clip_by_global_norm( anonymous_tuple.flatten(delta), clip_norm) return anonymous_tuple.pack_sequence_as(delta, clipped), global_norm
def assign(target, source): """Creates an op that assigns `target` from `source`. This utility function provides the exact same behavior as `tf.Variable.assign`, but it generalizes to a wider class of objects, including ordinary variables as well as various types of nested structures. Args: target: A nested structure composed of variables embedded in containers that are compatible with `tf.nest`, or instances of `anonymous_tuple.AnonymousTuple`. source: A nsested structure composed of tensors, matching that of `target`. Returns: A single op that represents the assignment. Raises: TypeError: If types mismatch. """ # TODO(b/113112108): Extend this to containers of mixed types. if isinstance(target, anonymous_tuple.AnonymousTuple): return tf.group(*anonymous_tuple.flatten( anonymous_tuple.map_structure(lambda a, b: a.assign(b), target, source))) else: return tf.group(*tf.nest.flatten( tf.nest.map_structure(lambda a, b: a.assign(b), target, source)))
def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements)
def test_canonical_form_with_learning_structure_does_not_change_execution_of_iterative_process( self): ip_1 = construct_example_training_comp() cf = tff.backends.mapreduce.get_canonical_form_for_iterative_process( ip_1) ip_2 = tff.backends.mapreduce.get_iterative_process_for_canonical_form( cf) ip_1.initialize.type_signature.check_equivalent_to( ip_2.initialize.type_signature) # The next functions type_signatures may not be equal, since we may have # appended an empty tuple as client side-channel outputs if none existed. ip_1.next.type_signature.parameter.check_equivalent_to( ip_2.next.type_signature.parameter) ip_1.next.type_signature.result.check_equivalent_to( ip_2.next.type_signature.result) sample_batch = collections.OrderedDict( x=np.array([[1., 1.]], dtype=np.float32), y=np.array([[0]], dtype=np.int32), ) client_data = [sample_batch] state_1 = ip_1.initialize() server_state_1, server_output_1 = ip_1.next(state_1, [client_data]) server_state_1 = anonymous_tuple.from_container(server_state_1, recursive=True) server_output_1 = anonymous_tuple.from_container(server_output_1, recursive=True) server_state_1_arrays = anonymous_tuple.flatten(server_state_1) server_output_1_arrays = anonymous_tuple.flatten(server_output_1) state_2 = ip_2.initialize() server_state_2, server_output_2 = ip_2.next(state_2, [client_data]) server_state_2_arrays = anonymous_tuple.flatten(server_state_2) server_output_2_arrays = anonymous_tuple.flatten(server_output_2) self.assertEmpty(server_state_1.delta_aggregate_state) self.assertEmpty(server_state_1.model_broadcast_state) # Note that we cannot simply use assertEqual because the values may differ # due to floating point issues. self.assertTrue( anonymous_tuple.is_same_structure(server_state_1, server_state_2)) self.assertTrue( anonymous_tuple.is_same_structure(server_output_1, server_output_2)) self.assertAllClose(server_state_1_arrays, server_state_2_arrays) self.assertAllClose(server_output_1_arrays[:2], server_output_2_arrays[:2])
def fetch_value_in_session(sess, value): """Fetches `value` in `session`. Args: sess: The session in which to perform the fetch (as a single run). value: A Python object of a form analogous to that constructed by the function `assemble_result_from_graph`, made of tensors and anononymous tuples, or a `tf.data.Dataset`. Returns: A Python object with structure similar to `value`, but with tensors replaced with their values, and data sets replaced with lists of their elements, all fetched with a single call `session.run()`. Raises: ValueError: If `value` is not a `tf.data.Dataset` or not a structure of tensors and anonoymous tuples. """ py_typecheck.check_type(sess, tf.Session) # TODO(b/113123634): Investigate handling `list`s and `tuple`s of # `tf.data.Dataset`s and what the API would look like to support this. if isinstance(value, DATASET_REPRESENTATION_TYPES): with sess.graph.as_default(): iterator = tf.compat.v1.data.make_one_shot_iterator(value) next_element = iterator.get_next() elements = [] while True: try: elements.append(sess.run(next_element)) except tf.errors.OutOfRangeError: break return elements else: flattened_value = anonymous_tuple.flatten(value) dataset_results = {} flat_tensors = [] for idx, v in enumerate(flattened_value): if isinstance(v, DATASET_REPRESENTATION_TYPES): dataset_results[idx] = fetch_value_in_session(sess, v) elif tf.is_tensor(v): flat_tensors.append(v) else: raise ValueError('Unsupported value type {}.'.format(str(v))) flat_computed_tensors = sess.run(flat_tensors) flattened_results = _interleave_dataset_results_and_tensors( dataset_results, flat_computed_tensors) def _to_unicode(v): if six.PY3 and isinstance(v, bytes): return v.decode('utf-8') return v if tf.is_tensor(value) and value.dtype == tf.string: flattened_results = [ _to_unicode(result) for result in flattened_results ] return anonymous_tuple.pack_sequence_as(value, flattened_results)
def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(param_fns)), str(len(arg_parts)))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements)
def test_flatten_and_pack_sequence_as(self): x = anonymous_tuple.AnonymousTuple([ ('a', 10), ('b', anonymous_tuple.AnonymousTuple([ ('x', anonymous_tuple.AnonymousTuple([('p', 40)])), ('y', 30), ('z', anonymous_tuple.AnonymousTuple([('q', 50), ('r', 60)])), ])), ('c', 20), ]) y = anonymous_tuple.flatten(x) self.assertEqual(y, [10, 40, 30, 50, 60, 20]) z = anonymous_tuple.pack_sequence_as(x, y) self.assertEqual(str(z), '<a=10,b=<x=<p=40>,y=30,z=<q=50,r=60>>,c=20>')
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional device name. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since # it deals exclusively with eager mode. Incubate here, and potentially move # there, once stable. if device is not None: raise NotImplementedError( 'Unable to embed TF code on a specific device.') py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_utils.are_equivalent_types(type_spec, comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( str(type_spec), str(comp_type))) else: type_spec = comp_type which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = graph_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = graph_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(input_tensor_names)), str(len(args)))) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op init_names = [init_op] if init_op else [] returned_elements = tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(zip(input_tensor_names, args)), return_elements=output_tensor_names + init_names) if init_names: with tf.control_dependencies([returned_elements[-1]]): return [tf.identity(x) for x in returned_elements[0:-1]] else: return returned_elements signature = [] param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): signature.append(tf.TensorSpec(spec.shape, spec.dtype)) param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) signature.append(tf.TensorSpec([], tf.variant)) param_fns.append(tf.data.experimental.to_variant) wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_utils.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(param_fns)), str(len(arg_parts)))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return( arg, p, w) if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional device name. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_utils.are_equivalent_types(type_spec, comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( len(input_tensor_names), len(args))) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = tensorflow_utils.add_control_deps_for_init_op( graph_def, init_op) def _import_fn(): return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(list(zip(input_tensor_names, args))), return_elements=output_tensor_names) if device is not None: with tf.device(device): return _import_fn() else: return _import_fn() signature = [] param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): signature.append(tf.TensorSpec(spec.shape, spec.dtype)) param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) signature.append(tf.TensorSpec([], tf.variant)) param_fns.append(tf.data.experimental.to_variant) wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_utils.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return( arg, p, w) if device is not None: old_fn_to_return = fn_to_return # pylint: disable=function-redefined def fn_to_return(x): with tf.device(device): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional `tf.config.LogicalDevice`. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_spec.is_equivalent_to(comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type # TODO(b/155198591): Currently, TF will raise on any function returning a # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove # this gating when we can. must_pin_function_to_cpu = type_analysis.contains_types( type_spec.result, computation_types.SequenceType) which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(): """No-arg function to import graph def. We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid the leftover placeholders that can result from binding arguments to the imported graphdef via `input_map`. The correct signature will be added to this function later, via the `prune` call below. Returns: Result of importing graphdef backing `comp`. """ graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = tensorflow_utils.add_control_deps_for_init_op( graph_def, init_op) def _import_fn(): return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), name='') if must_pin_function_to_cpu: with tf.device('cpu'): return _import_fn() elif device is not None: with tf.device(device.name): return _import_fn() else: return _import_fn() param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) param_fns.append(tf.data.experimental.to_variant) wrapped_noarg_fn = tf.compat.v1.wrap_function(function_to_wrap, signature=[]) import_graph = wrapped_noarg_fn.graph try: wrapped_fn = wrapped_noarg_fn.prune( feeds=tf.nest.map_structure(import_graph.as_graph_element, input_tensor_names), fetches=tf.nest.map_structure(import_graph.as_graph_element, output_tensor_names), ) except KeyError as e: raise TypeError( 'Caught exception trying to prune graph `{g}` with ' 'feeds {feeds} and fetches {fetches}. This indicates that these ' 'names may not refer to tensors in the graph. .\nException: {e}'. format(g=import_graph, feeds=input_tensor_names, fetches=output_tensor_names, e=e)) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_conversions.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return( arg, p, w) # pylint: disable=function-redefined if must_pin_function_to_cpu: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device('cpu'): return old_fn_to_return(x) elif device is not None: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device(device.name): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional `tf.config.LogicalDevice`. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_spec.is_equivalent_to(comp_type): raise TypeError('Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type # TODO(b/155198591): Currently, TF will raise on any function returning a # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove # this gating when we can. must_pin_function_to_cpu = type_analysis.contains(type_spec.result, lambda t: t.is_sequence()) which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if type_spec.is_function(): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec wrapped_fn = _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu, param_type, device) param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if spec.is_tensor(): param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) param_fns.append(tf.data.experimental.to_variant) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if spec.is_tensor(): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_conversions.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(arg, p, w) # pylint: disable=function-redefined if must_pin_function_to_cpu: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device('cpu'): return old_fn_to_return(x) elif device is not None: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device(device.name): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)