def test_single_model(): if sys.platform.startswith('win'): pytest.skip( 'this doctest does not work on Windows, ' 'training is not possible on Windows due to symlinks being unavailable' ) tr_dataset, dt_dataset = get_dataset() tr_dataset = tr_dataset[:2] dt_dataset = dt_dataset[:2] with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) config = pt.Trainer.get_config(updates=pb.utils.nested.deflatten( { 'model.factory': Model, 'storage_dir': str(tmp_dir), 'stop_trigger': (2, 'epoch'), 'summary_trigger': (3, 'iteration'), 'checkpoint_trigger': (2, 'iteration') })) t = pt.Trainer.from_config(config) pre_state_dict = copy.deepcopy(t.state_dict()) files_before = tuple(tmp_dir.glob('*')) if len(files_before) != 0: # no event file raise Exception(files_before) t.register_validation_hook(validation_iterator=dt_dataset, max_checkpoints=None) # Wrap each trigger in each hook with TriggerMock. log_list = [] for hook in t.hooks: for k, v in list(hook.__dict__.items()): if isinstance(v, pt.train.trigger.Trigger): hook.__dict__[k] = TriggerMock(v, log_list) t.train(train_dataset=tr_dataset, resume=False) hook_calls = ('\n'.join(log_list)) # CheckpointedValidationHook trigger is called two times # (once for checkpointing once for validation)_file_name hook_calls_ref = textwrap.dedent(''' I:0, E: 0, True, SummaryHook.pre_step I:0, E: 0, True, BackOffValidationHook.pre_step I:0, E: 0, True, CheckpointHook.pre_step I:0, E: 0, False, StopTrainingHook.pre_step I:1, E: 0, False, SummaryHook.pre_step I:1, E: 0, False, BackOffValidationHook.pre_step I:1, E: 0, False, CheckpointHook.pre_step I:1, E: 0, False, StopTrainingHook.pre_step I:2, E: 1, False, SummaryHook.pre_step I:2, E: 1, True, BackOffValidationHook.pre_step I:2, E: 1, True, CheckpointHook.pre_step I:2, E: 1, False, StopTrainingHook.pre_step I:3, E: 1, True, SummaryHook.pre_step I:3, E: 1, False, BackOffValidationHook.pre_step I:3, E: 1, False, CheckpointHook.pre_step I:3, E: 1, False, StopTrainingHook.pre_step I:4, E: 2, False, SummaryHook.pre_step I:4, E: 2, True, BackOffValidationHook.pre_step I:4, E: 2, True, CheckpointHook.pre_step I:4, E: 2, True, StopTrainingHook.pre_step ''').strip() print('#' * 80) print(hook_calls) print('#' * 80) if hook_calls != hook_calls_ref: import difflib raise AssertionError('\n' + ('\n'.join( difflib.ndiff( hook_calls_ref.splitlines(), hook_calls.splitlines(), )))) old_event_files = [] files_after = tuple(tmp_dir.glob('*')) assert len(files_after) == 2, files_after for file in sorted(files_after): if 'tfevents' in file.name: old_event_files.append(file) events = list(load_events_as_dict(file)) tags = [] # time_rel_data_loading = [] # time_rel_train_step = [] time_per_iteration = [] relative_timings = collections.defaultdict(list) relative_timing_keys = { 'training_timings/time_rel_data_loading', 'training_timings/time_rel_to_device', 'training_timings/time_rel_forward', 'training_timings/time_rel_review', 'training_timings/time_rel_backward', 'training_timings/time_rel_optimize', } for event in events: if 'summary' in event.keys(): value, = event['summary']['value'] tags.append(value['tag']) if value['tag'] in relative_timing_keys: relative_timings[value['tag']].append( value['simple_value']) elif value[ 'tag'] == 'training_timings/time_per_iteration': time_per_iteration.append(value['simple_value']) c = dict(collections.Counter(tags)) # Training summary is written two times (at iteration 3 when # summary_trigger triggers and when training stops and # summary_hook is closed). # Validation summary is written when checkpoint_trigger # triggers, hence 3 times. # non_validation_time can only be measured between # validations => 2 values (one fewer than validation_time) expect = { 'training/grad_norm': 2, 'training/grad_norm_': 2, 'training/loss': 2, 'training/lr/param_group_0': 2, 'training_timings/time_per_iteration': 2, 'training_timings/time_rel_to_device': 2, 'training_timings/time_rel_forward': 2, 'training_timings/time_rel_review': 2, 'training_timings/time_rel_backward': 2, 'training_timings/time_rel_optimize': 2, 'training_timings/time_rel_data_loading': 2, # 'training_timings/time_rel_step': 2, 'validation/loss': 3, 'validation_timings/time_per_iteration': 3, 'validation_timings/time_rel_to_device': 3, 'validation_timings/time_rel_forward': 3, 'validation_timings/time_rel_review': 3, 'validation_timings/time_rel_data_loading': 3, # 'validation_timings/time_rel_step': 3, # non validation time can only be measured between # validations: # => # of non_val_time - 1 == # of val_time 'validation_timings/non_validation_time': 2, 'validation_timings/validation_time': 3, } pprint(c) if c != expect: import difflib raise AssertionError('\n' + ('\n'.join( difflib.ndiff( [ f'{k!r}: {v!r}' for k, v in sorted(expect.items()) ], [f'{k!r}: {v!r}' for k, v in sorted(c.items())], )))) assert len(events) == 46, (len(events), events) assert relative_timing_keys == set( relative_timings.keys()), (relative_timing_keys, relative_timings) for k, v in relative_timings.items(): assert len(v) > 0, (k, v, relative_timings) # The relative timings should sum up to one, # but this model is really cheap. # e.g. 0.00108 and 0.000604 per iteration. # This may cause the mismatch. # Allow a calculation error of 25%. # ToDo: Get this work with less than 1% error. relative_times = np.array(list( relative_timings.values())).sum(axis=0) if not np.all(relative_times > 0.75): raise AssertionError( pretty((relative_times, time_per_iteration, dict(relative_timings)))) if not np.all(relative_times <= 1): raise AssertionError( pretty((relative_times, time_per_iteration, dict(relative_timings)))) elif file.name == 'checkpoints': checkpoints_files = tuple(file.glob('*')) assert len(checkpoints_files) == 5, checkpoints_files checkpoints_files_name = [f.name for f in checkpoints_files] expect = { 'ckpt_0.pth', 'ckpt_2.pth', 'ckpt_4.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth' } assert expect == set(checkpoints_files_name), ( expect, checkpoints_files_name) ckpt_ranking = torch.load( str(file / 'ckpt_latest.pth' ))['hooks']['BackOffValidationHook']['ckpt_ranking'] assert ckpt_ranking[0][1] > 0, ckpt_ranking for i, ckpt in enumerate(ckpt_ranking): ckpt_ranking[i] = (ckpt[0], -1) expect = [(f'ckpt_{i}.pth', -1) for i in [0, 2, 4]] assert ckpt_ranking == expect, (ckpt_ranking, expect) for symlink in [ file / 'ckpt_latest.pth', file / 'ckpt_best_loss.pth', ]: assert symlink.is_symlink(), symlink target = os.readlink(str(symlink)) if '/' in target: raise AssertionError( f'The symlink {symlink} contains a "/".\n' f'Expected that the symlink has a ralative target,\n' f'but the target is: {target}') else: raise ValueError(file) post_state_dict = copy.deepcopy(t.state_dict()) assert pre_state_dict.keys() == post_state_dict.keys() equal_amount = { key: (pt.utils.to_numpy(parameter_pre) == pt.utils.to_numpy( post_state_dict['model'][key])).mean() for key, parameter_pre in pre_state_dict['model'].items() } # ToDo: why are so many weights unchanged? Maybe the zeros in the image? assert equal_amount == {'l.bias': 0.0, 'l.weight': 0.6900510204081632} import time # tfevents use unixtime as unique indicator. Sleep 2 seconds to ensure # new value time.sleep(2) config['stop_trigger'] = (4, 'epoch') t = pt.Trainer.from_config(config) t.register_validation_hook(validation_iterator=dt_dataset, max_checkpoints=None) log_list = [] for hook in t.hooks: for k, v in list(hook.__dict__.items()): if isinstance(v, pt.train.trigger.Trigger): hook.__dict__[k] = TriggerMock(v, log_list) t.train(train_dataset=tr_dataset, resume=True) hook_calls = ('\n'.join(log_list)) hook_calls_ref = textwrap.dedent(''' I:4, E: 2, False, SummaryHook.pre_step I:4, E: 2, False, BackOffValidationHook.pre_step I:4, E: 2, False, CheckpointHook.pre_step I:4, E: 2, False, StopTrainingHook.pre_step I:5, E: 2, False, SummaryHook.pre_step I:5, E: 2, False, BackOffValidationHook.pre_step I:5, E: 2, False, CheckpointHook.pre_step I:5, E: 2, False, StopTrainingHook.pre_step I:6, E: 3, True, SummaryHook.pre_step I:6, E: 3, True, BackOffValidationHook.pre_step I:6, E: 3, True, CheckpointHook.pre_step I:6, E: 3, False, StopTrainingHook.pre_step I:7, E: 3, False, SummaryHook.pre_step I:7, E: 3, False, BackOffValidationHook.pre_step I:7, E: 3, False, CheckpointHook.pre_step I:7, E: 3, False, StopTrainingHook.pre_step I:8, E: 4, False, SummaryHook.pre_step I:8, E: 4, True, BackOffValidationHook.pre_step I:8, E: 4, True, CheckpointHook.pre_step I:8, E: 4, True, StopTrainingHook.pre_step ''').strip() print('#' * 80) print(hook_calls) print('#' * 80) if hook_calls != hook_calls_ref: import difflib raise AssertionError('\n' + ('\n'.join( difflib.ndiff( hook_calls_ref.splitlines(), hook_calls.splitlines(), )))) files_after = tuple(tmp_dir.glob('*')) assert len(files_after) == 3, files_after for file in sorted(files_after): if 'tfevents' in file.name: if file in old_event_files: continue events = list(load_events_as_dict(file)) tags = [] for event in events: if 'summary' in event.keys(): value, = event['summary']['value'] tags.append(value['tag']) c = dict(collections.Counter(tags)) assert len(events) == 38, (len(events), events) expect = { 'training/grad_norm': 2, 'training/grad_norm_': 2, 'training/loss': 2, 'training/lr/param_group_0': 2, 'training_timings/time_per_iteration': 2, 'training_timings/time_rel_to_device': 2, 'training_timings/time_rel_forward': 2, 'training_timings/time_rel_review': 2, 'training_timings/time_rel_backward': 2, 'training_timings/time_rel_optimize': 2, 'training_timings/time_rel_data_loading': 2, # 'training_timings/time_rel_step': 2, 'validation/loss': 2, # 'validation/lr/param_group_0': 2, 'validation_timings/time_per_iteration': 2, 'validation_timings/time_rel_to_device': 2, 'validation_timings/time_rel_forward': 2, 'validation_timings/time_rel_review': 2, 'validation_timings/time_rel_data_loading': 2, # 'validation_timings/time_rel_step': 2, # non validation time can only be measured between # validations: # => # of non_val_time - 1 == # of val_time 'validation_timings/non_validation_time': 1, 'validation_timings/validation_time': 2, } if c != expect: import difflib raise AssertionError('\n' + ('\n'.join( difflib.ndiff( [ f'{k!r}: {v!r}' for k, v in sorted(expect.items()) ], [f'{k!r}: {v!r}' for k, v in sorted(c.items())], )))) elif file.name == 'checkpoints': checkpoints_files = tuple(file.glob('*')) assert len(checkpoints_files) == 7, checkpoints_files checkpoints_files_name = [f.name for f in checkpoints_files] expect = { *[f'ckpt_{i}.pth' for i in [0, 2, 4, 6, 8]], 'ckpt_best_loss.pth', 'ckpt_latest.pth' } assert expect == set(checkpoints_files_name), ( expect, checkpoints_files_name) else: raise ValueError(file)
def test_single_model(): tr_dataset, dt_dataset = get_dataset() tr_dataset = tr_dataset[:2] dt_dataset = dt_dataset[:2] with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) config = pt.Trainer.get_config(updates=pb.utils.nested.deflatten( { 'model.factory': Model, 'storage_dir': str(tmp_dir), 'stop_trigger': (2, 'epoch'), 'summary_trigger': (3, 'iteration'), 'checkpoint_trigger': (2, 'iteration') })) t = pt.Trainer.from_config(config) pre_state_dict = copy.deepcopy(t.state_dict()) files_before = tuple(tmp_dir.glob('*')) if len(files_before) != 0: # no event file raise Exception(files_before) t.register_validation_hook(validation_iterator=dt_dataset, max_checkpoints=None) # Wrap each trigger in each hook with TriggerMock. log_list = [] for hook in t.hooks: for k, v in list(hook.__dict__.items()): if isinstance(v, pt.train.trigger.Trigger): hook.__dict__[k] = TriggerMock(v, log_list) t.train(train_iterator=tr_dataset, resume=False) hook_calls = ('\n'.join(log_list)) # CheckpointedValidationHook trigger is called two times # (once for checkpointing once for validation)_file_name hook_calls_ref = textwrap.dedent(''' I:0, E: 0, True, SummaryHook.pre_step I:0, E: 0, True, CheckpointHook.pre_step I:0, E: 0, True, BackOffValidationHook.pre_step I:0, E: 0, False, StopTrainingHook.pre_step I:1, E: 0, False, SummaryHook.pre_step I:1, E: 0, False, CheckpointHook.pre_step I:1, E: 0, False, BackOffValidationHook.pre_step I:1, E: 0, False, StopTrainingHook.pre_step I:2, E: 1, False, SummaryHook.pre_step I:2, E: 1, True, CheckpointHook.pre_step I:2, E: 1, True, BackOffValidationHook.pre_step I:2, E: 1, False, StopTrainingHook.pre_step I:3, E: 1, True, SummaryHook.pre_step I:3, E: 1, False, CheckpointHook.pre_step I:3, E: 1, False, BackOffValidationHook.pre_step I:3, E: 1, False, StopTrainingHook.pre_step I:4, E: 2, False, SummaryHook.pre_step I:4, E: 2, True, CheckpointHook.pre_step I:4, E: 2, True, BackOffValidationHook.pre_step I:4, E: 2, True, StopTrainingHook.pre_step ''').strip() print('#' * 80) print(hook_calls) print('#' * 80) if hook_calls != hook_calls_ref: import difflib raise AssertionError('\n' + ('\n'.join( difflib.ndiff( hook_calls_ref.splitlines(), hook_calls.splitlines(), )))) old_event_files = [] files_after = tuple(tmp_dir.glob('*')) assert len(files_after) == 2, files_after for file in sorted(files_after): if 'tfevents' in file.name: old_event_files.append(file) events = list(load_events_as_dict(file)) tags = [] time_rel_data_loading = [] time_rel_train_step = [] for event in events: if 'summary' in event.keys(): value, = event['summary']['value'] tags.append(value['tag']) if value[ 'tag'] == 'training_timings/time_rel_data_loading': time_rel_data_loading.append(value['simple_value']) elif value['tag'] == 'training_timings/time_rel_step': time_rel_train_step.append(value['simple_value']) c = dict(collections.Counter(tags)) # Training summary is written two times (at iteration 3 when # summary_trigger triggers and when training stops and # summary_hook is closed). # Validation summary is written when checkpoint_trigger # triggers, hence 3 times. # non_validation_time can only be measured between # validations => 2 values (one fewer than validation_time) expect = { 'training/grad_norm': 2, 'training/grad_norm_': 2, 'training/loss': 2, 'training/lr/param_group_0': 2, 'training_timings/time_per_iteration': 2, 'training_timings/time_rel_to_device': 2, 'training_timings/time_rel_forward': 2, 'training_timings/time_rel_review': 2, 'training_timings/time_rel_backward': 2, 'training_timings/time_rel_data_loading': 2, 'training_timings/time_rel_step': 2, 'validation/loss': 3, 'validation/lr/param_group_0': 3, 'validation_timings/time_per_iteration': 3, 'validation_timings/time_rel_to_device': 3, 'validation_timings/time_rel_forward': 3, 'validation_timings/time_rel_review': 3, 'validation_timings/time_rel_backward': 3, 'validation_timings/time_rel_data_loading': 3, 'validation_timings/time_rel_step': 3, # non validation time can only be measured between # validations: # => # of non_val_time - 1 == # of val_time 'validation_timings/non_validation_time': 2, 'validation_timings/validation_time': 3, } pprint(c) assert c == expect, c assert len(events) == 55, (len(events), events) assert len(time_rel_data_loading) > 0, (time_rel_data_loading, time_rel_train_step) assert len(time_rel_train_step) > 0, (time_rel_data_loading, time_rel_train_step) np.testing.assert_allclose( np.add(time_rel_data_loading, time_rel_train_step), 1, err_msg=f'{time_rel_data_loading}, {time_rel_train_step})') elif file.name == 'checkpoints': checkpoints_files = tuple(file.glob('*')) assert len(checkpoints_files) == 6, checkpoints_files checkpoints_files_name = [f.name for f in checkpoints_files] expect = { 'ckpt_0.pth', 'ckpt_2.pth', 'ckpt_4.pth', 'validation_state.json', 'ckpt_best_loss.pth', 'ckpt_latest.pth' } assert expect == set(checkpoints_files_name), ( expect, checkpoints_files_name) ckpt_ranking = pb.io.load_json( file / 'validation_state.json')['ckpt_ranking'] assert ckpt_ranking[0][1] > 0, ckpt_ranking for ckpt in ckpt_ranking: ckpt[1] = -1 expect = [[f'ckpt_{i}.pth', -1] for i in [0, 2, 4]] assert ckpt_ranking == expect, (ckpt_ranking, expect) for symlink in [ file / 'ckpt_latest.pth', file / 'ckpt_best_loss.pth', ]: assert symlink.is_symlink(), symlink target = os.readlink(str(symlink)) if '/' in target: raise AssertionError( f'The symlink {symlink} contains a "/".\n' f'Expected that the symlink has a ralative target,\n' f'but the target is: {target}') else: raise ValueError(file) post_state_dict = copy.deepcopy(t.state_dict()) assert pre_state_dict.keys() == post_state_dict.keys() equal_amount = { key: (pt.utils.to_numpy(parameter_pre) == pt.utils.to_numpy( post_state_dict['model'][key])).mean() for key, parameter_pre in pre_state_dict['model'].items() } # ToDo: why are so many weights unchanged? Maybe the zeros in the image? assert equal_amount == {'l.bias': 0.0, 'l.weight': 0.6900510204081632} import time # tfevents use unixtime as unique indicator. Sleep 2 seconds to ensure # new value time.sleep(2) config['stop_trigger'] = (4, 'epoch') t = pt.Trainer.from_config(config) t.register_validation_hook(validation_iterator=dt_dataset, max_checkpoints=None) log_list = [] for hook in t.hooks: for k, v in list(hook.__dict__.items()): if isinstance(v, pt.train.trigger.Trigger): hook.__dict__[k] = TriggerMock(v, log_list) t.train(train_iterator=tr_dataset, resume=True) hook_calls = ('\n'.join(log_list)) hook_calls_ref = textwrap.dedent(''' I:4, E: 2, False, SummaryHook.pre_step I:4, E: 2, False, CheckpointHook.pre_step I:4, E: 2, False, BackOffValidationHook.pre_step I:4, E: 2, False, StopTrainingHook.pre_step I:5, E: 2, False, SummaryHook.pre_step I:5, E: 2, False, CheckpointHook.pre_step I:5, E: 2, False, BackOffValidationHook.pre_step I:5, E: 2, False, StopTrainingHook.pre_step I:6, E: 3, True, SummaryHook.pre_step I:6, E: 3, True, CheckpointHook.pre_step I:6, E: 3, True, BackOffValidationHook.pre_step I:6, E: 3, False, StopTrainingHook.pre_step I:7, E: 3, False, SummaryHook.pre_step I:7, E: 3, False, CheckpointHook.pre_step I:7, E: 3, False, BackOffValidationHook.pre_step I:7, E: 3, False, StopTrainingHook.pre_step I:8, E: 4, False, SummaryHook.pre_step I:8, E: 4, True, CheckpointHook.pre_step I:8, E: 4, True, BackOffValidationHook.pre_step I:8, E: 4, True, StopTrainingHook.pre_step ''').strip() print('#' * 80) print(hook_calls) print('#' * 80) if hook_calls != hook_calls_ref: import difflib raise AssertionError('\n' + ('\n'.join( difflib.ndiff( hook_calls_ref.splitlines(), hook_calls.splitlines(), )))) files_after = tuple(tmp_dir.glob('*')) assert len(files_after) == 3, files_after for file in sorted(files_after): if 'tfevents' in file.name: if file in old_event_files: continue events = list(load_events_as_dict(file)) tags = [] for event in events: if 'summary' in event.keys(): value, = event['summary']['value'] tags.append(value['tag']) c = dict(collections.Counter(tags)) assert len(events) == 44, (len(events), events) expect = { 'training/grad_norm': 2, 'training/grad_norm_': 2, 'training/loss': 2, 'training/lr/param_group_0': 2, 'training_timings/time_per_iteration': 2, 'training_timings/time_rel_to_device': 2, 'training_timings/time_rel_forward': 2, 'training_timings/time_rel_review': 2, 'training_timings/time_rel_backward': 2, 'training_timings/time_rel_data_loading': 2, 'training_timings/time_rel_step': 2, 'validation/loss': 2, 'validation/lr/param_group_0': 2, 'validation_timings/time_per_iteration': 2, 'validation_timings/time_rel_to_device': 2, 'validation_timings/time_rel_forward': 2, 'validation_timings/time_rel_review': 2, 'validation_timings/time_rel_backward': 2, 'validation_timings/time_rel_data_loading': 2, 'validation_timings/time_rel_step': 2, # non validation time can only be measured between # validations: # => # of non_val_time - 1 == # of val_time 'validation_timings/non_validation_time': 1, 'validation_timings/validation_time': 2, } assert c == expect, c elif file.name == 'checkpoints': checkpoints_files = tuple(file.glob('*')) assert len(checkpoints_files) == 8, checkpoints_files checkpoints_files_name = [f.name for f in checkpoints_files] expect = { *[f'ckpt_{i}.pth' for i in [0, 2, 4, 6, 8]], 'validation_state.json', 'ckpt_best_loss.pth', 'ckpt_latest.pth' } assert expect == set(checkpoints_files_name), ( expect, checkpoints_files_name) else: raise ValueError(file)