def testStrippedOpListNestedFunctions(self): with self.test_session(): # Square two levels deep @function.Defun(dtypes.int32) def f0(x): return math_ops.square(x) @function.Defun(dtypes.int32) def f1(x): return f0(x) # At this point we've defined two functions but haven't called them, so # there should be no used ops. op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph() .as_graph_def()) self.assertEqual(len(op_list.op), 0) # If we call the function on a constant, there should be two ops _ = f1(constant_op.constant(7)) op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph() .as_graph_def()) self.assertEqual(["Const", "Square"], [op.name for op in op_list.op])
def testStrippedOpListNestedFunctions(self): with self.cached_session(): # Square two levels deep @function.Defun(dtypes.int32) def f0(x): return math_ops.square(x) @function.Defun(dtypes.int32) def f1(x): return f0(x) # At this point we've defined two functions but haven't called them, so # there should be no used ops. op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph() .as_graph_def()) self.assertEqual(len(op_list.op), 0) # If we call the function on a constant, there should be two ops _ = f1(constant_op.constant(7)) op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph() .as_graph_def()) self.assertEqual(["Const", "Square"], [op.name for op in op_list.op])
def testStrippedOpListRecursiveFunctions(self): # The function module doesn't support recursive functions, so we build a # recursive function situation by ourselves: A calls B calls A and Const. graph = graph_pb2.GraphDef() a = graph.library.function.add() b = graph.library.function.add() a.signature.name = "A" b.signature.name = "B" a.node_def.add().op = "B" b.node_def.add().op = "Const" b.node_def.add().op = "A" # Use A in the graph graph.node.add().op = "A" # The stripped op list should contain just Const. op_list = meta_graph.stripped_op_list_for_graph(graph) self.assertEqual(["Const"], [op.name for op in op_list.op])
def testStrippedOpListPartitionedCalls(self): # Function A calls B via StatefulPartitionedCall. graph = graph_pb2.GraphDef() a = graph.library.function.add() b = graph.library.function.add() a.signature.name = "A" b.signature.name = "B" node_in_a = a.node_def.add() node_in_a.op = "StatefulPartitionedCall" node_in_a.attr["f"].func.name = "B" b.node_def.add().op = "Const" b.node_def.add().op = "A" # Use A in the graph via PartitionedCall. node = graph.node.add() node.op = "PartitionedCall" node.attr["f"].func.name = "A" op_list = meta_graph.stripped_op_list_for_graph(graph) self.assertSameElements( ["Const", "PartitionedCall", "StatefulPartitionedCall"], [op.name for op in op_list.op])
def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, namespace_whitelist): """Generates a MetaGraph which calls `signature_functions`. Args: meta_graph_def: The MetaGraphDef proto to fill. saveable_view: The _SaveableView being exported. signature_functions: A dictionary mapping signature keys to concrete functions containing signatures to add to the MetaGraph. namespace_whitelist: List of strings containing whitelisted op namespaces. Returns: A tuple of (_AssetInfo, Graph) containing the captured assets and exported Graph generated from tracing the saveable_view. """ # List objects from the eager context to make sure Optimizers give us the # right Graph-dependent variables. accessible_objects = saveable_view.nodes resource_initializer_functions = _trace_resource_initializers( accessible_objects) exported_graph = ops.Graph() resource_initializer_ops = [] with exported_graph.as_default(): object_map, resource_map, asset_info = saveable_view.map_resources() for resource_initializer_function in resource_initializer_functions: asset_dependencies = [] for capture in resource_initializer_function.graph.external_captures: asset_initializer = asset_info.asset_initializers_by_resource.get( capture, None) if asset_initializer is not None: asset_dependencies.append(asset_initializer) with ops.control_dependencies(asset_dependencies): resource_initializer_ops.append( _call_function_with_mapped_captures( resource_initializer_function, [], resource_map)) resource_initializer_ops.extend( asset_info.asset_initializers_by_resource.values()) with ops.control_dependencies(resource_initializer_ops): init_op = control_flow_ops.no_op() # Add the same op to the main_op collection and to the init_op # signature. The collection is for compatibility with older loader APIs; # only one will be executed. meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( init_op.name) meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( signature_def_utils.op_signature_def( init_op, constants.INIT_OP_SIGNATURE_KEY)) # Saving an object-based checkpoint again gathers variables. We need to do the # gathering from the eager context so Optimizers save the right set of # variables, but want any operations associated with the save/restore to be in # the exported graph (thus the `to_graph` argument). saver = functional_saver.MultiDeviceSaver( saveable_view.checkpoint_view.frozen_saveable_objects( object_map=object_map, to_graph=exported_graph)) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) for concrete_function in saveable_view.concrete_functions: concrete_function.add_to_graph() saver_def = saver.to_proto() meta_graph_def.saver_def.CopyFrom(saver_def) graph_def = exported_graph.as_graph_def(add_shapes=True) _verify_ops(graph_def, namespace_whitelist) meta_graph_def.graph_def.CopyFrom(graph_def) meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) meta_graph_def.meta_info_def.tensorflow_version = versions.__version__ meta_graph_def.meta_info_def.tensorflow_git_version = ( versions.__git_version__) # We currently always strip default attributes. meta_graph_def.meta_info_def.stripped_default_attrs = True meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) meta_graph_def.asset_file_def.extend(asset_info.asset_defs) for signature_key, signature in signatures.items(): meta_graph_def.signature_def[signature_key].CopyFrom(signature) meta_graph.strip_graph_default_valued_attrs(meta_graph_def) return asset_info, exported_graph
def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions): """Generates a MetaGraph which calls `signature_functions`. Args: meta_graph_def: The MetaGraphDef proto to fill. saveable_view: The _SaveableView being exported. signature_functions: A dictionary mapping signature keys to concrete functions containing signatures to add to the MetaGraph. Returns: An _AssetInfo, which contains information to help creating the SavedModel. """ # List objects from the eager context to make sure Optimizers give us the # right Graph-dependent variables. accessible_objects = saveable_view.nodes resource_initializer_functions = _trace_resource_initializers( accessible_objects) exported_graph = ops.Graph() resource_initializer_ops = [] with exported_graph.as_default(): object_map, resource_map, asset_info = saveable_view.map_resources() for resource_initializer_function in resource_initializer_functions: asset_dependencies = [] for capture in resource_initializer_function.graph.external_captures: asset_initializer = asset_info.asset_initializers_by_resource.get( capture, None) if asset_initializer is not None: asset_dependencies.append(asset_initializer) with ops.control_dependencies(asset_dependencies): resource_initializer_ops.append( _call_function_with_mapped_captures( resource_initializer_function, [], resource_map)) resource_initializer_ops.extend( asset_info.asset_initializers_by_resource.values()) with ops.control_dependencies(resource_initializer_ops): init_op = control_flow_ops.no_op() # Add the same op to the main_op collection and to the init_op # signature. The collection is for compatibility with older loader APIs; # only one will be executed. meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( init_op.name) meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( signature_def_utils.op_signature_def( init_op, constants.INIT_OP_SIGNATURE_KEY)) # Saving an object-based checkpoint again gathers variables. We need to do the # gathering from the eager context so Optimizers save the right set of # variables, but want any operations associated with the save/restore to be in # the exported graph (thus the `to_graph` argument). saver = functional_saver.MultiDeviceSaver( saveable_view.checkpoint_view.frozen_saveable_objects( object_map=object_map, to_graph=exported_graph)) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) for concrete_function in saveable_view.concrete_functions: concrete_function.add_to_graph() saver_def = saver.to_proto() meta_graph_def.saver_def.CopyFrom(saver_def) graph_def = exported_graph.as_graph_def(add_shapes=True) meta_graph_def.graph_def.CopyFrom(graph_def) meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) meta_graph_def.meta_info_def.tensorflow_version = versions.__version__ meta_graph_def.meta_info_def.tensorflow_git_version = ( versions.__git_version__) # We currently always strip default attributes. meta_graph_def.meta_info_def.stripped_default_attrs = True meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) meta_graph_def.asset_file_def.extend(asset_info.asset_defs) for signature_key, signature in signatures.items(): meta_graph_def.signature_def[signature_key].CopyFrom(signature) meta_graph.strip_graph_default_valued_attrs(meta_graph_def) return asset_info, exported_graph