def stripped_op_list_for_graph(graph_def): """Collect the stripped OpDefs for ops used by a graph. This function computes the `stripped_op_list` field of `MetaGraphDef` and similar protos. The result can be communicated from the producer to the consumer, which can then use the C++ function `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. Args: graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. Returns: An `OpList` of ops used by the graph. Raises: ValueError: If an unregistered op is used. """ # This is the Python equivalent of StrippedOpListForGraph in C++. # Unfortunately, since the Python op registry can differ from that in C++, we # can't remove the duplication using swig (at least naively). # TODO(irving): Support taking graphs directly. used_ops = ops_used_by_graph_def(graph_def) # Verify that all used ops are registered. registered_ops = op_def_registry.get_registered_ops() # These internal ops used by functions are not registered, so we need to # whitelist them. # TODO(irving): Do something better here. op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList") for op in used_ops: if op not in registered_ops and op not in op_whitelist: raise ValueError("Op %s is used by the graph, but is not registered" % op) # Build the stripped op list in sorted order return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) if op in registered_ops])
def _add_op_node(op, func): """Converts an op to a function def node and add it to `func`.""" node = function_pb2.FunctionDef.Node() node.op = op.type # pylint: disable=protected-access if hasattr(op, "_sig"): op_def = getattr(op, "_sig") else: op_def = op_def_registry.get_registered_ops()[op.type] # pylint: enable=protected-access attrs = _get_node_def_attr(op) if not op_def.output_arg: node.ret.append(_make_argname_from_tensor_name(op.name)) else: out_index = 0 for arg_def in op_def.output_arg: if arg_def.number_attr: dtype = arg_def.type or attrs[arg_def.type_attr].type num = attrs[arg_def.number_attr].i node.ret.append( _add_output_array(op, out_index, out_index + num, dtype, func)) out_index += num elif arg_def.type_list_attr: dtype_lst = attrs[arg_def.type_list_attr].list.type num = len(dtype_lst) node.ret.append( _add_output_list(op, out_index, out_index + num, dtype_lst, func)) out_index += num else: node.ret.append( _make_argname_from_tensor_name(op.outputs[out_index].name)) out_index += 1 inp_index = 0 for arg_def in op_def.input_arg: if arg_def.number_attr: dtype = arg_def.type or attrs[arg_def.type_attr].type num = attrs[arg_def.number_attr].i node.arg.append( _add_input_array(op, inp_index, inp_index + num, dtype, func)) inp_index += num elif arg_def.type_list_attr: num = len(attrs[arg_def.type_list_attr].list.type) node.arg.extend([ _make_argname_from_tensor_name(op.inputs[i].name) for i in range(inp_index, inp_index + num) ]) inp_index += num else: node.arg.append( _make_argname_from_tensor_name(op.inputs[inp_index].name)) inp_index += 1 node.dep.extend( [_make_argname_from_tensor_name(x.name) for x in op.control_inputs]) for k, v in _get_node_def_attr(op).items(): node.attr[k].CopyFrom(v) func.node.extend([node])
def testStripDefaultAttrsInconsistentConsumerDefaults(self): if ops._USE_C_API: return # TODO(skyewm): get this working export_dir = self._get_export_dir( "test_strip_default_attrs_no_consumer_defaults") builder = saved_model_builder.SavedModelBuilder(export_dir) # Add a graph with two float32 variables and a Complex Op composing them # with strip_default_attrs enabled. This must remove the following # defaults for the "Complex" Op: # o "T" : float32. (input type) # o "Tout" : complex64. (output type) with session.Session(graph=ops.Graph()) as sess: real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables( sess, ["foo"], strip_default_attrs=True) # Save the SavedModel to disk in text format. builder.save(as_text=True) # Update the Op registry to remove defaults for all attrs("T", "Tout") from # the "Complex" OpDef. complex_op_def = op_def_registry.get_registered_ops()["Complex"] original_complex_op_def = op_def_pb2.OpDef() original_complex_op_def.CopyFrom(complex_op_def) for attr_def in complex_op_def.attr: attr_def.ClearField("default_value") # Loading the SavedModel via the loader must fail because the SavedModel # does not have any attr values for the "Complex" node and the current # op registry does not have have any default values for the "Complex" op. sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( ValueError, "Expected one attr with name .*T(out)?.* in name: \"complex\".*"): loader.load(sess, ["foo"], export_dir) # Update the Op registry to change the defaults for attr "Tout" # (complex64 -> complex128). complex_op_def.CopyFrom(original_complex_op_def) for attr_def in complex_op_def.attr: if attr_def.name == "Tout": attr_def.default_value.type = types_pb2.DT_COMPLEX128 # Loading the SavedModel via the loader must set "Tout" attr_value for the # "Complex" node according to the latest defaults (complex128). This is # expected to fail the model import as there is no OpKernel registered to # handle attrs "T" (float32) and "Tout" (complex128). sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( errors.InvalidArgumentError, ".*No OpKernel was registered to support Op \'Complex\' with these " "attrs..*"): loader.load(sess, ["foo"], export_dir)
def _is_array_type_input(op, i): registered_ops = op_def_registry.get_registered_ops() if op not in registered_ops: return False op_def = registered_ops[op] if i not in xrange(len(op_def.input_arg)): raise TypeError("Expected arg index " "to be in [0, %d)" % len(op_def.input_arg)) input_arg = op_def.input_arg[i] return True if input_arg.number_attr else False
def _is_array_type_input(op, i): registered_ops = op_def_registry.get_registered_ops() if op not in registered_ops: return False op_def = registered_ops[op] if i not in xrange(len(op_def.input_arg)): raise TypeError("Expected arg index " "to be in [0, %d)" % len(op_def.input_arg)) input_arg = op_def.input_arg[i] return True if input_arg.number_attr else False
def _register_function_ops(func_list): """Registers custom ops in the default graph. This is needed Because our checkpoint is saved with ops that are not part of Tensorflow.""" op_dict = op_def_registry.get_registered_ops() for func in func_list: #pylint: disable=W0212 func._create_definition_if_needed() op_def = func._definition.signature op_dict[op_def.name] = op_def RegisterShape(op_def.name)(common_shapes.unknown_shape)
def _register_function_ops(func_list): """Registers custom ops in the default graph. This is needed Because our checkpoint is saved with ops that are not part of Tensorflow.""" op_dict = op_def_registry.get_registered_ops() for func in func_list: #pylint: disable=W0212 func._create_definition_if_needed() op_def = func._definition.signature op_dict[op_def.name] = op_def RegisterShape(op_def.name)(common_shapes.unknown_shape)
def _add_op_node(op, func): """Converts an op to a function def node and add it to `func`.""" node = function_pb2.FunctionDef.Node() node.op = op.type # pylint: disable=protected-access if hasattr(op, "_sig"): op_def = getattr(op, "_sig") else: op_def = op_def_registry.get_registered_ops()[op.type] # pylint: enable=protected-access attrs = _get_node_def_attr(op) if not op_def.output_arg: node.ret.append(_make_argname_from_tensor_name(op.name)) else: out_index = 0 for arg_def in op_def.output_arg: if arg_def.number_attr: dtype = arg_def.type or attrs[arg_def.type_attr].type num = attrs[arg_def.number_attr].i node.ret.append( _add_output_array(op, out_index, out_index + num, dtype, func)) out_index += num elif arg_def.type_list_attr: dtype_lst = attrs[arg_def.type_list_attr].list.type num = len(dtype_lst) node.ret.append( _add_output_list(op, out_index, out_index + num, dtype_lst, func)) out_index += num else: node.ret.append( _make_argname_from_tensor_name(op.outputs[out_index].name)) out_index += 1 inp_index = 0 for arg_def in op_def.input_arg: if arg_def.number_attr: dtype = arg_def.type or attrs[arg_def.type_attr].type num = attrs[arg_def.number_attr].i node.arg.append( _add_input_array(op, inp_index, inp_index + num, dtype, func)) inp_index += num elif arg_def.type_list_attr: num = len(attrs[arg_def.type_list_attr].list.type) node.arg.extend([ _make_argname_from_tensor_name(op.inputs[i].name) for i in range(inp_index, inp_index + num) ]) inp_index += num else: node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name)) inp_index += 1 node.dep.extend( [_make_argname_from_tensor_name(x.name) for x in op.control_inputs]) for k, v in _get_node_def_attr(op).items(): node.attr[k].CopyFrom(v) func.node.extend([node])
def _stripped_op_list_for_graph(graph_def): """Returns OpDefs of ops used in graph_def.""" op_set = set() registered_ops = op_def_registry.get_registered_ops() for n in graph_def.node: if n.op in registered_ops: op_set.add(n.op) for func in graph_def.library.function: for n in func.node: if n.op in registered_ops: op_set.add(n.op) return op_def_pb2.OpList(op=[registered_ops[x] for x in sorted(op_set)])
def sync(): p_buffer = c_api.TF_GetAllOpList() cpp_op_list = op_def_pb2.OpList() cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer)) registered_ops = op_def_registry.get_registered_ops() for op_def in cpp_op_list.op: # If an OpList is registered from a gen_*_ops.py, it does not any # descriptions. Strip them here as well to satisfy validation in # register_op_list. _remove_non_deprecated_descriptions(op_def) registered_ops[op_def.name] = op_def
def list_registered_stateful_ops_without_inputs(): """Returns set of registered stateful ops that do not expect inputs. This list is used to identify the ops to be included in the state-graph and that are subsequently fed into the apply-graphs. Returns: A set of strings. """ return set([ name for name, op in op_def_registry.get_registered_ops().items() if op.is_stateful and not op.input_arg ])
def list_registered_stateful_ops_without_inputs(): """Returns set of registered stateful ops that do not expect inputs. This list is used to identify the ops to be included in the state-graph and that are subsequently fed into the apply-graphs. Returns: A set of strings. """ return set([ name for name, op in op_def_registry.get_registered_ops().items() if op.is_stateful and not op.input_arg ])
def _get_ref_args(self, node): """Determine whether an input of an op is ref-type. Args: node: A `NodeDef`. Returns: A list of the arg names (as strs) that are ref-type. """ op_def = op_def_registry.get_registered_ops().get(node.op) ref_args = [] if op_def: for i, output_arg in enumerate(op_def.output_arg): if output_arg.is_ref: arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i)) ref_args.append(arg_name) return ref_args
def _get_ref_args(self, node): """Determine whether an input of an op is ref-type. Args: node: A `NodeDef`. Returns: A list of the arg names (as strs) that are ref-type. """ op_def = op_def_registry.get_registered_ops().get(node.op) ref_args = [] if op_def: for i, output_arg in enumerate(op_def.output_arg): if output_arg.is_ref: arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i)) ref_args.append(arg_name) return ref_args
def _strip_graph_default_valued_attrs(meta_graph_def): """Strips default valued attributes for node defs in given MetaGraphDef. This method also sets `meta_info_def.stripped_default_attrs` in the given `MetaGraphDef` proto to True. Args: meta_graph_def: `MetaGraphDef` protocol buffer Returns: None. """ # Map function op names to their function definitions. op_name_to_function = {} for function_def in meta_graph_def.graph_def.library.function: op_name_to_function[function_def.signature.name] = function_def # Get all registered ops. registered_ops = op_def_registry.get_registered_ops() def _strip_node_default_valued_attrs(node_def): """Removes default valued attributes from a single node def.""" if node_def.op in op_name_to_function or node_def.op not in registered_ops: return op_def = registered_ops[node_def.op] attrs_to_strip = set() for attr_name, attr_value in node_def.attr.items(): if _is_default_attr_value(op_def, attr_name, attr_value): attrs_to_strip.add(attr_name) for attr in attrs_to_strip: del node_def.attr[attr] # Process all NodeDef instances in graph_def. for node_def in meta_graph_def.graph_def.node: _strip_node_default_valued_attrs(node_def) # Process all NodeDef instances in graph_def.library.function. for function_def in meta_graph_def.graph_def.library.function: for function_node_def in function_def.node_def: _strip_node_default_valued_attrs(function_node_def) # Tell consumers of this graph that default valued attrs have been stripped. meta_graph_def.meta_info_def.stripped_default_attrs = True
def _strip_graph_default_valued_attrs(meta_graph_def): """Strips default valued attributes for node defs in given MetaGraphDef. This method also sets `meta_info_def.stripped_default_attrs` in the given `MetaGraphDef` proto to True. Args: meta_graph_def: `MetaGraphDef` protocol buffer Returns: None. """ # Map function op names to their function definitions. op_name_to_function = {} for function_def in meta_graph_def.graph_def.library.function: op_name_to_function[function_def.signature.name] = function_def # Get all registered ops. registered_ops = op_def_registry.get_registered_ops() def _strip_node_default_valued_attrs(node_def): """Removes default valued attributes from a single node def.""" if node_def.op in op_name_to_function or node_def.op not in registered_ops: return op_def = registered_ops[node_def.op] attrs_to_strip = set() for attr_name, attr_value in node_def.attr.items(): if _is_default_attr_value(op_def, attr_name, attr_value): attrs_to_strip.add(attr_name) for attr in attrs_to_strip: del node_def.attr[attr] # Process all NodeDef instances in graph_def. for node_def in meta_graph_def.graph_def.node: _strip_node_default_valued_attrs(node_def) # Process all NodeDef instances in graph_def.library.function. for function_def in meta_graph_def.graph_def.library.function: for function_node_def in function_def.node_def: _strip_node_default_valued_attrs(function_node_def) # Tell consumers of this graph that default valued attrs have been stripped. meta_graph_def.meta_info_def.stripped_default_attrs = True
def _create_op_def_library(op_protos): for op_proto in op_protos: registered_ops = _registry.get_registered_ops() if op_proto.name not in registered_ops: raise LookupError("Op with name {0} not registered".format( op_proto.name)) op_def_lib = _op_def_library.OpDefLibrary() ops_proto = _op_def_pb2.OpList() ops_proto.op.extend([op_proto]) # Fails if the interfaces ("op schemas") don't match between the # previously registered op and this one. _registry.register_op_list(ops_proto) op_def_lib.add_op_list(ops_proto) return op_def_lib
def register_ops_if_needed(graph_ops): """Register graph ops absent in op_def_registry, if present in c++ registry. Args: graph_ops: set with graph op names to register. Raises: RuntimeError: if `graph_ops` contains ops that are not in either python or c++ registry. """ missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys()) if not missing_ops: return p_buffer = c_api.TF_GetAllOpList() cpp_op_list = op_def_pb2.OpList() cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer)) cpp_registry_ops = {op.name: op for op in cpp_op_list.op} missing_op_list = op_def_pb2.OpList() for missing_op in missing_ops: if missing_op not in cpp_registry_ops: tf.logging.info( "Op %s is missing from both the python and C++ registry.", missing_op) else: missing_op_list.op.extend([cpp_registry_ops[missing_op]]) tf.logging.info( "Adding op %s from c++ registry to python registry.", missing_op) op_def_registry.register_op_list(missing_op_list) # Note: Only raise missing op ValueError after trying to load ops. # This allows the test to exercise all the calls into TensorFlow # without having to write a C + python test. if not missing_ops <= set(cpp_registry_ops.keys()): raise RuntimeError( "Graph ops missing from the python registry (%s) are also absent from " "the c++ registry." % missing_ops.difference(set(cpp_registry_ops.keys())))
def register_ops_if_needed(graph_ops): """Register graph ops absent in op_def_registry, if present in c++ registry. Args: graph_ops: set with graph op names to register. Raises: RuntimeError: if `graph_ops` contains ops that are not in either python or c++ registry. """ missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys()) if not missing_ops: return p_buffer = c_api.TF_GetAllOpList() cpp_op_list = op_def_pb2.OpList() cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer)) cpp_registry_ops = {op.name: op for op in cpp_op_list.op} missing_op_list = op_def_pb2.OpList() for missing_op in missing_ops: if missing_op not in cpp_registry_ops: tf.logging.info( "Op %s is missing from both the python and C++ registry.", missing_op) else: missing_op_list.op.extend([cpp_registry_ops[missing_op]]) tf.logging.info( "Adding op %s from c++ registry to python registry.", missing_op) op_def_registry.register_op_list(missing_op_list) # Note: Only raise missing op ValueError after trying to load ops. # This allows the test to exercise all the calls into TensorFlow # without having to write a C + python test. if not missing_ops <= set(cpp_registry_ops.keys()): raise RuntimeError( "Graph ops missing from the python registry (%s) are also absent from " "the c++ registry." % missing_ops.difference(set(cpp_registry_ops.keys())))
def stripped_op_list_for_graph(graph_def): """Collect the stripped OpDefs for ops used by a graph. This function computes the `stripped_op_list` field of `MetaGraphDef` and similar protos. The result can be communicated from the producer to the consumer, which can then use the C++ function `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. Args: graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. Returns: An `OpList` of ops used by the graph. Raises: ValueError: If an unregistered op is used. """ # This is the Python equivalent of StrippedOpListForGraph in C++. # Unfortunately, since the Python op registry can differ from that in C++, we # can't remove the duplication using swig (at least naively). # TODO(irving): Support taking graphs directly. used_ops = ops_used_by_graph_def(graph_def) # Verify that all used ops are registered. registered_ops = op_def_registry.get_registered_ops() # These internal ops used by functions are not registered, so we need to # whitelist them. # TODO(irving): Do something better here. op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList") for op in used_ops: if op not in registered_ops and op not in op_whitelist: raise ValueError( "Op %s is used by the graph, but is not registered" % op) # Build the stripped op list in sorted order return op_def_pb2.OpList(op=[ registered_ops[op] for op in sorted(used_ops) if op in registered_ops ])
def testStripDefaultAttrsInconsistentConsumerDefaults(self): export_dir = os.path.join( test.get_temp_dir(), "test_strip_default_attrs_no_consumer_defaults") builder = saved_model_builder.SavedModelBuilder(export_dir) # Add a graph with two float32 variables and a Complex Op composing them # with strip_default_attrs enabled. This must remove the following # defaults for the "Complex" Op: # o "T" : float32. (input type) # o "Tout" : complex64. (output type) with session.Session(graph=ops.Graph()) as sess: real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables(sess, ["foo"], strip_default_attrs=True) # Save the SavedModel to disk in text format. builder.save(as_text=True) # Update the Op registry to remove defaults for all attrs("T", "Tout") from # the "Complex" OpDef. complex_op_def = op_def_registry.get_registered_ops()["Complex"] original_complex_op_def = op_def_pb2.OpDef() original_complex_op_def.CopyFrom(complex_op_def) for attr_def in complex_op_def.attr: attr_def.ClearField("default_value") # Loading the SavedModel via the loader must fail because the SavedModel # does not have any attr values for the "Complex" node and the current # op registry does not have have any default values for the "Complex" op. sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( ValueError, "Expected one attr with name .*T(out)?.* in name: \"complex\".*" ): loader.load(sess, ["foo"], export_dir) # Update the Op registry to change the defaults for attr "Tout" # (complex64 -> complex128). complex_op_def.CopyFrom(original_complex_op_def) for attr_def in complex_op_def.attr: if attr_def.name == "Tout": attr_def.default_value.type = types_pb2.DT_COMPLEX128 # Loading the SavedModel via the loader must set "Tout" attr_value for the # "Complex" node according to the latest defaults (complex128). This is # expected to fail the model import as there is no OpKernel registered to # handle attrs "T" (float32) and "Tout" (complex128). sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( errors.InvalidArgumentError, ".*No OpKernel was registered to support Op \'Complex\' with these " "attrs..*"): loader.load(sess, ["foo"], export_dir)
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None): """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all( isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())): raise TypeError( 'input_map must be a dictionary mapping strings to ' 'Tensor objects.') if return_elements is not None: return_elements = tuple(return_elements) if not all( isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() with ops.op_scope(input_map.values(), name, 'import'): g = ops.get_default_graph() g.graph_def_versions.CopyFrom(graph_def.versions) with ops.name_scope('_inputs'): input_map = { k: ops.convert_to_tensor(v) for k, v in input_map.items() } # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Set any default attr values that aren't present. op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op(node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name, ))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError( _InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name, ))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName( input_name) try: source_op = name_to_op[operation_name] source_tensor = list( source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name, ))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected_access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (", ".join( dtypes.as_dtype(x).name for x in input_types), ", ".join( x.name for x in op._input_dtypes)))) # pylint: enable=protected_access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # Apply device functions for this op. # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset( input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append( name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ graph_def = _ProcessGraphDefParam(graph_def) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) op_dict = op_def_registry.get_registered_ops() if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() if graph._c_graph: # pylint: disable=protected-access with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: with errors.raise_exception_on_not_ok_status() as status: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options, status) # pylint: disable=protected-access except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) _ProcessNewOps(graph) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results)) if missing_unused_input_keys: missing_unused_input_keys = [compat.as_str(s) for s in missing_unused_input_keys] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results) else: g = graph # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning # is that function names are similar to op definition names, which # currently do not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(g) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) input_map = _ConvertInputMapValues(name, input_map) # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) # Set any default attr values that aren't present. if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # Maps from a node to the ops it is colocated with, if colocation # is specified in the attributes. colocation_pairs = collections.defaultdict(list) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) apply_device_function = True # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError('Specified colocation to an op that ' 'does not exist during import: %s in %s' % ( op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append(compat.as_bytes( 'loc:@' + original_op.name)) if op_to_bind_to != node.name: # Keep track of this mapping for a later phase. colocation_pairs[op].append(original_op) # Don't apply this op's device function, # the colocation constraint will ensure # the proper device gets assigned at runtime. apply_device_function = False else: new_class_values.append(class_value) value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not # appear in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError(_InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_types != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join(dtypes.as_dtype(x).name for x in input_types), ', '.join(x.name for x in op._input_types)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape # information may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [dim.size if dim.size >= 0 else None for dim in dims.dim]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'LMDBReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity']: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. if apply_device_function: with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # The following loop populates the device field of ops that are # colocated with another op. This is implied by the colocation # attribute, but we propagate the device field for completeness. for op, coloc_op_list in colocation_pairs.items(): coloc_device = None # Find any device in the list of colocated ops that have a # device, if it exists. We assume that if multiple ops # have devices, they refer to the same device. Otherwise, a # runtime error will occur since the colocation property # cannot be guaranteed. # # One possible improvement is to try to check for compatibility # of all devices in this list at import time here, which would # require implementing a compatibility function for device specs # in python. for coloc_op in coloc_op_list: if coloc_op.device: coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) break if coloc_device: op._set_device(coloc_device) # pylint: disable=protected-access # Treat input mappings that don't appear in the graph as an error, # because they are likely to be due to a typo. def _IsImportedNodeOutput(tensor_name): operation_name, output_index = _ParseTensorName(tensor_name) try: return output_index < len(name_to_op[operation_name].outputs) except KeyError: return False absent_input_keys = [ k for k in frozenset(input_map.keys()).difference(used_input_keys) if not _IsImportedNodeOutput(k)] if absent_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(absent_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ op_dict = op_def_registry.get_registered_ops() graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() if graph._c_graph: # pylint: disable=protected-access with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) # _ProcessNewOps mutates the new operations. _lock ensures a Session.run # call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. with graph._lock: # pylint: disable=protected-access with c_api_util.tf_buffer( graph_def.SerializeToString()) as serialized: try: with errors.raise_exception_on_not_ok_status() as status: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options, status) # pylint: disable=protected-access except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) _ProcessNewOps(graph) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results)) if missing_unused_input_keys: missing_unused_input_keys = [ compat.as_str(s) for s in missing_unused_input_keys ] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results) else: g = graph # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning # is that function names are similar to op definition names, which # currently do not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(g) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) input_map = _ConvertInputMapValues(name, input_map) # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op(node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # Maps from a node to the ops it is colocated with, if colocation # is specified in the attributes. colocation_pairs = collections.defaultdict(list) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) apply_device_function = True # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError( 'Specified colocation to an op that ' 'does not exist during import: %s in %s' % (op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append( compat.as_bytes('loc:@' + original_op.name)) if op_to_bind_to != node.name: # Keep track of this mapping for a later phase. colocation_pairs[op].append(original_op) # Don't apply this op's device function, # the colocation constraint will ensure # the proper device gets assigned at runtime. apply_device_function = False else: new_class_values.append(class_value) value.list.CopyFrom( attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not # appear in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name, ))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError( _InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name, ))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName( input_name) try: source_op = name_to_op[operation_name] source_tensor = list( source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name, ))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_types != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join( dtypes.as_dtype(x).name for x in input_types), ', '.join( x.name for x in op._input_types)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape # information may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [ dim.size if dim.size >= 0 else None for dim in dims.dim ]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in [ 'RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'LMDBReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity' ]: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. if apply_device_function: with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # The following loop populates the device field of ops that are # colocated with another op. This is implied by the colocation # attribute, but we propagate the device field for completeness. for op, coloc_op_list in colocation_pairs.items(): coloc_device = None # Find any device in the list of colocated ops that have a # device, if it exists. We assume that if multiple ops # have devices, they refer to the same device. Otherwise, a # runtime error will occur since the colocation property # cannot be guaranteed. # # One possible improvement is to try to check for compatibility # of all devices in this list at import time here, which would # require implementing a compatibility function for device specs # in python. for coloc_op in coloc_op_list: if coloc_op.device: coloc_device = pydev.DeviceSpec.from_string( coloc_op.device) break if coloc_device: op._set_device(coloc_device) # pylint: disable=protected-access # Treat input mappings that don't appear in the graph as an error, # because they are likely to be due to a typo. def _IsImportedNodeOutput(tensor_name): operation_name, output_index = _ParseTensorName(tensor_name) try: return output_index < len( name_to_op[operation_name].outputs) except KeyError: return False absent_input_keys = [ k for k in frozenset(input_map.keys()).difference( used_input_keys) if not _IsImportedNodeOutput(k) ] if absent_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(absent_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName( name) ret.append(name_to_op[operation_name]. outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def _get_op_def(op): # pylint: disable=protected-access if hasattr(op, "_sig"): return getattr(op, "_sig") else: return op_def_registry.get_registered_ops()[op.type]
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, attrs for ops in `graph_def` that are not in `op_dict` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all(isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())): raise TypeError('input_map must be a dictionary mapping strings to ' 'Tensor objects.') if return_elements is not None: return_elements = tuple(return_elements) if not all(isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() if producer_op_list is None: producer_op_dict = None else: producer_op_dict = {op.name: op for op in producer_op_list.op} g = ops.get_default_graph() # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning is # that function names are similar to op definition names, which currently do # not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(g) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) if not all(isinstance(v, ops.Tensor) for v in input_map.values()): if not scope: # The caller must have passed `name=''`. raise ValueError( 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' 'contains non-Tensor values. Try calling tf.convert_to_tensor() on ' '`input_map` values before calling tf.import_graph_def().') with ops.name_scope('_inputs'): input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) # Set any default attr values that aren't present. if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) if producer_op_dict: # Remove any default attr values that aren't in op_def. if node.op in producer_op_dict: producer_op_def = producer_op_dict[node.op] # We make a copy of node.attr to iterate through since we # may modify node.attr inside the loop. for key in list(node.attr): if _FindAttrInOpDef(key, op_def) is None: # No attr_def in consumer, look in producer. attr_def = _FindAttrInOpDef(key, producer_op_def) if (attr_def and attr_def.HasField('default_value') and node.attr[key] == attr_def.default_value): # Unknown attr had default value in producer, delete it # so it can be understood by consumer. del node.attr[key] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # Maps from a node to the ops it is colocated with, if colocation # is specified in the attributes. colocation_pairs = collections.defaultdict(list) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) apply_device_function = True # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError('Specified colocation to an op that ' 'does not exist during import: %s in %s' % ( op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append(compat.as_bytes( 'loc:@' + original_op.name)) if op_to_bind_to != node.name: # Keep track of this mapping for a later phase. colocation_pairs[op].append(original_op) # Don't apply this op's device function, # the colocation constraint will ensure # the proper device gets assigned at runtime. apply_device_function = False else: new_class_values.append(class_value) value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError(_InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join(dtypes.as_dtype(x).name for x in input_types), ', '.join(x.name for x in op._input_dtypes)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [dim.size if dim.size >= 0 else None for dim in dims.dim]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'LMDBReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity']: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. if apply_device_function: with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # The following loop populates the device field of ops that are # colocated with another op. This is implied by the colocation # attribute, but we propagate the device field for completeness. for op, coloc_op_list in colocation_pairs.items(): coloc_device = None # Find any device in the list of colocated ops that have a # device, if it exists. We assume that if multiple ops # have devices, they refer to the same device. Otherwise, a # runtime error will occur since the colocation property # cannot be guaranteed. # # One possible improvement is to try to check for compatibility # of all devices in this list at import time here, which would # require implementing a compatibility function for device specs # in python. for coloc_op in coloc_op_list: if coloc_op.device: coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) break if coloc_device: op._set_device(coloc_device) # pylint: disable=protected-access # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def _get_op_def(op): return op.op_def or op_def_registry.get_registered_ops()[op.type]
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, attrs for ops in `graph_def` that are not in `op_dict` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all( isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())): raise TypeError( 'input_map must be a dictionary mapping strings to ' 'Tensor objects.') if return_elements is not None: return_elements = tuple(return_elements) if not all( isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() if producer_op_list is None: producer_op_dict = None else: producer_op_dict = {op.name: op for op in producer_op_list.op} with ops.op_scope(input_map.values(), name, 'import'): g = ops.get_default_graph() g.graph_def_versions.CopyFrom(graph_def.versions) with ops.name_scope('_inputs'): input_map = { k: ops.convert_to_tensor(v) for k, v in input_map.items() } # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Set any default attr values that aren't present. op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) if producer_op_dict: # Remove any default attr values that aren't in op_def. if node.op in producer_op_dict: producer_op_def = producer_op_dict[node.op] # We make a copy of node.attr to iterate through since we # may modify node.attr inside the loop. for key in list(node.attr): if _FindAttrInOpDef(key, op_def) is None: # No attr_def in consumer, look in producer. attr_def = _FindAttrInOpDef(key, producer_op_def) if (attr_def and attr_def.HasField('default_value') and node.attr[key] == attr_def.default_value): # Unknown attr had default value in producer, delete it # so it can be understood by consumer. del node.attr[key] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op(node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError( 'Specified colocation to an op that ' 'does not exist during import: %s in %s' % (op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append( compat.as_bytes('loc:@' + original_op.name)) else: new_class_values.append(class_value) value.list.CopyFrom( attr_value_pb2.AttrValue.ListValue(s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name, ))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError( _InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name, ))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName( input_name) try: source_op = name_to_op[operation_name] source_tensor = list( source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name, ))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected_access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join( dtypes.as_dtype(x).name for x in input_types), ', '.join( x.name for x in op._input_dtypes)))) # pylint: enable=protected_access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [ dim.size if dim.size >= 0 else None for dim in dims.dim ]) output.set_shape(output_shape) del op.node_def.attr['_output_shapes'] # Apply device functions for this op. # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset( input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append( name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ op_dict = op_def_registry.get_registered_ops() graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) # _ProcessNewOps mutates the new operations. _mutation_lock ensures a # Session.run call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. with graph._mutation_lock(): # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options) # pylint: disable=protected-access results = c_api_util.ScopedTFImportGraphDefResults(results) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph # TODO(b/74620627): move this after _ProcessNewOps outside the lock once # _USE_C_SHAPES is removed. if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access _ProcessNewOps(graph) # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results.results)) if missing_unused_input_keys: missing_unused_input_keys = [ compat.as_str(s) for s in missing_unused_input_keys ] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results.results)
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None): """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements'. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map' is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all(isinstance(k, six.string_types) for k in input_map.keys())): raise TypeError('input_map must be a dictionary mapping strings to ' 'Tensor objects.') if (return_elements is not None and not (isinstance(return_elements, (list, tuple)) and all(isinstance(x, six.string_types) for x in return_elements))): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() with ops.op_scope(input_map.values(), name, 'import'): g = ops.get_default_graph() with ops.name_scope('_inputs'): input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def'. # 1. Add operations without their inputs. for node in graph_def.node: output_types = _OutputTypes(node, op_dict) with _MaybeDevice(node.device): name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def'. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError( _InvalidNodeMessage(node, 'Input tensor %r %s' % (input_name, te.message))) # pylint: disable=protected_access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (", ".join(types_lib.as_dtype(x).name for x in input_types), ", ".join(x.name for x in op._input_dtypes)))) # pylint: enable=protected_access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def function_def_to_graph_def(fdef, input_shapes=None): """Convert a FunctionDef to a GraphDef. Steps: 1. Creates placeholder nodes corresponding to inputs in `FunctionDef.signature.input_arg`. 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. 3. Renames inputs of all nodes to use the convention of GraphDef instead of FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming in FunctionDefs is different from GraphDefs. Args: fdef: FunctionDef. input_shapes: Optional. A list of TensorShape objects of the shapes of function inputs. If specified, its length must match length of `fdef.signature.input_arg`. If a shape is None, the corresponding input placeholder will have unknown shape. Returns: A tuple of (GraphDef, dict<string, string>). The dict contains a mapping from nested tensor names (in FunctionDef) to flattened names (in GraphDef). Raises: ValueError: If the length of input_shapes does not match the number of input_args or if the FunctionDef is invalid. """ graph_def = graph_pb2.GraphDef() graph_def.versions.CopyFrom( versions_pb2.VersionDef( producer=versions.GRAPH_DEF_VERSION, min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): raise ValueError("Length of input_shapes must match the number of " + "input_args. len(input_shapes): {} len(input_arg): {}". format(len(input_shapes), len(fdef.signature.input_arg))) # 1. Create placeholders for input nodes. for i, arg_def in enumerate(fdef.signature.input_arg): node_def = graph_def.node.add() node_def.name = arg_def.name node_def.op = "Placeholder" node_def.attr["dtype"].type = arg_def.type if input_shapes and input_shapes[i] is not None: node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto()) # 2. Copy all body NodeDefs to the GraphDef. graph_def.node.extend(fdef.node_def) # 3. Perform the renaming. # Build the tensor name mapping then flatten the tensor names. # See comment on `FunctionDef.node_def` on how the tensor naming in # FunctionDefs is different from GraphDefs. nested_to_flat_tensor_name = {} for arg_def in fdef.signature.input_arg: nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) for node_def in fdef.node_def: op_def = op_def_registry.get_registered_ops().get(node_def.op) if not op_def: # TODO(b/80470245): Support functions which refer other functions. raise NotImplementedError( "No op registered for {},".format(node_def.op) + " it may be a function. function_def_to_graph_def " + "currently does not support converting functions with " + "references to other graph functions.") for attr in op_def.attr: if attr.type in ("func", "list(func)"): # TODO(b/80470245): Support functions which refer other functions. raise NotImplementedError("Unsupported attr {} ".format(attr.name) + " with type {}".format(attr.type) + " in op {}. ".format(op_def.name) + "function_def_to_graph_def currently does " + "not support converting functions with " + "references to other graph functions.") # Iterate over output_args in op_def to build the map. # Index of the output tensor in the flattened list of *all* output # tensors of the op. flattened_index = 0 for arg_def in op_def.output_arg: num_args = _get_num_args(arg_def, node_def) for i in range(num_args): # Map tensor names from "node_name:output_arg_name:index" to # "node_name:flattened_index". nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) flat_name = "{}:{}".format(node_def.name, flattened_index) nested_to_flat_tensor_name[nested_name] = flat_name flattened_index += 1 # Update inputs of all nodes in graph. for node_def in graph_def.node: for i in range(len(node_def.input)): node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] return graph_def, nested_to_flat_tensor_name
from tensorflow.python.framework import constant_op from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top _REGISTERED_OPS = op_def_registry.get_registered_ops() def enable_jit_nonstateful(node_def): try: return not _REGISTERED_OPS[node_def.op].is_stateful except KeyError: raise ValueError("Unregistered op being created: %s" % node_def) class JITTest(test.TestCase): def compute(self, use_jit, compute_fn): random_seed.set_random_seed(1234) with self.test_session(graph=ops.Graph()) as sess: with jit.experimental_jit_scope(use_jit):
def _stripped_op_list_for_graph(graph_def): registered_ops = op_def_registry.get_registered_ops() used_ops = {n.op for n in graph_def.node} op_list = [registered_ops[op_name] for op_name in sorted(used_ops)] return op_def_pb2.OpList(op=op_list)
def function_def_to_graph_def(fdef, input_shapes=None): """Convert a FunctionDef to a GraphDef. Steps: 1. Creates placeholder nodes corresponding to inputs in `FunctionDef.signature.input_arg`. 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. 3. Renames inputs of all nodes to use the convention of GraphDef instead of FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming in FunctionDefs is different from GraphDefs. Args: fdef: FunctionDef. input_shapes: Optional. A list of TensorShape objects of the shapes of function inputs. If specified, its length must match length of `fdef.signature.input_arg`. If a shape is None, the corresponding input placeholder will have unknown shape. Returns: A tuple of (GraphDef, dict<string, string>). The dict contains a mapping from nested tensor names (in FunctionDef) to flattened names (in GraphDef). Raises: ValueError: If the length of input_shapes does not match the number of input_args or if the FunctionDef is invalid. """ graph_def = graph_pb2.GraphDef() graph_def.versions.CopyFrom( versions_pb2.VersionDef( producer=versions.GRAPH_DEF_VERSION, min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): raise ValueError("Length of input_shapes must match the number of " + "input_args. len(input_shapes): {} len(input_arg): {}". format(len(input_shapes), len(fdef.signature.input_arg))) # 1. Create placeholders for input nodes. for i, arg_def in enumerate(fdef.signature.input_arg): node_def = graph_def.node.add() node_def.name = arg_def.name node_def.op = "Placeholder" node_def.attr["dtype"].type = arg_def.type if input_shapes and input_shapes[i] is not None: node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto()) # 2. Copy all body NodeDefs to the GraphDef. graph_def.node.extend(fdef.node_def) # 3. Perform the renaming. # Build the tensor name mapping then flatten the tensor names. # See comment on `FunctionDef.node_def` on how the tensor naming in # FunctionDefs is different from GraphDefs. nested_to_flat_tensor_name = {} for arg_def in fdef.signature.input_arg: nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) for node_def in fdef.node_def: op_def = op_def_registry.get_registered_ops().get(node_def.op) if not op_def: # TODO(b/80470245): Support functions which refer other functions. raise NotImplementedError( "No op registered for {},".format(node_def.op) + " it may be a function. function_def_to_graph_def " + "currently does not support converting functions with " + "references to other graph functions.") for attr in op_def.attr: if attr.type in ("func", "list(func)"): # TODO(b/80470245): Support functions which refer other functions. raise NotImplementedError("Unsupported attr {} ".format(attr.name) + " with type {}".format(attr.type) + " in op {}. ".format(op_def.name) + "function_def_to_graph_def currently does " + "not support converting functions with " + "references to other graph functions.") # Iterate over output_args in op_def to build the map. # Index of the output tensor in the flattened list of *all* output # tensors of the op. flattened_index = 0 for arg_def in op_def.output_arg: num_args = _get_num_args(arg_def, node_def) for i in range(num_args): # Map tensor names from "node_name:output_arg_name:index" to # "node_name:flattened_index". nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) flat_name = "{}:{}".format(node_def.name, flattened_index) nested_to_flat_tensor_name[nested_name] = flat_name flattened_index += 1 # Update inputs of all nodes in graph. for node_def in graph_def.node: for i in range(len(node_def.input)): node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] return graph_def, nested_to_flat_tensor_name
return ('?', name, util.is_tensor) def Variable(name=None): return ('?', name, util.is_var) def Const(name=None): return ('?', name, util.is_const) def Placeholder(name=None): return ('?', name, util.is_placeholder) _op_names = op_def_registry.get_registered_ops().keys() util.import_ops_no_clobber(globals(), _op_names) # NOTE(mattjj): renamed in TF 1.0, but not registered as an op in 1.0.1 Unstack = util.make_op_pattern('Unpack') # pylint: disable=invalid-name ## convenient compound patterns # The op definitions are pulled in via the op_def_registry, which is # why we disable the undefined variable check for e.g. Rsqrt, Mul, etc. # Otherwise we would have to refer to them by name rather than object. # pylint: disable=undefined-variable def BatchNorm(in_pattern=Tensor('in'), scale_name='scale',
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None): """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all(isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())): raise TypeError('input_map must be a dictionary mapping strings to ' 'Tensor objects.') if return_elements is not None: return_elements = tuple(return_elements) if not all(isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() with ops.op_scope(input_map.values(), name, 'import'): g = ops.get_default_graph() g.graph_def_versions.CopyFrom(graph_def.versions) with ops.name_scope('_inputs'): input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Set any default attr values that aren't present. op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError('Specified colocation to an op that ' 'does not exist during import: %s in %s' % ( op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append(compat.as_bytes( 'loc:@' + original_op.name)) else: new_class_values.append(class_value) value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError(_InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected_access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (", ".join(dtypes.as_dtype(x).name for x in input_types), ", ".join(x.name for x in op._input_dtypes)))) # pylint: enable=protected_access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # Apply device functions for this op. # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def _get_op_def(op): return op.op_def or op_def_registry.get_registered_ops()[op.type]
def _get_op_def(op): # pylint: disable=protected-access if hasattr(op, "_sig"): return getattr(op, "_sig") else: return op_def_registry.get_registered_ops()[op.type]
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, attrs for ops in `graph_def` that are not in `op_dict` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all(isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())): raise TypeError('input_map must be a dictionary mapping strings to ' 'Tensor objects.') if return_elements is not None: return_elements = tuple(return_elements) if not all(isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() if producer_op_list is None: producer_op_dict = None else: producer_op_dict = {op.name: op for op in producer_op_list.op} g = ops.get_default_graph() # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning is # that function names are similar to op definition names, which currently do # not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: g._add_function(f) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) if not all(isinstance(v, ops.Tensor) for v in input_map.values()): if not scope: # The caller must have passed `name=''`. raise ValueError( 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' 'contains non-Tensor values. Try calling tf.convert_to_tensor() on ' '`input_map` values before calling tf.import_graph_def().') with ops.name_scope('_inputs'): input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Set any default attr values that aren't present. if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) if producer_op_dict: # Remove any default attr values that aren't in op_def. if node.op in producer_op_dict: producer_op_def = producer_op_dict[node.op] # We make a copy of node.attr to iterate through since we # may modify node.attr inside the loop. for key in list(node.attr): if _FindAttrInOpDef(key, op_def) is None: # No attr_def in consumer, look in producer. attr_def = _FindAttrInOpDef(key, producer_op_def) if (attr_def and attr_def.HasField('default_value') and node.attr[key] == attr_def.default_value): # Unknown attr had default value in producer, delete it # so it can be understood by consumer. del node.attr[key] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError('Specified colocation to an op that ' 'does not exist during import: %s in %s' % ( op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append(compat.as_bytes( 'loc:@' + original_op.name)) else: new_class_values.append(class_value) value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError(_InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join(dtypes.as_dtype(x).name for x in input_types), ', '.join(x.name for x in op._input_dtypes)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [dim.size if dim.size >= 0 else None for dim in dims.dim]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity']: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # Apply device functions for this op. # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as `tf.Tensor` and `tf.Operation` objects. Once extracted, these objects are placed into the current default `Graph`. See `tf.Graph.as_graph_def` for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ op_dict = op_def_registry.get_registered_ops() graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) # _ProcessNewOps mutates the new operations. _mutation_lock ensures a # Session.run call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. with graph._mutation_lock(): # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options) # pylint: disable=protected-access results = c_api_util.ScopedTFImportGraphDefResults(results) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph # TODO(b/74620627): move this after _ProcessNewOps outside the lock once # _USE_C_SHAPES is removed. if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access _ProcessNewOps(graph) # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results.results)) if missing_unused_input_keys: missing_unused_input_keys = [ compat.as_str(s) for s in missing_unused_input_keys ] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results.results)
from tensorflow.python.framework import constant_op from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top _REGISTERED_OPS = op_def_registry.get_registered_ops() def enable_jit_nonstateful(node_def): try: return not _REGISTERED_OPS[node_def.op].is_stateful except KeyError: raise ValueError("Unregistered op being created: %s" % node_def) class JITTest(test.TestCase): def compute(self, use_jit, compute_fn): random_seed.set_random_seed(1234) with self.test_session(graph=ops.Graph()) as sess: with jit.experimental_jit_scope(use_jit):
def get(name): registered_ops = op_def_registry.get_registered_ops() return registered_ops.get(name)