def _parse_input_saver_proto(input_saver, input_binary): """Parser input tensorflow Saver into SaverDef proto.""" if not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 mode = "rb" if input_binary else "r" with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) return saver_def
def to_proto(self): """Serializes to a SaverDef referencing the current graph.""" filename_tensor = array_ops.placeholder(shape=[], dtype=dtypes.string, name="saver_filename") # TODO(allenl): Add save and restore function names to the proto directly. signature = (tensor_spec.TensorSpec(shape=(), dtype=dtypes.string), ) # Autograph is off because of reference cycles which must be collected when # a function is created and destroyed (as in tf.saved_model.save). It's also # not necessary, so having it off may be slightly faster. # # TODO(b/121302372): We should be able to decorate save() and restore() # unconditionally. save_tensor = def_function.function(self.save, input_signature=signature, autograph=False)(filename_tensor) restore_op = def_function.function(self.restore, input_signature=signature, autograph=False)(filename_tensor).op return saver_pb2.SaverDef(filename_tensor_name=filename_tensor.name, save_tensor_name=save_tensor.name, restore_op_name=restore_op.name, version=saver_pb2.SaverDef.V2)
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read().decode("utf-8"), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver: with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = saver_lib.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver: with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = saver_lib.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def __init__(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=saver_pb2.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None): # pylint: disable=too-many-arguments, too-many-locals """ Saver for AutoDist. This saver saves the variables that maps to the *original*, *user-declared*, *single-node* graph, instead of the transformed graph. Hence, the saved variables can be loaded either by the original user code (for resuming single-node training or inference), or by the AutoDist-distributed code (to resume distributed training). Differently, AutoDist saver saves the meta_graph_def that maps to the *AutoDist transformed* graph (TODO). AutoDist saver implements this by writing/loading the contents that maps to the master_replica (default to replica_index=0) of the transformed graph. Refer to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py for a detailed explanation of the signature. """ # TODO(Hao): support constructing saver after AutoDist graph transformation _autodist = autodist.autodist.get_default_autodist() if _autodist.is_built(): raise ValueError( 'Saver must be used before create_distributed_session().') # A saver will append relevant save/restore ops to all variables in var_list, i.e. one saver # maps to all variables, and encloses them under a "saver" scope. super(Saver, self).__init__( var_list=var_list, reshape=reshape, sharded=sharded, max_to_keep=max_to_keep, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, name=name, restore_sequentially=restore_sequentially, saver_def=saver_def, builder=builder, defer_build=defer_build, allow_empty=allow_empty, write_version=write_version, pad_step_number=pad_step_number, save_relative_paths=save_relative_paths, filename=filename) # Note: tensorflow does not add user-declared saver to collections, so have to track it in info. item = get_default_graph_item() new_saver_def = saver_pb2.SaverDef() new_saver_def.CopyFrom(self.to_proto()) item.info.update_savers([new_saver_def], replace=False)
def _build_internal(self, names_to_saveables, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, filename="model", build_save=True, build_restore=True): """build() with option to only perform save and restore.""" if not context.executing_eagerly() and (not build_save or not build_restore): raise ValueError("save and restore operations need to be built together " " when eager execution is not enabled.") saveables = saveable_object_util.validate_and_slice_inputs( names_to_saveables) if max_to_keep is None: max_to_keep = 0 with ops.name_scope(name, "save", [saveable.op for saveable in saveables]) as name: # Add a placeholder string tensor for the filename. filename_tensor = array_ops.placeholder_with_default( filename or "model", shape=(), name="filename") # Keep the name "Const" for backwards compatibility. filename_tensor = array_ops.placeholder_with_default( filename_tensor, shape=(), name="Const") # Add the save ops. if sharded: per_device = self._GroupByDevices(saveables) if build_save: op_list = [] with tf.name_scope("Save_Weight_Update_Sharding"): grad_and_var_items = util.get_all_grad_item() for item in grad_and_var_items: if item.var in names_to_saveables: rank_id = item.root_rank_id if rank_id >= 0: with tf.get_default_graph().control_dependencies(op_list): out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id) op_list.append(out_var[0].op) if len(op_list) > 0: with tf.get_default_graph().control_dependencies(op_list): save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) else: save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) if build_restore: restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, restore_sequentially, reshape) else: if build_save: op_list = [] with tf.name_scope("Save_Weight_Update_Sharding"): grad_and_var_items = util.get_all_grad_item() for item in grad_and_var_items: if item.var in names_to_saveables: rank_id = item.root_rank_id if rank_id >= 0: with tf.get_default_graph().control_dependencies(op_list): out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id) op_list.append(out_var[0].op) if len(op_list) > 0: with tf.get_default_graph().control_dependencies(op_list): save_tensor = self._AddSaveOps(filename_tensor, saveables) else: save_tensor = self._AddSaveOps(filename_tensor, saveables) if build_restore: restore_op = self._AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape) # In the following use case, it's possible to have restore_ops be called # something else: # - Build inference graph and export a meta_graph. # - Import the inference meta_graph # - Extend the inference graph to a train graph. # - Export a new meta_graph. # Now the second restore_op will be called "restore_all_1". # As such, comment out the assert for now until we know whether supporting # such usage model makes sense. # # assert restore_op.name.endswith("restore_all"), restore_op.name if context.executing_eagerly(): # Store the tensor values to the tensor_names. save_tensor_name = save_tensor.numpy() if build_save else "" return saver_pb2.SaverDef( filename_tensor_name=filename_tensor.numpy(), save_tensor_name=save_tensor_name, restore_op_name="", max_to_keep=max_to_keep, sharded=sharded, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, version=self._write_version) else: graph = ops.get_default_graph() # Do some sanity checking on collections containing # PartitionedVariables. If a saved collection has a PartitionedVariable, # the GraphDef needs to include concat ops to get the value (or there'll # be a lookup error on load). check_collection_list = graph.get_all_collection_keys() for collection_type in check_collection_list: for element in graph.get_collection(collection_type): if isinstance(element, variables.PartitionedVariable): try: graph.get_operation_by_name(element.name) except KeyError: # Create a concat op for this PartitionedVariable. The user may # not need it, but we'll try looking it up on MetaGraph restore # since it's in a collection. element.as_tensor() return saver_pb2.SaverDef( filename_tensor_name=filename_tensor.name, save_tensor_name=save_tensor.name, restore_op_name=restore_op.name, max_to_keep=max_to_keep, sharded=sharded, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, version=self._write_version)