def test_validator_1(data): model = deepcopy(data[0]) trainer = Trainer(model=model, optimizer=Adam(lr=0.1), loss_func=MSELoss(), epochs=20) trainer.extend(TensorConverter(), Validator('regress', early_stop=30, trace_order=1, warming_up=0, mae=0)) trainer.fit(*data[1], *data[1]) assert trainer.get_checkpoint() == ['mae'] model = deepcopy(data[0]) trainer = Trainer(model=model, optimizer=Adam(lr=0.1), loss_func=MSELoss(), epochs=20) trainer.extend(TensorConverter(), Validator('regress', early_stop=30, trace_order=5, warming_up=50, mae=0)) trainer.fit(*data[1], *data[1]) assert trainer.get_checkpoint() == []
def test_trainer_fit_3(data): model = deepcopy(data[0]) trainer = Trainer(model=model, optimizer=Adam(), loss_func=MSELoss(), epochs=5) trainer.fit(*data[1]) assert len(trainer.checkpoints.keys()) == 0 trainer.reset() assert trainer.total_iterations == 0 assert trainer.total_epochs == 0 assert len(trainer.get_checkpoint()) == 0 trainer.fit(*data[1], checkpoint=True) assert len(trainer.get_checkpoint()) == 5 assert isinstance(trainer.get_checkpoint(2), trainer.checkpoint_tuple) assert isinstance(trainer.get_checkpoint('cp_2'), trainer.checkpoint_tuple) with pytest.raises(TypeError, match='parameter <cp> must be str or int'): trainer.get_checkpoint([]) trainer.reset(to=3, remove_checkpoints=False) assert len(trainer.get_checkpoint()) == 5 assert isinstance(trainer.get_checkpoint(2), trainer.checkpoint_tuple) assert isinstance(trainer.get_checkpoint('cp_2'), trainer.checkpoint_tuple) trainer.reset(to='cp_3') assert trainer.total_iterations == 0 assert trainer.total_epochs == 0 assert len(trainer.get_checkpoint()) == 0 with pytest.raises( TypeError, match='parameter <to> must be torch.nnModule, int, or str'): trainer.reset(to=[]) # todo: need a real testing trainer.fit(*data[1], checkpoint=True) trainer.predict(*data[1], checkpoint=3) trainer.reset() trainer.fit(*data[1], checkpoint=lambda i: (True, f'new:{i}')) assert len(trainer.get_checkpoint()) == 5 assert trainer.get_checkpoint() == list([f'new:{i + 1}' for i in range(5)])