def test_save_checkpoint_keep_interval_timedelta(self, checkpoint_type): config_name = 'test.test_module.ConfigName' root_dir = os.path.join(FLAGS.test_tmpdir, 'test4', str(checkpoint_type), 'checkpoints') tf.io.gfile.makedirs(root_dir) current_datetime = datetime.datetime.now() zero_datetime = datetime.datetime.fromtimestamp(0) with mock.patch('datetime.datetime', autospec=True) as dt: dt.utcnow.return_value = current_datetime dt.fromtimestamp.return_value = zero_datetime checkpoint_manager = checkpoint_managers.CheckpointManager( config_name=config_name, root_dir=root_dir, checkpoint_type=checkpoint_type, save_interval_steps=1000, max_to_keep=2, keep_interval_timedelta=datetime.timedelta(hours=2)) steps = list(range(0, 10000, 1000)) checkpoint_datetimes = [] for step in steps: with mock.patch('datetime.datetime', autospec=True) as dt: dt.utcnow.return_value = current_datetime dt.fromtimestamp.return_value = zero_datetime if checkpoint_manager.should_save(step): _create_dummy_checkpoint(root_dir, step, checkpoint_type) checkpoint_manager.save_metadata(step) checkpoint_datetimes.append(current_datetime) current_datetime += datetime.timedelta(hours=1) saved_steps = [0, 2000, 4000, 6000, 8000, 9000] saved_checkpoint_datetimes = checkpoint_datetimes[::2] + [ checkpoint_datetimes[-1] ] filenames = [ os.path.basename(v) for v in tf.io.gfile.glob( os.path.join(root_dir, f'{CHECKPOINT_PREFIX}*')) ] self.assertSameElements( filenames, _base_checkpoint_filenames(saved_steps, checkpoint_type)) checkpoints_filename = os.path.join( root_dir, checkpoint_managers.CHECKPOINT_BASENAME) expected_proto = _create_reference_checkpoint_history( config_name, root_dir, checkpoint_type, saved_steps, saved_checkpoint_datetimes) self.assertCheckpointsFileProto(checkpoints_filename, expected_proto)
def _create_checkpoint_manager( model_name: str, task_p: InstantiableParams, job_log_dir: str, checkpoint_type: CheckpointType, todelete_subdir: Optional[str] ) -> checkpoint_managers.CheckpointManager: """Creates a checkpoint manager.""" checkpoint_dir = _checkpoint_dir(job_log_dir) train_p = task_p.train max_to_keep = train_p.save_max_to_keep save_interval_steps = train_p.save_interval_steps keep_interval_timedelta = _parse_duration( train_p.save_keep_interval_duration) return checkpoint_managers.CheckpointManager( config_name=model_name, root_dir=checkpoint_dir, checkpoint_type=checkpoint_type, max_to_keep=max_to_keep, save_interval_steps=save_interval_steps, keep_interval_timedelta=keep_interval_timedelta, todelete_subdir=todelete_subdir)