def test_persist_1(data): model = deepcopy(data[0]) trainer = Trainer(model=model, optimizer=Adam(lr=0.1), loss_func=MSELoss(), epochs=200) trainer.extend(TensorConverter(), Persist('model_dir')) trainer.fit(*data[1], *data[1]) persist = trainer['persist'] checker = persist._checker assert isinstance(persist, Persist) assert isinstance(checker.model, torch.nn.Module) assert isinstance(checker.describe, dict) assert isinstance(checker.files, list) assert set(checker.files) == { 'model', 'init_state', 'model_structure', 'describe', 'training_info', 'final_state' } trainer = Trainer.load(checker) assert isinstance(trainer.training_info, pd.DataFrame) assert isinstance(trainer.model, torch.nn.Module) assert isinstance(trainer._training_info, list) assert trainer.optimizer is None assert trainer.lr_scheduler is None assert trainer.x_val is None assert trainer.y_val is None assert trainer.validate_dataset is None assert trainer._optimizer_state is None assert trainer.total_epochs == 0 assert trainer.total_iterations == 0 assert trainer.loss_type is None assert trainer.loss_func is None trainer = Trainer.load(from_=checker.path, optimizer=Adam(), loss_func=MSELoss(), lr_scheduler=ExponentialLR(gamma=0.99), clip_grad=ClipValue(clip_value=0.1)) assert isinstance(trainer._scheduler, ExponentialLR) assert isinstance(trainer._optim, Adam) assert isinstance(trainer.clip_grad, ClipValue) assert isinstance(trainer.loss_func, MSELoss)
def test_persist_save_checkpoints(data): class _Trainer(BaseRunner): def __init__(self): super().__init__() self.model = SequentialLinear(50, 2) def predict(self, x_, y_): # noqa return x_, y_ cp_1 = Trainer.checkpoint_tuple( id='cp_1', iterations=111, model_state=SequentialLinear(50, 2).state_dict(), ) cp_2 = Trainer.checkpoint_tuple( id='cp_2', iterations=111, model_state=SequentialLinear(50, 2).state_dict(), ) # save checkpoint p = Persist('test_model_1', increment=False, only_best_states=False) p.before_proc(trainer=_Trainer()) p.on_checkpoint(cp_1, trainer=_Trainer()) p.on_checkpoint(cp_2, trainer=_Trainer()) assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' / 'cp_1.pth.s').exists() assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' / 'cp_2.pth.s').exists() # reduced save checkpoint p = Persist('test_model_2', increment=False, only_best_states=True) p.before_proc(trainer=_Trainer()) p.on_checkpoint(cp_1, trainer=_Trainer()) p.on_checkpoint(cp_2, trainer=_Trainer()) assert (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp.pth.s').exists() assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp_1.pth.s').exists() assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp_2.pth.s').exists() # no checkpoint will be saved p = Persist('test_model_3', increment=False, only_best_states=True) p.before_proc(trainer=_Trainer()) p.on_checkpoint(cp_2, trainer=_Trainer()) assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' / 'cp.pth.s').exists() assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' / 'cp_2.pth.s').exists()
def test_persist_1(data): class _Trainer(BaseRunner): def __init__(self): super().__init__() self.model = SequentialLinear(50, 2) def predict(self, x_, y_): # noqa return x_, y_ p = Persist() with pytest.raises(ValueError, match='can not access property `path` before training'): p.path p.before_proc(trainer=_Trainer()) assert p.path == str(Path('.').resolve() / Path(os.getcwd()).name) with pytest.raises(ValueError, match='can not reset property `path` after training'): p.path = 'aa' p = Persist('test_model') p.before_proc(trainer=_Trainer()) assert p.path == str(Path('.').resolve() / 'test_model') assert (Path('.').resolve() / 'test_model' / 'describe.pkl.z').exists() assert (Path('.').resolve() / 'test_model' / 'init_state.pth.s').exists() assert (Path('.').resolve() / 'test_model' / 'model.pth.m').exists() assert (Path('.').resolve() / 'test_model' / 'model_structure.pkl.z').exists() p = Persist('test_model', increment=True) p.before_proc(trainer=_Trainer()) assert p.path == str(Path('.').resolve() / 'test_model@1') assert (Path('.').resolve() / 'test_model@1' / 'describe.pkl.z').exists() assert (Path('.').resolve() / 'test_model@1' / 'init_state.pth.s').exists() assert (Path('.').resolve() / 'test_model@1' / 'model.pth.m').exists() assert (Path('.').resolve() / 'test_model@1' / 'model_structure.pkl.z').exists()