Exemplo n.º 1
0
    def test_save_checkpoint_keep_interval_timedelta(self, checkpoint_type):
        config_name = 'test.test_module.ConfigName'
        root_dir = os.path.join(FLAGS.test_tmpdir, 'test4',
                                str(checkpoint_type), 'checkpoints')
        tf.io.gfile.makedirs(root_dir)
        current_datetime = datetime.datetime.now()
        zero_datetime = datetime.datetime.fromtimestamp(0)
        with mock.patch('datetime.datetime', autospec=True) as dt:
            dt.utcnow.return_value = current_datetime
            dt.fromtimestamp.return_value = zero_datetime
            checkpoint_manager = checkpoint_managers.CheckpointManager(
                config_name=config_name,
                root_dir=root_dir,
                checkpoint_type=checkpoint_type,
                save_interval_steps=1000,
                max_to_keep=2,
                keep_interval_timedelta=datetime.timedelta(hours=2))

        steps = list(range(0, 10000, 1000))
        checkpoint_datetimes = []
        for step in steps:
            with mock.patch('datetime.datetime', autospec=True) as dt:
                dt.utcnow.return_value = current_datetime
                dt.fromtimestamp.return_value = zero_datetime
                if checkpoint_manager.should_save(step):
                    _create_dummy_checkpoint(root_dir, step, checkpoint_type)
                    checkpoint_manager.save_metadata(step)
                checkpoint_datetimes.append(current_datetime)
                current_datetime += datetime.timedelta(hours=1)

        saved_steps = [0, 2000, 4000, 6000, 8000, 9000]
        saved_checkpoint_datetimes = checkpoint_datetimes[::2] + [
            checkpoint_datetimes[-1]
        ]
        filenames = [
            os.path.basename(v) for v in tf.io.gfile.glob(
                os.path.join(root_dir, f'{CHECKPOINT_PREFIX}*'))
        ]
        self.assertSameElements(
            filenames, _base_checkpoint_filenames(saved_steps,
                                                  checkpoint_type))
        checkpoints_filename = os.path.join(
            root_dir, checkpoint_managers.CHECKPOINT_BASENAME)
        expected_proto = _create_reference_checkpoint_history(
            config_name, root_dir, checkpoint_type, saved_steps,
            saved_checkpoint_datetimes)
        self.assertCheckpointsFileProto(checkpoints_filename, expected_proto)
Exemplo n.º 2
0
def _create_checkpoint_manager(
        model_name: str, task_p: InstantiableParams, job_log_dir: str,
        checkpoint_type: CheckpointType, todelete_subdir: Optional[str]
) -> checkpoint_managers.CheckpointManager:
    """Creates a checkpoint manager."""
    checkpoint_dir = _checkpoint_dir(job_log_dir)
    train_p = task_p.train
    max_to_keep = train_p.save_max_to_keep
    save_interval_steps = train_p.save_interval_steps
    keep_interval_timedelta = _parse_duration(
        train_p.save_keep_interval_duration)
    return checkpoint_managers.CheckpointManager(
        config_name=model_name,
        root_dir=checkpoint_dir,
        checkpoint_type=checkpoint_type,
        max_to_keep=max_to_keep,
        save_interval_steps=save_interval_steps,
        keep_interval_timedelta=keep_interval_timedelta,
        todelete_subdir=todelete_subdir)