def test_extensions_manager_state_dict_future_ppe_version(): model_state_dict = object() optimizer_state_dict = object() max_epochs = 5 iters_per_epoch = 4 passed_iteration = 11 manager = training.ExtensionsManager( {'model_name': _StateDictModel(state_dict=model_state_dict)}, {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)}, max_epochs, iters_per_epoch=iters_per_epoch, ) for _ in range(passed_iteration): with manager.run_iteration(): pass new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict) new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict) manager_2 = training.ExtensionsManager( {'model_name': new_model}, {'optimizer_name': new_optimizer}, max_epochs, iters_per_epoch=iters_per_epoch, ) state_dict = manager.state_dict() state_dict['ppe_version'] = '23.0.0' with pytest.warns(UserWarning, match='version'): manager_2.load_state_dict(state_dict)
def test_model_transformations_in_state_dict(): model_state_dict = object() optimizer_state_dict = object() max_epochs = 5 iters_per_epoch = 4 model = Wrapper(_StateDictModel(state_dict=model_state_dict)) manager = training.ExtensionsManager( model, _StateDictObj(state_dict=optimizer_state_dict), max_epochs, iters_per_epoch=iters_per_epoch, transform_model=lambda n, x: x.wrapper_module(), ) state_dict = manager.state_dict() assert model.accessed new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict) new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict) new_manager = training.ExtensionsManager( Wrapper(new_model), new_optimizer, max_epochs, iters_per_epoch=iters_per_epoch, transform_model=lambda n, x: x.wrapper_module(), ) new_manager.load_state_dict(state_dict) assert isinstance(new_manager.models['main'], _StateDictModel)
def test_extensions_manager_state_dict(): model_state_dict = object() optimizer_state_dict = object() extension_state_dict = object() max_epochs = 5 iters_per_epoch = 4 passed_iteration = 11 manager = training.ExtensionsManager( {'model_name': _StateDictModel(state_dict=model_state_dict)}, {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)}, max_epochs, iters_per_epoch=iters_per_epoch, ) manager.extend(_StateDictExtension(state_dict=extension_state_dict), name='extension_name') for it in range(passed_iteration): with manager.run_iteration(): pass state_dict = manager.state_dict() assert state_dict == { '_start_iteration': passed_iteration, 'models': { 'model_name': model_state_dict }, 'optimizers': { 'optimizer_name': optimizer_state_dict }, 'extensions': { 'extension_name': { 'extension': extension_state_dict, 'trigger': { '_previous_iteration': passed_iteration, '_previous_epoch_detail': passed_iteration / iters_per_epoch }, } }, } new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict) new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict) new_extension = _StateDictExtension( state_dict_to_be_loaded=extension_state_dict) new_manager = training.ExtensionsManager( {'model_name': new_model}, {'optimizer_name': new_optimizer}, max_epochs, iters_per_epoch=iters_per_epoch, ) new_manager.extend(new_extension, name='extension_name') new_manager.load_state_dict(state_dict) assert new_model.called_load_state_dict == 1 assert new_optimizer.called_load_state_dict == 1 assert new_optimizer.called_load_state_dict == 1
def test_trigger(iters_per_epoch, schedule, expected, finished, resume): trainer = training.ExtensionsManager({}, [], 100, iters_per_epoch=iters_per_epoch) trigger = triggers.ManualScheduleTrigger(*schedule) _test_trigger(trainer, trigger, expected, finished)
def test_extensions_manager_extensions(): model = nn.Module() optimizer = object() max_epochs = 5 iters_per_epoch = 4 manager = training.ExtensionsManager( {'model_name': model}, {'optimizer_name': optimizer}, max_epochs, iters_per_epoch=iters_per_epoch, ) call_record = [] init_record = [] dummy5 = _DummyExtension(5, call_record, init_record) exts = [ _DummyExtension(0, call_record, init_record), _DummyExtensionInitialize(1, call_record, init_record), _DummyExtension(2, call_record, init_record), _DummyExtensionInitialize(3, call_record, init_record), _DummyExtensionInitialize(4, call_record, init_record), lambda manager: dummy5(manager), ] manager.extend(exts[0], 'ext0', priority=2, call_before_training=True) manager.extend(exts[1], 'ext1', priority=1, call_before_training=False) manager.extend(exts[2], 'ext2', priority=3, call_before_training=False) manager.extend(exts[3], 'ext3', priority=0, call_before_training=True) manager.extend(exts[4], 'ext4', priority=4, call_before_training=True) manager.extend(exts[5], 'ext5', priority=-1, call_before_training=True) assert manager.get_extension('ext0') is exts[0] assert manager.get_extension('ext1') is exts[1] assert manager.get_extension('ext2') is exts[2] assert manager.get_extension('ext3') is exts[3] assert manager.get_extension('ext4') is exts[4] with pytest.raises(ValueError): manager.get_extension('ext10') for it in range(max_epochs * iters_per_epoch): call_record.clear() init_record.clear() with manager.run_iteration(): assert manager.iteration == it if it == 0: assert call_record == [4, 0, 3, 5] assert init_record == [4, 1, 3] else: assert call_record == [] assert init_record == [] call_record.clear() init_record.clear() assert call_record == [4, 2, 0, 1, 3, 5] assert init_record == []
def test_model_transformations(path): model_state_dict = object() optimizer_state_dict = object() max_epochs = 5 iters_per_epoch = 4 model = Wrapper(_StateDictModel(state_dict=model_state_dict)) manager = training.ExtensionsManager( model, _StateDictObj(state_dict=optimizer_state_dict), max_epochs, iters_per_epoch=iters_per_epoch, out_dir=path, ) snapshot = extensions.snapshot( filename='test', transform_models=lambda n, x: x.wrapper_module()) snapshot(manager) assert model.accessed # Verify that autoload applies the transformation to_load = torch.load(os.path.join(path, 'test')) trainer = get_trainer(out_dir=path, state_to_load=to_load) snapshot = extensions.snapshot( filename='test', autoload=True, autoload_transform_models=lambda n, x: Wrapper(x)) snapshot.initialize(trainer) assert isinstance(trainer.models['main'], Wrapper)
def test_extensions_manager_with_plain_model_and_optimizer(): model_state_dict = object() optimizer_state_dict = object() max_epochs = 5 iters_per_epoch = 4 manager = training.ExtensionsManager( _StateDictModel(state_dict=model_state_dict), _StateDictObj(state_dict=optimizer_state_dict), max_epochs, iters_per_epoch=iters_per_epoch, ) state_dict = manager.state_dict() assert state_dict == { '_start_execution': 0, '_start_iteration': 0, 'models': { 'main': model_state_dict }, 'optimizers': { 'main': optimizer_state_dict }, 'extensions': {} }
def get_trainer_with_mock_updater(path): epochs = 10 # FIXME model = _create_distributed_model() optimizer = torch.optim.SGD(model.parameters(), lr=1.0) optimizers = {'main': optimizer} models = {'main': model} return training.ExtensionsManager( models, optimizers, epochs, iters_per_epoch=1, out_dir=path)
def test_trigger( trigger_type, trigger_args, iters_per_epoch, accuracies, expected, resume): key = 'main/accuracy' manager = training.ExtensionsManager( {}, [], 100, iters_per_epoch=iters_per_epoch) trigger = trigger_type(key, *trigger_args) _test_trigger( manager, trigger, key, accuracies, expected)
def test_trigger(iters_per_epoch, interval, expected, resume): trainer = training.ExtensionsManager( {}, [], 100, iters_per_epoch=iters_per_epoch) trigger = triggers.IntervalTrigger(*interval) for e in expected: with trainer.run_iteration(): pass assert trigger.may_fire(trainer.iteration, iters_per_epoch) == e assert trigger(trainer) == e
def get_trainer_with_mock_updater(*, out_dir, state_to_load=None): model_state_dict = {} optimizer_state_dict = {} models = {'main': _StateDictModel(state_dict=model_state_dict)} optimizers = {'main': _StateDictObj(state_dict=optimizer_state_dict)} epochs = 10 # FIXME return training.ExtensionsManager( models, optimizers, epochs, iters_per_epoch=10, out_dir=out_dir)
def test_get_trigger(iters_per_epoch, trigger_args, expected): trainer = training.ExtensionsManager({}, [], 100, iters_per_epoch=iters_per_epoch) trigger = trigger_util.get_trigger(trigger_args) # before the first iteration, trigger should be False for it, e in enumerate([False] + expected): with trainer.run_iteration(): assert trigger(trainer) == e
def get_manager_model_optimizer(): epochs = 3 model = torch.nn.Linear(1, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1.0) optimizers = {'main': optimizer} models = {'main': model} manager = training.ExtensionsManager( models, optimizers, epochs, iters_per_epoch=4) manager.extend(FailOnNonNumber()) return manager, model, optimizer
def test_call_optimizers(): m = torch.nn.Linear(5, 5) a = torch.ones(1, requires_grad=True) optimizer = torch.optim.SGD(lr=1.0, params=[a]) manager = training.ExtensionsManager( m, optimizer, 1, iters_per_epoch=1, ) with manager.run_iteration(step_optimizers=['main']): a.grad = torch.tensor([2.0]) assert torch.equal(a.detach(), torch.tensor([-1.]))
def test_manager_status_info(): manager = training.ExtensionsManager(nn.Module(), object(), 10, iters_per_epoch=4) manager.iteration = 9 assert manager.iteration == 9 assert manager.epoch == 2 assert manager.epoch_detail == 2.25 manager.iteration = 15 assert manager.epoch == 3 assert manager.epoch_detail == 3.75
def test_resumed_trigger(iters_per_epoch, schedule, expected, finished, resume): trainer = training.ExtensionsManager({}, [], 100, iters_per_epoch=iters_per_epoch) trigger = triggers.ManualScheduleTrigger(*schedule) _test_trigger(trainer, trigger, expected[:resume], finished[:resume]) state = trigger.state_dict() new_trigger = triggers.ManualScheduleTrigger(*schedule) new_trigger.load_state_dict(state) _test_trigger(trainer, new_trigger, expected[resume:], finished[resume:])
def test_model_transformations(): model_state_dict = object() optimizer_state_dict = object() max_epochs = 5 iters_per_epoch = 4 model = Wrapper(_StateDictModel(state_dict=model_state_dict)) manager = training.ExtensionsManager( model, _StateDictObj(state_dict=optimizer_state_dict), max_epochs, iters_per_epoch=iters_per_epoch, transform_model=lambda n, x: x.wrapper_module(), ) assert not isinstance(manager.models['main'], Wrapper) assert model.accessed
def test_on_error(): # Will fail when accesing the dummy optimizer optimizers = {'main': object()} trainer = training.ExtensionsManager( {}, optimizers, 1, iters_per_epoch=1, out_dir='.') filename = 'myfile-deadbeef.dat' snapshot = extensions.snapshot_object(trainer, filename, snapshot_on_error=True) trainer.extend(snapshot) assert not os.path.exists(filename) with pytest.raises(AttributeError): with trainer.run_iteration(): pass assert not os.path.exists(filename)
def test_resumed_trigger( trigger_type, trigger_args, iters_per_epoch, accuracies, expected, resume): key = 'main/accuracy' manager = training.ExtensionsManager( {}, [], 100, iters_per_epoch=iters_per_epoch) trigger = trigger_type(key, *trigger_args) _test_trigger( manager, trigger, key, accuracies[:resume], expected[:resume]) state = trigger.state_dict() new_trigger = trigger_type(key, *trigger_args) new_trigger.load_state_dict(state) _test_trigger( manager, new_trigger, key, accuracies[resume:], expected[resume:])
def test_extensions_accessing_models_without_flag(priority): m = torch.nn.Linear(5, 5) a = torch.ones(1, requires_grad=True) optimizer = torch.optim.SGD(lr=1.0, params=[a]) extension = _DummyExtension(0, [], [], True) extension.name = 'Dummy' extension.needs_model_state = False extension.trigger = (1, 'iteration') if priority is not None: extension.priority = priority manager = training.ExtensionsManager(m, optimizer, 1, iters_per_epoch=5, extensions=[extension]) while not manager.stop_trigger: with pytest.raises(RuntimeError): with manager.run_iteration(): pass
def test_model_transformations(path): model_state_dict = object() optimizer_state_dict = object() max_epochs = 5 iters_per_epoch = 4 model = Wrapper(_StateDictModel(state_dict=model_state_dict)) manager = training.ExtensionsManager( model, _StateDictObj(state_dict=optimizer_state_dict), max_epochs, iters_per_epoch=iters_per_epoch, out_dir=path, transform_model=lambda n, x: x.wrapper_module(), ) snapshot = extensions.snapshot(filename='test') snapshot(manager) assert model.accessed
def test_resumed_trigger(iters_per_epoch, interval, expected, resume): trainer = training.ExtensionsManager({}, [], 100, iters_per_epoch=iters_per_epoch) trigger = triggers.IntervalTrigger(*interval) for e in expected[:resume]: with trainer.run_iteration(): pass assert trigger(trainer) == e state = trigger.state_dict() new_trigger = triggers.IntervalTrigger(*interval) new_trigger.load_state_dict(state) for e in expected[resume:]: with trainer.run_iteration(): pass assert new_trigger(trainer) == e
def test_needs_state_this_iteration(): m = torch.nn.Linear(5, 5) a = torch.ones(1, requires_grad=True) optimizer = torch.optim.SGD(lr=1.0, params=[a]) extension = _DummyExtension(0, [], [], True) extension.name = 'Dummy' extension.needs_model_state = True extension.trigger = (50, 'iteration') manager = training.ExtensionsManager(m, optimizer, 1, iters_per_epoch=100, extensions=[extension]) while not manager.stop_trigger: with manager.run_iteration(): # iteration is always added 1 before calling # extensions if manager.iteration in (49, 99): assert manager.needs_state_this_iteration() else: assert not manager.needs_state_this_iteration()
def test_deferred_iteration(): m = torch.nn.Linear(5, 5) a = torch.ones(1, requires_grad=True) optimizer = torch.optim.SGD(lr=1.0, params=[a]) call_record = [] extension = _DummyExtension(0, call_record, []) extension.name = 'Dummy 0' extension.trigger = (1, 'iteration') extension.is_async = True extension2 = _DummyExtension(1, call_record, []) extension2.name = 'Dummy 1' extension2.trigger = (1, 'iteration') manager = training.ExtensionsManager( m, optimizer, 1, iters_per_epoch=100, extensions=[extension, extension2] ) for _ in range(5): with manager.run_iteration() as iter_handler: # iteration is always added 1 before calling # extensions iter_handler.defer() with manager.run_iteration(): pass assert manager.iteration == 1 assert manager.execution == 6 assert call_record == [0] * 6 + [1] for _ in range(5): with manager.complete_iteration(): pass assert manager.iteration == 6 assert manager.execution == 6 assert call_record == [0] * 6 + [1] * 6