def _get_checkpoint_filename(ckpt_dir_or_file): """Returns checkpoint filename given directory or specific checkpoint file.""" if isinstance(ckpt_dir_or_file, os.PathLike): ckpt_dir_or_file = os.fspath(ckpt_dir_or_file) if gfile.IsDirectory(ckpt_dir_or_file): return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) return ckpt_dir_or_file
def wait_for_new_checkpoint(checkpoint_dir, last_checkpoint=None, seconds_to_sleep=1, timeout=None): """Waits until a new checkpoint file is found. Args: checkpoint_dir: The directory in which checkpoints are saved. last_checkpoint: The last checkpoint path used or `None` if we're expecting a checkpoint for the first time. seconds_to_sleep: The number of seconds to sleep for before looking for a new checkpoint. timeout: The maximum number of seconds to wait. If left as `None`, then the process will wait indefinitely. Returns: a new checkpoint path, or None if the timeout was reached. """ logging.info("Waiting for new checkpoint at %s", checkpoint_dir) stop_time = time.time() + timeout if timeout is not None else None while True: checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) if checkpoint_path is None or checkpoint_path == last_checkpoint: if stop_time is not None and time.time() + seconds_to_sleep > stop_time: return None time.sleep(seconds_to_sleep) else: logging.info("Found new checkpoint at %s", checkpoint_path) return checkpoint_path
def test_spmd_model_checkpointing(self): class LinearModel(module.Module): def __init__(self, w): super(LinearModel, self).__init__() self.w = variables.Variable(w) def __call__(self, x): return math_ops.matmul(x, self.w) def change_weights_op(self, w_new): return self.w.assign(w_new) batch_size = 32 num_feature_in = 16 num_feature_out = 8 w1 = random_ops.random_uniform((num_feature_in, num_feature_out), dtype=dtypes.float32) w2 = random_ops.random_uniform((num_feature_in, num_feature_out), dtype=dtypes.float32) x = random_ops.random_uniform((batch_size, num_feature_in), dtype=dtypes.float32) strategy, num_replicas = get_tpu_strategy(enable_spmd=True) with strategy.scope(): model = LinearModel(w1) checkpoint_dir = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") checkpoint = util.Checkpoint(model=model) @def_function.function def step_fn(x): x = strategy.experimental_split_to_logical_devices(x, [1, 2]) return model(x) with self.cached_session() as sess: self.evaluate(variables.global_variables_initializer()) checkpoint.save(file_prefix=checkpoint_prefix) self.evaluate(model.change_weights_op(w2)) result = strategy.run(step_fn, args=(x,)) self.assertAllClose( math_ops.matmul(x, w2) * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None)), rtol=5e-3, atol=5e-3) status = checkpoint.restore( checkpoint_management.latest_checkpoint(checkpoint_dir)) status.run_restore_ops(sess) # must run restore op in non-eager mode. status.assert_consumed() status.assert_existing_objects_matched() result = strategy.run(step_fn, args=(x,)) self.assertAllClose( math_ops.matmul(x, w1) * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None)), rtol=5e-3, atol=5e-3)
def _restore_or_save_initial_ckpt(self, session): # Ideally this should be run in after_create_session but is not for the # following reason: # Currently there is no way of enforcing an order of running the # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` # is run *after* this hook. That is troublesome because # 1. If a checkpoint exists and this hook restores it, the initializer hook # will override it. # 2. If no checkpoint exists, this hook will try to save an uninitialized # iterator which will result in an exception. # # As a temporary fix we enter the following implicit contract between this # hook and the _DatasetInitializerHook. # 1. The _DatasetInitializerHook initializes the iterator in the call to # after_create_session. # 2. This hook saves the iterator on the first call to `before_run()`, which # is guaranteed to happen after `after_create_session()` of all hooks # have been run. # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access latest_checkpoint_path = checkpoint_management.latest_checkpoint( self._checkpoint_saver_hook._checkpoint_dir, latest_filename=self._latest_filename) if latest_checkpoint_path: self._checkpoint_saver_hook._get_saver().restore( session, latest_checkpoint_path) else: # The checkpoint saved here is the state at step "global_step". # Note: We do not save the GraphDef or MetaGraphDef here. global_step = session.run( self._checkpoint_saver_hook._global_step_tensor) self._checkpoint_saver_hook._save(session, global_step) self._checkpoint_saver_hook._timer.update_last_triggered_step( global_step)
def testLatestCheckpointFSpathDirectory(self): directory = pathlib.Path(self.get_temp_dir()) checkpoint = util.Checkpoint() manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name") manager.save() cp_dir = checkpoint_management.latest_checkpoint(directory) self.assertEqual(str(directory / "ckpt_name-1"), cp_dir)
def _read_vars(self, model_dir): """Returns (global_step, latest_feature).""" with ops.Graph().as_default() as g: ckpt_path = checkpoint_management.latest_checkpoint(model_dir) meta_filename = ckpt_path + '.meta' saver_lib.import_meta_graph(meta_filename) saver = saver_lib.Saver() with self.session(graph=g) as sess: saver.restore(sess, ckpt_path) return sess.run(ops.get_collection('my_vars'))
def test_paritioned_model_checkpointing(self): class PartitionedModel(module.Module): def __init__(self, v, w): super(PartitionedModel, self).__init__() assert distribution_strategy_context.has_strategy() strategy = distribution_strategy_context.get_strategy() with strategy.extended.experimental_logical_device(0): self.v = variables.Variable(v) with strategy.extended.experimental_logical_device(1): self.w = variables.Variable(w) def __call__(self, x): replica_ctx = distribution_strategy_context.get_replica_context( ) with replica_ctx.experimental_logical_device(0): y = self.v * x with replica_ctx.experimental_logical_device(1): z = self.w * y return z def change_weights_op(self, v_new, w_new): return control_flow_ops.group( [self.v.assign(v_new), self.w.assign(w_new)]) strategy, num_replicas = get_tpu_strategy() with strategy.scope(): model = PartitionedModel(2., 3.) checkpoint_dir = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") checkpoint = util.Checkpoint(model=model) with self.cached_session() as sess: self.evaluate(variables.global_variables_initializer()) checkpoint.save(file_prefix=checkpoint_prefix) self.evaluate(model.change_weights_op(1., 4.)) result = strategy.run(def_function.function(model), args=(5.0, )) self.assertEqual( 20. * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None))) status = checkpoint.restore( checkpoint_management.latest_checkpoint(checkpoint_dir)) status.run_restore_ops( sess) # must run restore op in non-eager mode. status.assert_consumed() status.assert_existing_objects_matched() result = strategy.run(def_function.function(model), args=(5.0, )) self.assertEqual( 30. * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None)))
def testRestoreInReconstructedIteratorInitializable(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) get_next = iterator.get_next checkpoint = trackable_utils.Checkpoint(iterator=iterator) for i in range(5): checkpoint.restore( checkpoint_management.latest_checkpoint( checkpoint_directory)).initialize_or_restore() for j in range(2): self.assertEqual(i * 2 + j, self.evaluate(get_next())) checkpoint.save(file_prefix=checkpoint_prefix)
def testNameCollision(self): # Make sure we have a clean directory to work in. with self.tempDir() as tempdir: # Jump to that directory until this test is done. with self.tempWorkingDir(tempdir): # Save training snapshots to a relative path. traindir = "train" os.mkdir(traindir) # Collides with the default name of the checkpoint state file. filepath = os.path.join(traindir, "checkpoint") with self.cached_session() as sess: unused_a = variables.Variable( 0.0) # So that Saver saves something. self.evaluate(variables.global_variables_initializer()) # Should fail. saver = saver_module.Saver(sharded=False) with self.assertRaisesRegex(ValueError, "collides with"): saver.save(sess, filepath) # Succeeds: the file will be named "checkpoint-<step>". saver.save(sess, filepath, global_step=1) self.assertIsNotNone( checkpoint_management.latest_checkpoint(traindir)) # Succeeds: the file will be named "checkpoint-<i>-of-<n>". saver = saver_module.Saver(sharded=True) saver.save(sess, filepath) self.assertIsNotNone( checkpoint_management.latest_checkpoint(traindir)) # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>". saver = saver_module.Saver(sharded=True) saver.save(sess, filepath, global_step=1) self.assertIsNotNone( checkpoint_management.latest_checkpoint(traindir))
def testRelativePath(self): # Make sure we have a clean directory to work in. with self.tempDir() as tempdir: # Jump to that directory until this test is done. with self.tempWorkingDir(tempdir): # Save training snapshots to a relative path. traindir = "train" os.mkdir(traindir) filename = "snapshot" filepath = os.path.join(traindir, filename) with self.cached_session() as sess: # Build a simple graph. v0 = variables.Variable(0.0) inc = v0.assign_add(1.0) save = saver_module.Saver({"v0": v0}) # Record a short training history. self.evaluate(variables.global_variables_initializer()) save.save(sess, filepath, global_step=0) self.evaluate(inc) save.save(sess, filepath, global_step=1) self.evaluate(inc) save.save(sess, filepath, global_step=2) with self.cached_session() as sess: # Build a new graph with different initialization. v0 = variables.Variable(-1.0) # Create a new saver. save = saver_module.Saver({"v0": v0}) self.evaluate(variables.global_variables_initializer()) # Get the most recent checkpoint name from the training history file. name = checkpoint_management.latest_checkpoint(traindir) self.assertIsNotNone(name) # Restore "v0" from that checkpoint. save.restore(sess, name) self.assertEqual(v0.eval(), 2.0)
def testCheckpointExists(self): for sharded in (False, True): for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): with self.session(graph=ops_lib.Graph()) as sess: unused_v = variables.Variable(1.0, name="v") self.evaluate(variables.global_variables_initializer()) saver = saver_module.Saver(sharded=sharded, write_version=version) path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) self.assertFalse( checkpoint_management.checkpoint_exists( path)) # Not saved yet. ckpt_prefix = saver.save(sess, path) self.assertTrue( checkpoint_management.checkpoint_exists(ckpt_prefix)) ckpt_prefix = checkpoint_management.latest_checkpoint( self._base_dir) self.assertTrue( checkpoint_management.checkpoint_exists(ckpt_prefix))
def testCustomNumbering(self): directory = self.get_temp_dir() step = variables.Variable(0, dtype=dtypes.int64) checkpoint = util.Checkpoint(step=step) manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=2) self.evaluate(step.initializer) for i in range(5): path = manager.save(checkpoint_number=step) expected_suffix = "-%d" % (2 * i, ) if not path.endswith(expected_suffix): self.fail("%s should have suffix %s" % (path, expected_suffix)) self.evaluate(step.assign_add(2)) self.assertEqual(5, self.evaluate(checkpoint.save_counter)) # Test regular integers last_path = manager.save(checkpoint_number=32) self.assertIn("-32", last_path) self.assertEqual(last_path, manager.latest_checkpoint) self.assertEqual(last_path, checkpoint_management.latest_checkpoint(directory)) state = checkpoint_management.get_checkpoint_state(directory) # Only the most recent two checkpoints are saved self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
def _latest_ckpt(self): return checkpoint_management.latest_checkpoint(self.get_temp_dir())