def test_extract_tensor_names_from_binding_with_tuple_of_tensors(self): with tf.Graph().as_default() as graph: _, binding = graph_utils.capture_result_from_graph( collections.OrderedDict([('foo', tf.constant(10, name='A')), ('bar', tf.constant(20, name='B'))]), graph) result = graph_utils.extract_tensor_names_from_binding(binding) self.assertEqual(str(sorted(result)), '[\'A:0\', \'B:0\']')
def prune_tensorflow_proto(proto): """Extracts subgraph from `proto` preserving parameter, result and initialize. Args: proto: Instance of `pb.Computation` of the `tensorflow` variety whose `graphdef` attribute we wish to prune of extraneous ops. Returns: A transformed instance of `pb.Computation` of the `tensorflow` variety, whose `graphdef` attribute contains only ops which can reach the parameter or result bindings, or initialize op. """ py_typecheck.check_type(proto, pb.Computation) computation_oneof = proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError( '`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the \'tensorflow\' variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) if proto.tensorflow.parameter.WhichOneof('binding'): parameter_tensor_names = graph_utils.extract_tensor_names_from_binding( proto.tensorflow.parameter) parameter_names = [ ':'.join(x.split(':')[:-1]) for x in parameter_tensor_names ] else: parameter_names = [] return_tensor_names = graph_utils.extract_tensor_names_from_binding( proto.tensorflow.result) return_names = [':'.join(x.split(':')[:-1]) for x in return_tensor_names] graph_def = serialization_utils.unpack_graph_def( proto.tensorflow.graph_def) init_op_name = proto.tensorflow.initialize_op names_to_preserve = parameter_names + return_names if init_op_name: names_to_preserve.append(init_op_name) subgraph_def = tf.compat.v1.graph_util.extract_sub_graph( graph_def, names_to_preserve) tf_block = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(subgraph_def), initialize_op=proto.tensorflow.initialize_op, parameter=proto.tensorflow.parameter, result=proto.tensorflow.result) pruned_proto = pb.Computation(type=proto.type, tensorflow=tf_block) return pruned_proto
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional device name. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since # it deals exclusively with eager mode. Incubate here, and potentially move # there, once stable. if device is not None: raise NotImplementedError( 'Unable to embed TF code on a specific device.') py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_utils.are_equivalent_types(type_spec, comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( str(type_spec), str(comp_type))) else: type_spec = comp_type which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = graph_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = graph_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(input_tensor_names)), str(len(args)))) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op init_names = [init_op] if init_op else [] returned_elements = tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(zip(input_tensor_names, args)), return_elements=output_tensor_names + init_names) if init_names: with tf.control_dependencies([returned_elements[-1]]): return [tf.identity(x) for x in returned_elements[0:-1]] else: return returned_elements signature = [] param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): signature.append(tf.TensorSpec(spec.shape, spec.dtype)) param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) signature.append(tf.TensorSpec([], tf.variant)) param_fns.append(tf.data.experimental.to_variant) wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_utils.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(param_fns)), str(len(arg_parts)))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return( arg, p, w) if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def deserialize_and_call_tf_computation(computation_proto, arg, graph): """Deserializes a TF computation and inserts it into `graph`. This method performs an action that can be considered roughly the opposite of what `tensorflow_serialization.serialize_py_fn_as_tf_computation` does. At the moment, it simply imports the graph in the current context. A future implementation may rely on different mechanisms. The caller should not be concerned with the specifics of the implementation. At this point, the method is expected to only be used within the body of another TF computation (within an instance of `tf_computation_context.TensorFlowComputationContext` at the top of the stack), and potentially also in certain types of interpreted execution contexts (TBD). Args: computation_proto: An instance of `pb.Computation` with the `computation` one of equal to `tensorflow` to be deserialized and called. arg: The argument to invoke the computation with, or None if the computation does not specify a parameter type and does not expects one. graph: The graph to stamp into. Returns: A tuple (init_op, result) where: init_op: String name of an op to initialize the graph. result: The results to be fetched from TensorFlow. Depending on the type of the result, this can be `tf.Tensor` or `tf.data.Dataset` instances, or a nested structure (such as an `anonymous_tuple.AnonymousTuple`). Raises: TypeError: If the arguments are of the wrong types. ValueError: If `computation_proto` is not a TensorFlow computation proto. """ py_typecheck.check_type(computation_proto, pb.Computation) computation_oneof = computation_proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise ValueError('Expected a TensorFlow computation, got {}.'.format( computation_oneof)) py_typecheck.check_type(graph, tf.Graph) with graph.as_default(): type_spec = type_serialization.deserialize_type(computation_proto.type) if type_spec.parameter is None: if arg is None: input_map = None else: raise TypeError( 'The computation declared no parameters; encountered an unexpected ' 'argument {}.'.format(str(arg))) elif arg is None: raise TypeError( 'The computation declared a parameter of type {}, but the argument ' 'was not supplied.'.format(str(type_spec.parameter))) else: arg_type, arg_binding = graph_utils.capture_result_from_graph( arg, graph) if not type_utils.is_assignable_from(type_spec.parameter, arg_type): raise TypeError( 'The computation declared a parameter of type {}, but the argument ' 'is of a mismatching type {}.'.format( str(type_spec.parameter), str(arg_type))) else: input_map = { k: graph.get_tensor_by_name(v) for k, v in six.iteritems( graph_utils.compute_map_from_bindings( computation_proto.tensorflow.parameter, arg_binding)) } return_elements = graph_utils.extract_tensor_names_from_binding( computation_proto.tensorflow.result) orig_init_op_name = computation_proto.tensorflow.initialize_op if orig_init_op_name: return_elements.append(orig_init_op_name) # N. B. Unlike MetaGraphDef, the GraphDef alone contains no information # about collections, and hence, when we import a graph with Variables, # those Variables are not added to global collections, and hence # functions like tf.compat.v1.global_variables_initializers() will not # contain their initialization ops. output_tensors = tf.import_graph_def( serialization_utils.unpack_graph_def( computation_proto.tensorflow.graph_def), input_map, return_elements, # N. B. It is very important not to return any names from the original # computation_proto.tensorflow.graph_def, those names might or might not # be valid in the current graph. Using a different scope makes the graph # somewhat more readable, since _N style de-duplication of graph # node names is less likely to be needed. name='subcomputation') output_map = {k: v for k, v in zip(return_elements, output_tensors)} new_init_op_name = output_map.pop(orig_init_op_name, None) return (new_init_op_name, graph_utils.assemble_result_from_graph( type_spec.result, computation_proto.tensorflow.result, output_map))
def test_extract_tensor_names_from_binding_with_sequence(self): binding = pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding( iterator_string_handle_name='foo')) result = graph_utils.extract_tensor_names_from_binding(binding) self.assertEqual(str(sorted(result)), '[\'foo\']')