Ejemplo n.º 1
0
def test_single_model():
    if sys.platform.startswith('win'):
        pytest.skip(
            'this doctest does not work on Windows, '
            'training is not possible on Windows due to symlinks being unavailable'
        )

    tr_dataset, dt_dataset = get_dataset()
    tr_dataset = tr_dataset[:2]
    dt_dataset = dt_dataset[:2]

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)

        config = pt.Trainer.get_config(updates=pb.utils.nested.deflatten(
            {
                'model.factory': Model,
                'storage_dir': str(tmp_dir),
                'stop_trigger': (2, 'epoch'),
                'summary_trigger': (3, 'iteration'),
                'checkpoint_trigger': (2, 'iteration')
            }))

        t = pt.Trainer.from_config(config)
        pre_state_dict = copy.deepcopy(t.state_dict())

        files_before = tuple(tmp_dir.glob('*'))
        if len(files_before) != 0:
            # no event file
            raise Exception(files_before)

        t.register_validation_hook(validation_iterator=dt_dataset,
                                   max_checkpoints=None)

        # Wrap each trigger in each hook with TriggerMock.
        log_list = []
        for hook in t.hooks:
            for k, v in list(hook.__dict__.items()):
                if isinstance(v, pt.train.trigger.Trigger):
                    hook.__dict__[k] = TriggerMock(v, log_list)
        t.train(train_dataset=tr_dataset, resume=False)

        hook_calls = ('\n'.join(log_list))

        # CheckpointedValidationHook trigger is called two times
        #   (once for checkpointing once for validation)_file_name

        hook_calls_ref = textwrap.dedent('''
        I:0, E: 0, True, SummaryHook.pre_step
        I:0, E: 0, True, BackOffValidationHook.pre_step
        I:0, E: 0, True, CheckpointHook.pre_step
        I:0, E: 0, False, StopTrainingHook.pre_step
        I:1, E: 0, False, SummaryHook.pre_step
        I:1, E: 0, False, BackOffValidationHook.pre_step
        I:1, E: 0, False, CheckpointHook.pre_step
        I:1, E: 0, False, StopTrainingHook.pre_step
        I:2, E: 1, False, SummaryHook.pre_step
        I:2, E: 1, True, BackOffValidationHook.pre_step
        I:2, E: 1, True, CheckpointHook.pre_step
        I:2, E: 1, False, StopTrainingHook.pre_step
        I:3, E: 1, True, SummaryHook.pre_step
        I:3, E: 1, False, BackOffValidationHook.pre_step
        I:3, E: 1, False, CheckpointHook.pre_step
        I:3, E: 1, False, StopTrainingHook.pre_step
        I:4, E: 2, False, SummaryHook.pre_step
        I:4, E: 2, True, BackOffValidationHook.pre_step
        I:4, E: 2, True, CheckpointHook.pre_step
        I:4, E: 2, True, StopTrainingHook.pre_step
        ''').strip()

        print('#' * 80)
        print(hook_calls)
        print('#' * 80)

        if hook_calls != hook_calls_ref:
            import difflib
            raise AssertionError('\n' + ('\n'.join(
                difflib.ndiff(
                    hook_calls_ref.splitlines(),
                    hook_calls.splitlines(),
                ))))

        old_event_files = []

        files_after = tuple(tmp_dir.glob('*'))
        assert len(files_after) == 2, files_after
        for file in sorted(files_after):
            if 'tfevents' in file.name:
                old_event_files.append(file)
                events = list(load_events_as_dict(file))

                tags = []
                # time_rel_data_loading = []
                # time_rel_train_step = []
                time_per_iteration = []

                relative_timings = collections.defaultdict(list)
                relative_timing_keys = {
                    'training_timings/time_rel_data_loading',
                    'training_timings/time_rel_to_device',
                    'training_timings/time_rel_forward',
                    'training_timings/time_rel_review',
                    'training_timings/time_rel_backward',
                    'training_timings/time_rel_optimize',
                }
                for event in events:
                    if 'summary' in event.keys():
                        value, = event['summary']['value']
                        tags.append(value['tag'])
                        if value['tag'] in relative_timing_keys:
                            relative_timings[value['tag']].append(
                                value['simple_value'])
                        elif value[
                                'tag'] == 'training_timings/time_per_iteration':
                            time_per_iteration.append(value['simple_value'])

                c = dict(collections.Counter(tags))
                # Training summary is written two times (at iteration 3 when
                #   summary_trigger triggers and when training stops and
                #   summary_hook is closed).
                # Validation summary is written when checkpoint_trigger
                #   triggers, hence 3 times.
                #   non_validation_time can only be measured between
                #   validations => 2 values (one fewer than validation_time)
                expect = {
                    'training/grad_norm': 2,
                    'training/grad_norm_': 2,
                    'training/loss': 2,
                    'training/lr/param_group_0': 2,
                    'training_timings/time_per_iteration': 2,
                    'training_timings/time_rel_to_device': 2,
                    'training_timings/time_rel_forward': 2,
                    'training_timings/time_rel_review': 2,
                    'training_timings/time_rel_backward': 2,
                    'training_timings/time_rel_optimize': 2,
                    'training_timings/time_rel_data_loading': 2,
                    # 'training_timings/time_rel_step': 2,
                    'validation/loss': 3,
                    'validation_timings/time_per_iteration': 3,
                    'validation_timings/time_rel_to_device': 3,
                    'validation_timings/time_rel_forward': 3,
                    'validation_timings/time_rel_review': 3,
                    'validation_timings/time_rel_data_loading': 3,
                    # 'validation_timings/time_rel_step': 3,
                    # non validation time can only be measured between
                    # validations:
                    #  => # of non_val_time - 1 == # of val_time
                    'validation_timings/non_validation_time': 2,
                    'validation_timings/validation_time': 3,
                }
                pprint(c)
                if c != expect:
                    import difflib

                    raise AssertionError('\n' + ('\n'.join(
                        difflib.ndiff(
                            [
                                f'{k!r}: {v!r}'
                                for k, v in sorted(expect.items())
                            ],
                            [f'{k!r}: {v!r}' for k, v in sorted(c.items())],
                        ))))
                assert len(events) == 46, (len(events), events)

                assert relative_timing_keys == set(
                    relative_timings.keys()), (relative_timing_keys,
                                               relative_timings)

                for k, v in relative_timings.items():
                    assert len(v) > 0, (k, v, relative_timings)

                # The relative timings should sum up to one,
                # but this model is really cheap.
                # e.g. 0.00108 and 0.000604 per iteration.
                # This may cause the mismatch.
                # Allow a calculation error of 25%.
                # ToDo: Get this work with less than 1% error.
                relative_times = np.array(list(
                    relative_timings.values())).sum(axis=0)
                if not np.all(relative_times > 0.75):
                    raise AssertionError(
                        pretty((relative_times, time_per_iteration,
                                dict(relative_timings))))
                if not np.all(relative_times <= 1):
                    raise AssertionError(
                        pretty((relative_times, time_per_iteration,
                                dict(relative_timings))))

            elif file.name == 'checkpoints':
                checkpoints_files = tuple(file.glob('*'))
                assert len(checkpoints_files) == 5, checkpoints_files
                checkpoints_files_name = [f.name for f in checkpoints_files]
                expect = {
                    'ckpt_0.pth', 'ckpt_2.pth', 'ckpt_4.pth',
                    'ckpt_best_loss.pth', 'ckpt_latest.pth'
                }
                assert expect == set(checkpoints_files_name), (
                    expect, checkpoints_files_name)
                ckpt_ranking = torch.load(
                    str(file / 'ckpt_latest.pth'
                        ))['hooks']['BackOffValidationHook']['ckpt_ranking']
                assert ckpt_ranking[0][1] > 0, ckpt_ranking
                for i, ckpt in enumerate(ckpt_ranking):
                    ckpt_ranking[i] = (ckpt[0], -1)
                expect = [(f'ckpt_{i}.pth', -1) for i in [0, 2, 4]]
                assert ckpt_ranking == expect, (ckpt_ranking, expect)

                for symlink in [
                        file / 'ckpt_latest.pth',
                        file / 'ckpt_best_loss.pth',
                ]:
                    assert symlink.is_symlink(), symlink

                    target = os.readlink(str(symlink))
                    if '/' in target:
                        raise AssertionError(
                            f'The symlink {symlink} contains a "/".\n'
                            f'Expected that the symlink has a ralative target,\n'
                            f'but the target is: {target}')
            else:
                raise ValueError(file)

        post_state_dict = copy.deepcopy(t.state_dict())
        assert pre_state_dict.keys() == post_state_dict.keys()

        equal_amount = {
            key: (pt.utils.to_numpy(parameter_pre) == pt.utils.to_numpy(
                post_state_dict['model'][key])).mean()
            for key, parameter_pre in pre_state_dict['model'].items()
        }

        # ToDo: why are so many weights unchanged? Maybe the zeros in the image?
        assert equal_amount == {'l.bias': 0.0, 'l.weight': 0.6900510204081632}

        import time
        # tfevents use unixtime as unique indicator. Sleep 2 seconds to ensure
        # new value
        time.sleep(2)

        config['stop_trigger'] = (4, 'epoch')
        t = pt.Trainer.from_config(config)
        t.register_validation_hook(validation_iterator=dt_dataset,
                                   max_checkpoints=None)
        log_list = []
        for hook in t.hooks:
            for k, v in list(hook.__dict__.items()):
                if isinstance(v, pt.train.trigger.Trigger):
                    hook.__dict__[k] = TriggerMock(v, log_list)
        t.train(train_dataset=tr_dataset, resume=True)

        hook_calls = ('\n'.join(log_list))

        hook_calls_ref = textwrap.dedent('''
        I:4, E: 2, False, SummaryHook.pre_step
        I:4, E: 2, False, BackOffValidationHook.pre_step
        I:4, E: 2, False, CheckpointHook.pre_step
        I:4, E: 2, False, StopTrainingHook.pre_step
        I:5, E: 2, False, SummaryHook.pre_step
        I:5, E: 2, False, BackOffValidationHook.pre_step
        I:5, E: 2, False, CheckpointHook.pre_step
        I:5, E: 2, False, StopTrainingHook.pre_step
        I:6, E: 3, True, SummaryHook.pre_step
        I:6, E: 3, True, BackOffValidationHook.pre_step
        I:6, E: 3, True, CheckpointHook.pre_step
        I:6, E: 3, False, StopTrainingHook.pre_step
        I:7, E: 3, False, SummaryHook.pre_step
        I:7, E: 3, False, BackOffValidationHook.pre_step
        I:7, E: 3, False, CheckpointHook.pre_step
        I:7, E: 3, False, StopTrainingHook.pre_step
        I:8, E: 4, False, SummaryHook.pre_step
        I:8, E: 4, True, BackOffValidationHook.pre_step
        I:8, E: 4, True, CheckpointHook.pre_step
        I:8, E: 4, True, StopTrainingHook.pre_step
        ''').strip()

        print('#' * 80)
        print(hook_calls)
        print('#' * 80)

        if hook_calls != hook_calls_ref:
            import difflib
            raise AssertionError('\n' + ('\n'.join(
                difflib.ndiff(
                    hook_calls_ref.splitlines(),
                    hook_calls.splitlines(),
                ))))

        files_after = tuple(tmp_dir.glob('*'))
        assert len(files_after) == 3, files_after
        for file in sorted(files_after):
            if 'tfevents' in file.name:
                if file in old_event_files:
                    continue

                events = list(load_events_as_dict(file))

                tags = []
                for event in events:
                    if 'summary' in event.keys():
                        value, = event['summary']['value']
                        tags.append(value['tag'])

                c = dict(collections.Counter(tags))
                assert len(events) == 38, (len(events), events)
                expect = {
                    'training/grad_norm': 2,
                    'training/grad_norm_': 2,
                    'training/loss': 2,
                    'training/lr/param_group_0': 2,
                    'training_timings/time_per_iteration': 2,
                    'training_timings/time_rel_to_device': 2,
                    'training_timings/time_rel_forward': 2,
                    'training_timings/time_rel_review': 2,
                    'training_timings/time_rel_backward': 2,
                    'training_timings/time_rel_optimize': 2,
                    'training_timings/time_rel_data_loading': 2,
                    # 'training_timings/time_rel_step': 2,
                    'validation/loss': 2,
                    # 'validation/lr/param_group_0': 2,
                    'validation_timings/time_per_iteration': 2,
                    'validation_timings/time_rel_to_device': 2,
                    'validation_timings/time_rel_forward': 2,
                    'validation_timings/time_rel_review': 2,
                    'validation_timings/time_rel_data_loading': 2,
                    # 'validation_timings/time_rel_step': 2,
                    # non validation time can only be measured between
                    # validations:
                    #  => # of non_val_time - 1 == # of val_time
                    'validation_timings/non_validation_time': 1,
                    'validation_timings/validation_time': 2,
                }
                if c != expect:
                    import difflib

                    raise AssertionError('\n' + ('\n'.join(
                        difflib.ndiff(
                            [
                                f'{k!r}: {v!r}'
                                for k, v in sorted(expect.items())
                            ],
                            [f'{k!r}: {v!r}' for k, v in sorted(c.items())],
                        ))))
            elif file.name == 'checkpoints':
                checkpoints_files = tuple(file.glob('*'))
                assert len(checkpoints_files) == 7, checkpoints_files
                checkpoints_files_name = [f.name for f in checkpoints_files]
                expect = {
                    *[f'ckpt_{i}.pth' for i in [0, 2, 4, 6, 8]],
                    'ckpt_best_loss.pth', 'ckpt_latest.pth'
                }
                assert expect == set(checkpoints_files_name), (
                    expect, checkpoints_files_name)
            else:
                raise ValueError(file)
