def testContinueFromUnmanaged(self): directory = self.get_temp_dir() prefix = os.path.join(directory, "unusual_prefix") checkpoint = util.Checkpoint() first_path = checkpoint.save(prefix) second_path = checkpoint.save(prefix) del checkpoint checkpoint = util.Checkpoint() manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2) checkpoint.restore(manager.latest_checkpoint).run_restore_ops() self.assertEqual(2, self.evaluate(checkpoint.save_counter)) third_path = manager.save() self.assertEqual([third_path], manager.checkpoints) fourth_path = manager.save() self.assertEqual([third_path, fourth_path], manager.checkpoints) fifth_path = manager.save() self.assertEqual([fourth_path, fifth_path], manager.checkpoints) self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertFalse(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
def testDeletion(self): checkpoint = util.Checkpoint() manager = checkpoint_management.CheckpointManager( checkpoint, self.get_temp_dir(), max_to_keep=3) first_path = manager.save() second_path = manager.save() third_path = manager.save() fourth_path = manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
def testRemoveCheckpoint(self): for sharded in (False, True): for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): with self.session(graph=ops_lib.Graph()) as sess: unused_v = variables.Variable(1.0, name="v") variables.global_variables_initializer().run() saver = saver_module.Saver(sharded=sharded, write_version=version) path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) ckpt_prefix = saver.save(sess, path) self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) checkpoint_management.remove_checkpoint(ckpt_prefix, version) self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
def testRemoveCheckpoint(self): for sharded in (False, True): for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): with self.test_session(graph=ops_lib.Graph()) as sess: unused_v = variables.Variable(1.0, name="v") variables.global_variables_initializer().run() saver = saver_module.Saver(sharded=sharded, write_version=version) path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) ckpt_prefix = saver.save(sess, path) self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) checkpoint_management.remove_checkpoint(ckpt_prefix, version) self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
def testClockReset(self, mock_time): directory = self.get_temp_dir() mock_time.time.return_value = 10000. checkpoint = util.Checkpoint() first_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.) first_path = first_manager.save() mock_time.time.return_value += 3600. second_path = first_manager.save() mock_time.time.return_value += 3600. third_path = first_manager.save() self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual([third_path], first_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(13600., state.last_preserved_timestamp) # Set the clock back in time mock_time.time.return_value = 5000. del first_manager with test.mock.patch.object(logging, "warning") as mock_log: second_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=1) self.assertRegexpMatches( str(mock_log.call_args), "behind the last preserved checkpoint timestamp") # We should err on the side of keeping checkpoints around when we're not # sure whether they were preserved or not due to clock funkiness. self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) # We know about the existing checkpoints, but they'll never be deleted and # so won't go in the CheckpointState proto on save. self.assertEqual(third_path, second_manager.latest_checkpoint) self.assertEqual([], second_manager.checkpoints) mock_time.time.return_value += 10. fourth_path = second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual(fourth_path, second_manager.latest_checkpoint) self.assertEqual([fourth_path], second_manager.checkpoints) mock_time.time.return_value += 10. fifth_path = second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual([fifth_path], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(5000., state.last_preserved_timestamp) self.assertEqual([5020.], state.all_model_checkpoint_timestamps)
def testCheckpointInterval(self): v = variables.Variable(1.0) step_counter = variables.Variable(0) self.evaluate([v.initializer, step_counter.initializer]) checkpoint = util.Checkpoint(v=v) manager = checkpoint_management.CheckpointManager( checkpoint, self.get_temp_dir(), max_to_keep=None, step_counter=step_counter, checkpoint_interval=2) # step_counter: 0, save an initial checkpoint. path = manager.save(check_interval=True) self.assertTrue(checkpoint_management.checkpoint_exists(path)) # step_counter: 1, no checkpoint saved. self.evaluate(step_counter.assign_add(1)) path = manager.save(check_interval=True) self.assertIsNone(path) # step_counter: 2, checkpoint saved. self.evaluate(step_counter.assign_add(1)) path = manager.save(check_interval=True) self.assertTrue(checkpoint_management.checkpoint_exists(path)) # no checkpoint saved when calling `save` with the same step counter. path = manager.save(check_interval=True) self.assertIsNone(path) # step_counter: 3, no checkpoint saved. self.evaluate(step_counter.assign_add(1)) path = manager.save(check_interval=True) self.assertIsNone(path) # Always save the checkpoint. path = manager.save(check_interval=False) self.assertTrue(checkpoint_management.checkpoint_exists(path))
def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True): """Wait for a checkpoint file to appear. Args: pattern: A string. timeout_secs: How long to wait for in seconds. for_checkpoint: whether we're globbing for checkpoints. """ end_time = time.time() + timeout_secs while time.time() < end_time: if for_checkpoint: if checkpoint_management.checkpoint_exists(pattern): return else: if len(gfile.Glob(pattern)) >= 1: return time.sleep(0.05) self.assertFalse(True, "Glob never matched any file: %s" % pattern)
def deploy_checkpoint(input_meta_graph_def, input_checkpoint, q_config): if not checkpoint_management.checkpoint_exists(input_checkpoint): raise ValueError("Input checkpoint '" + input_checkpoint + "' does not exits.") if gfile.IsDirectory(input_checkpoint): input_checkpoint = checkpoint_management.latest_checkpoint( input_checkpoint) if not os.path.exists(q_config.output_dir): os.makedirs(q_config.output_dir) quantize_eval_graph_def = None if input_meta_graph_def: quantize_eval_graph_def = input_meta_graph_def.graph_def else: raise ValueError("You need to provide a `MetaGraphDef` for deploy.") q_config.output_nodes = get_quantized_nodes(quantize_eval_graph_def, q_config.output_nodes) saver = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) with Session() as sess: saver.restore(sess, input_checkpoint) frozen_graph_def = graph_util.convert_variables_to_constants( sess, quantize_eval_graph_def, q_config.output_nodes) if not os.path.exists(os.path.join(q_config.output_dir, "deploy")): os.makedirs(os.path.join(q_config.output_dir, "deploy")) quantize_deploy_graph_def = CreateQuantizeDeployGraphDef( frozen_graph_def, q_config) save_pb_file(quantize_deploy_graph_def, os.path.join(q_config.output_dir, "deploy/deploy_model.pb")) print("INFO: Quantize deploy graph are generated in: {}".format( os.path.join(q_config.output_dir, "deploy"))) return
def patched_restore(self, sess, save_path, options=None): # type: ignore """ Restores previously saved variables. This method runs the ops added by the constructor for restoring variables. It requires a session in which the graph was launched. The variables to restore do not have to have been initialized, as restoring is itself a way to initialize variables. The `save_path` argument is typically a value previously returned from a `save()` call, or a call to `latest_checkpoint()`. Args: sess: A `Session` to use to restore the parameters. None in eager mode. save_path: Path where parameters were previously saved. Raises: ValueError: If save_path is None or not a valid checkpoint. """ if self._is_empty: return if save_path is None: raise ValueError("Can't load save_path when it is None.") checkpoint_prefix = compat.as_text(save_path) if not checkpoint_management.checkpoint_exists(checkpoint_prefix): raise ValueError("The passed save_path is not a valid checkpoint: " + checkpoint_prefix) logging.info("Restoring parameters from %s", checkpoint_prefix) try: if context.executing_eagerly(): self._build_eager(save_path, build_save=False, build_restore=True) else: sess.run( self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}, options=options, ) except errors.NotFoundError as err: # There are three common conditions that might cause this error: # 0. The file is missing. We ignore here, as this is checked above. # 1. This is an object-based checkpoint trying name-based loading. # 2. The graph has been altered and a variable or other name is missing. # 1. The checkpoint would not be loaded successfully as is. Try to parse # it as an object-based checkpoint. try: names_to_keys = object_graph_key_mapping(save_path) except errors.NotFoundError: # 2. This is not an object-based checkpoint, which likely means there # is a graph mismatch. Re-raise the original error with # a helpful message (b/110263146) raise _wrap_restore_error_with_msg( err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do # the restore. logging.warning( "Restoring an object-based checkpoint using a name-based saver. This " "may be somewhat fragile, and will re-build the Saver. Instead, " "consider loading object-based checkpoints using " "tf.train.Checkpoint().") self._object_restore_saver = saver_from_object_based_checkpoint( checkpoint_path=save_path, var_list=self._var_list, builder=self._builder, names_to_keys=names_to_keys, cached_saver=self._object_restore_saver, ) self._object_restore_saver.restore(sess=sess, save_path=save_path, options=options) except errors.InvalidArgumentError as err: # There is a mismatch between the graph and the checkpoint being loaded. # We add a more reasonable error message here to help users (b/110263146) raise _wrap_restore_error_with_msg( err, "a mismatch between the current graph and the graph")
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_denylist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants. Args: input_graph_def: A `GraphDef`. input_saver_def: A `SaverDef` (optional). input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen `GraphDef`. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated string of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted). variable_names_denylist: The set of variable names to omit converting to constants (optional). input_meta_graph_def: A `MetaGraphDef` (optional), input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format (optional). checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2) Returns: Location of the output_graph_def. """ del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): raise ValueError("Input checkpoint '" + input_checkpoint + "' doesn't exist!") if not output_node_names: raise ValueError( "You need to supply the name of a node to --output_node_names.") # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_partition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_partition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: raise ValueError( "Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") # Models that have been frozen previously do not contain Variables. elif _has_no_variables(sess): raise ValueError( "No variables were found in this model. It is likely the model " "was frozen previously. You cannot freeze a graph twice." ) return 0 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = (variable_names_whitelist.replace( " ", "").split(",") if variable_names_whitelist else None) variable_names_denylist = (variable_names_denylist.replace( " ", "").split(",") if variable_names_denylist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_denylist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_denylist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def testKeepAll(self): checkpoint = util.Checkpoint() directory = os.path.join( self.get_temp_dir(), # Avoid sharing directories between eager and graph # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories str(context.executing_eagerly())) manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=None) first_path = manager.save() second_path = manager.save() third_path = manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) self.assertEqual(third_path, manager.latest_checkpoint) self.assertEqual([first_path, second_path, third_path], manager.checkpoints) del manager manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=None) fourth_path = manager.save() self.assertEqual([first_path, second_path, third_path, fourth_path], manager.checkpoints) del manager manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=3) self.assertEqual([first_path, second_path, third_path, fourth_path], manager.checkpoints) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) fifth_path = manager.save() self.assertEqual([third_path, fourth_path, fifth_path], manager.checkpoints) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertFalse(checkpoint_management.checkpoint_exists(second_path)) self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
def testSaveRestoreState(self, mock_time): directory = self.get_temp_dir() mock_time.time.return_value = 3. checkpoint = util.Checkpoint() first_manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=2) first_time = 10000. first_name = os.path.join(directory, "ckpt-1") mock_time.time.return_value = first_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) second_time = first_time + 3610. second_name = os.path.join(directory, "ckpt-2") mock_time.time.return_value = second_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual([first_time, second_time], state.all_model_checkpoint_timestamps) self.assertEqual([first_name, second_name], first_manager.checkpoints) self.assertEqual(second_name, first_manager.latest_checkpoint) del first_manager second_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual([first_name, second_name], second_manager.checkpoints) self.assertEqual(second_name, second_manager.latest_checkpoint) third_name = os.path.join(directory, "ckpt-3") third_time = second_time + 3600. * 0.2 mock_time.time.return_value = third_time second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([second_name, third_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) fourth_time = third_time + 3600. * 0.5 mock_time.time.return_value = fourth_time fourth_name = os.path.join(directory, "ckpt-4") second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([third_name, fourth_name], second_manager.checkpoints) fifth_time = fourth_time + 3600. * 0.5 mock_time.time.return_value = fifth_time fifth_name = os.path.join(directory, "ckpt-5") second_manager.save() self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) del second_manager third_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual(fifth_name, third_manager.latest_checkpoint) mock_time.time.return_value += 10. third_manager.save() sixth_name = os.path.join(directory, "ckpt-6") state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(fourth_time, state.last_preserved_timestamp) self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)
def testKeepAll(self): checkpoint = util.Checkpoint() directory = os.path.join( self.get_temp_dir(), # Avoid sharing directories between eager and graph # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories str(context.executing_eagerly())) manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=None) first_path = manager.save() second_path = manager.save() third_path = manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) self.assertEqual(third_path, manager.latest_checkpoint) self.assertEqual([first_path, second_path, third_path], manager.checkpoints) del manager manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=None) fourth_path = manager.save() self.assertEqual([first_path, second_path, third_path, fourth_path], manager.checkpoints) del manager manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=3) self.assertEqual([first_path, second_path, third_path, fourth_path], manager.checkpoints) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) fifth_path = manager.save() self.assertEqual([third_path, fourth_path, fifth_path], manager.checkpoints) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertFalse(checkpoint_management.checkpoint_exists(second_path)) self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
def testSaveRestoreState(self, mock_time): directory = self.get_temp_dir() mock_time.time.return_value = 3. checkpoint = util.Checkpoint() first_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2) first_time = 10000. first_name = os.path.join(directory, "ckpt-1") mock_time.time.return_value = first_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) second_time = first_time + 3610. second_name = os.path.join(directory, "ckpt-2") mock_time.time.return_value = second_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual([first_time, second_time], state.all_model_checkpoint_timestamps) self.assertEqual([first_name, second_name], first_manager.checkpoints) self.assertEqual(second_name, first_manager.latest_checkpoint) del first_manager second_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual([first_name, second_name], second_manager.checkpoints) self.assertEqual(second_name, second_manager.latest_checkpoint) third_name = os.path.join(directory, "ckpt-3") third_time = second_time + 3600. * 0.2 mock_time.time.return_value = third_time second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([second_name, third_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) fourth_time = third_time + 3600. * 0.5 mock_time.time.return_value = fourth_time fourth_name = os.path.join(directory, "ckpt-4") second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([third_name, fourth_name], second_manager.checkpoints) fifth_time = fourth_time + 3600. * 0.5 mock_time.time.return_value = fifth_time fifth_name = os.path.join(directory, "ckpt-5") second_manager.save() self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) del second_manager third_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual(fifth_name, third_manager.latest_checkpoint) mock_time.time.return_value += 10. third_manager.save() sixth_name = os.path.join(directory, "ckpt-6") state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(fourth_time, state.last_preserved_timestamp) self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_parition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_parition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: print( "Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") return -1 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = (variable_names_whitelist.replace( " ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = (variable_names_blacklist.replace( " ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants. Args: input_graph_def: A `GraphDef`. input_saver_def: A `SaverDef` (optional). input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen `GraphDef`. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated string of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants (optional). input_meta_graph_def: A `MetaGraphDef` (optional), input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format (optional). checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2) Returns: Location of the output_graph_def. """ del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_parition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_parition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver( var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: print("Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") return -1 # Models that have been frozen previously do not contain Variables. elif _has_no_variables(sess): print("No variables were found in this model. It is likely the model " "was frozen previously. You cannot freeze a graph twice.") return 0 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = ( variable_names_whitelist.replace(" ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = ( variable_names_blacklist.replace(" ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_parition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_parition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver( var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: print("Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") return -1 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = ( variable_names_whitelist.replace(" ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = ( variable_names_blacklist.replace(" ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def