Ejemplo n.º 2
0
def test_single_model():
    tr_dataset, dt_dataset = get_dataset()
    tr_dataset = tr_dataset[:2]
    dt_dataset = dt_dataset[:2]

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)

        config = pt.Trainer.get_config(updates=pb.utils.nested.deflatten(
            {
                'model.factory': Model,
                'storage_dir': str(tmp_dir),
                'stop_trigger': (2, 'epoch'),
                'summary_trigger': (3, 'iteration'),
                'checkpoint_trigger': (2, 'iteration')
            }))

        t = pt.Trainer.from_config(config)
        pre_state_dict = copy.deepcopy(t.state_dict())

        files_before = tuple(tmp_dir.glob('*'))
        if len(files_before) != 0:
            # no event file
            raise Exception(files_before)

        t.register_validation_hook(validation_iterator=dt_dataset,
                                   max_checkpoints=None)

        # Wrap each trigger in each hook with TriggerMock.
        log_list = []
        for hook in t.hooks:
            for k, v in list(hook.__dict__.items()):
                if isinstance(v, pt.train.trigger.Trigger):
                    hook.__dict__[k] = TriggerMock(v, log_list)
        t.train(train_iterator=tr_dataset, resume=False)

        hook_calls = ('\n'.join(log_list))

        # CheckpointedValidationHook trigger is called two times
        #   (once for checkpointing once for validation)_file_name

        hook_calls_ref = textwrap.dedent('''
        I:0, E: 0, True, SummaryHook.pre_step
        I:0, E: 0, True, CheckpointHook.pre_step
        I:0, E: 0, True, BackOffValidationHook.pre_step
        I:0, E: 0, False, StopTrainingHook.pre_step
        I:1, E: 0, False, SummaryHook.pre_step
        I:1, E: 0, False, CheckpointHook.pre_step
        I:1, E: 0, False, BackOffValidationHook.pre_step
        I:1, E: 0, False, StopTrainingHook.pre_step
        I:2, E: 1, False, SummaryHook.pre_step
        I:2, E: 1, True, CheckpointHook.pre_step
        I:2, E: 1, True, BackOffValidationHook.pre_step
        I:2, E: 1, False, StopTrainingHook.pre_step
        I:3, E: 1, True, SummaryHook.pre_step
        I:3, E: 1, False, CheckpointHook.pre_step
        I:3, E: 1, False, BackOffValidationHook.pre_step
        I:3, E: 1, False, StopTrainingHook.pre_step
        I:4, E: 2, False, SummaryHook.pre_step
        I:4, E: 2, True, CheckpointHook.pre_step
        I:4, E: 2, True, BackOffValidationHook.pre_step
        I:4, E: 2, True, StopTrainingHook.pre_step
        ''').strip()

        print('#' * 80)
        print(hook_calls)
        print('#' * 80)

        if hook_calls != hook_calls_ref:
            import difflib
            raise AssertionError('\n' + ('\n'.join(
                difflib.ndiff(
                    hook_calls_ref.splitlines(),
                    hook_calls.splitlines(),
                ))))

        old_event_files = []

        files_after = tuple(tmp_dir.glob('*'))
        assert len(files_after) == 2, files_after
        for file in sorted(files_after):
            if 'tfevents' in file.name:
                old_event_files.append(file)
                events = list(load_events_as_dict(file))

                tags = []
                time_rel_data_loading = []
                time_rel_train_step = []
                for event in events:
                    if 'summary' in event.keys():
                        value, = event['summary']['value']
                        tags.append(value['tag'])
                        if value[
                                'tag'] == 'training_timings/time_rel_data_loading':
                            time_rel_data_loading.append(value['simple_value'])
                        elif value['tag'] == 'training_timings/time_rel_step':
                            time_rel_train_step.append(value['simple_value'])

                c = dict(collections.Counter(tags))
                # Training summary is written two times (at iteration 3 when
                #   summary_trigger triggers and when training stops and
                #   summary_hook is closed).
                # Validation summary is written when checkpoint_trigger
                #   triggers, hence 3 times.
                #   non_validation_time can only be measured between
                #   validations => 2 values (one fewer than validation_time)
                expect = {
                    'training/grad_norm': 2,
                    'training/grad_norm_': 2,
                    'training/loss': 2,
                    'training/lr/param_group_0': 2,
                    'training_timings/time_per_iteration': 2,
                    'training_timings/time_rel_to_device': 2,
                    'training_timings/time_rel_forward': 2,
                    'training_timings/time_rel_review': 2,
                    'training_timings/time_rel_backward': 2,
                    'training_timings/time_rel_data_loading': 2,
                    'training_timings/time_rel_step': 2,
                    'validation/loss': 3,
                    'validation/lr/param_group_0': 3,
                    'validation_timings/time_per_iteration': 3,
                    'validation_timings/time_rel_to_device': 3,
                    'validation_timings/time_rel_forward': 3,
                    'validation_timings/time_rel_review': 3,
                    'validation_timings/time_rel_backward': 3,
                    'validation_timings/time_rel_data_loading': 3,
                    'validation_timings/time_rel_step': 3,
                    # non validation time can only be measured between
                    # validations:
                    #  => # of non_val_time - 1 == # of val_time
                    'validation_timings/non_validation_time': 2,
                    'validation_timings/validation_time': 3,
                }
                pprint(c)
                assert c == expect, c
                assert len(events) == 55, (len(events), events)

                assert len(time_rel_data_loading) > 0, (time_rel_data_loading,
                                                        time_rel_train_step)
                assert len(time_rel_train_step) > 0, (time_rel_data_loading,
                                                      time_rel_train_step)
                np.testing.assert_allclose(
                    np.add(time_rel_data_loading, time_rel_train_step),
                    1,
                    err_msg=f'{time_rel_data_loading}, {time_rel_train_step})')

            elif file.name == 'checkpoints':
                checkpoints_files = tuple(file.glob('*'))
                assert len(checkpoints_files) == 6, checkpoints_files
                checkpoints_files_name = [f.name for f in checkpoints_files]
                expect = {
                    'ckpt_0.pth', 'ckpt_2.pth', 'ckpt_4.pth',
                    'validation_state.json', 'ckpt_best_loss.pth',
                    'ckpt_latest.pth'
                }
                assert expect == set(checkpoints_files_name), (
                    expect, checkpoints_files_name)
                ckpt_ranking = pb.io.load_json(
                    file / 'validation_state.json')['ckpt_ranking']
                assert ckpt_ranking[0][1] > 0, ckpt_ranking
                for ckpt in ckpt_ranking:
                    ckpt[1] = -1
                expect = [[f'ckpt_{i}.pth', -1] for i in [0, 2, 4]]
                assert ckpt_ranking == expect, (ckpt_ranking, expect)

                for symlink in [
                        file / 'ckpt_latest.pth',
                        file / 'ckpt_best_loss.pth',
                ]:
                    assert symlink.is_symlink(), symlink

                    target = os.readlink(str(symlink))
                    if '/' in target:
                        raise AssertionError(
                            f'The symlink {symlink} contains a "/".\n'
                            f'Expected that the symlink has a ralative target,\n'
                            f'but the target is: {target}')
            else:
                raise ValueError(file)

        post_state_dict = copy.deepcopy(t.state_dict())
        assert pre_state_dict.keys() == post_state_dict.keys()

        equal_amount = {
            key: (pt.utils.to_numpy(parameter_pre) == pt.utils.to_numpy(
                post_state_dict['model'][key])).mean()
            for key, parameter_pre in pre_state_dict['model'].items()
        }

        # ToDo: why are so many weights unchanged? Maybe the zeros in the image?
        assert equal_amount == {'l.bias': 0.0, 'l.weight': 0.6900510204081632}

        import time
        # tfevents use unixtime as unique indicator. Sleep 2 seconds to ensure
        # new value
        time.sleep(2)

        config['stop_trigger'] = (4, 'epoch')
        t = pt.Trainer.from_config(config)
        t.register_validation_hook(validation_iterator=dt_dataset,
                                   max_checkpoints=None)
        log_list = []
        for hook in t.hooks:
            for k, v in list(hook.__dict__.items()):
                if isinstance(v, pt.train.trigger.Trigger):
                    hook.__dict__[k] = TriggerMock(v, log_list)
        t.train(train_iterator=tr_dataset, resume=True)

        hook_calls = ('\n'.join(log_list))

        hook_calls_ref = textwrap.dedent('''
        I:4, E: 2, False, SummaryHook.pre_step
        I:4, E: 2, False, CheckpointHook.pre_step
        I:4, E: 2, False, BackOffValidationHook.pre_step
        I:4, E: 2, False, StopTrainingHook.pre_step
        I:5, E: 2, False, SummaryHook.pre_step
        I:5, E: 2, False, CheckpointHook.pre_step
        I:5, E: 2, False, BackOffValidationHook.pre_step
        I:5, E: 2, False, StopTrainingHook.pre_step
        I:6, E: 3, True, SummaryHook.pre_step
        I:6, E: 3, True, CheckpointHook.pre_step
        I:6, E: 3, True, BackOffValidationHook.pre_step
        I:6, E: 3, False, StopTrainingHook.pre_step
        I:7, E: 3, False, SummaryHook.pre_step
        I:7, E: 3, False, CheckpointHook.pre_step
        I:7, E: 3, False, BackOffValidationHook.pre_step
        I:7, E: 3, False, StopTrainingHook.pre_step
        I:8, E: 4, False, SummaryHook.pre_step
        I:8, E: 4, True, CheckpointHook.pre_step
        I:8, E: 4, True, BackOffValidationHook.pre_step
        I:8, E: 4, True, StopTrainingHook.pre_step
        ''').strip()

        print('#' * 80)
        print(hook_calls)
        print('#' * 80)

        if hook_calls != hook_calls_ref:
            import difflib
            raise AssertionError('\n' + ('\n'.join(
                difflib.ndiff(
                    hook_calls_ref.splitlines(),
                    hook_calls.splitlines(),
                ))))

        files_after = tuple(tmp_dir.glob('*'))
        assert len(files_after) == 3, files_after
        for file in sorted(files_after):
            if 'tfevents' in file.name:
                if file in old_event_files:
                    continue

                events = list(load_events_as_dict(file))

                tags = []
                for event in events:
                    if 'summary' in event.keys():
                        value, = event['summary']['value']
                        tags.append(value['tag'])

                c = dict(collections.Counter(tags))
                assert len(events) == 44, (len(events), events)
                expect = {
                    'training/grad_norm': 2,
                    'training/grad_norm_': 2,
                    'training/loss': 2,
                    'training/lr/param_group_0': 2,
                    'training_timings/time_per_iteration': 2,
                    'training_timings/time_rel_to_device': 2,
                    'training_timings/time_rel_forward': 2,
                    'training_timings/time_rel_review': 2,
                    'training_timings/time_rel_backward': 2,
                    'training_timings/time_rel_data_loading': 2,
                    'training_timings/time_rel_step': 2,
                    'validation/loss': 2,
                    'validation/lr/param_group_0': 2,
                    'validation_timings/time_per_iteration': 2,
                    'validation_timings/time_rel_to_device': 2,
                    'validation_timings/time_rel_forward': 2,
                    'validation_timings/time_rel_review': 2,
                    'validation_timings/time_rel_backward': 2,
                    'validation_timings/time_rel_data_loading': 2,
                    'validation_timings/time_rel_step': 2,
                    # non validation time can only be measured between
                    # validations:
                    #  => # of non_val_time - 1 == # of val_time
                    'validation_timings/non_validation_time': 1,
                    'validation_timings/validation_time': 2,
                }
                assert c == expect, c
            elif file.name == 'checkpoints':
                checkpoints_files = tuple(file.glob('*'))
                assert len(checkpoints_files) == 8, checkpoints_files
                checkpoints_files_name = [f.name for f in checkpoints_files]
                expect = {
                    *[f'ckpt_{i}.pth' for i in [0, 2, 4, 6, 8]],
                    'validation_state.json', 'ckpt_best_loss.pth',
                    'ckpt_latest.pth'
                }
                assert expect == set(checkpoints_files_name), (
                    expect, checkpoints_files_name)
            else:
                raise ValueError(file)