def test_trainer_fit_4(data): model = deepcopy(data[0]) trainer = Trainer(model=model, optimizer=Adam(), loss_func=MSELoss(), clip_grad=ClipValue(0.1), lr_scheduler=ReduceLROnPlateau(), epochs=10) count = 1 for i in trainer(*data[1]): assert isinstance(i, dict) assert i['i_epoch'] == count if count == 3: trainer.early_stop('stop') count += 1 assert trainer.total_epochs == 3 assert trainer._early_stopping == (True, 'stop') trainer.reset() train_set = DataLoader(TensorDataset(*data[1])) count = 1 for i in trainer(training_dataset=train_set): assert isinstance(i, dict) assert i['i_batch'] == count if count == 3: trainer.early_stop('stop!!!') count += 1 assert trainer.total_iterations == 3 assert trainer._early_stopping == (True, 'stop!!!')
def test_trainer_fit_1(data): model = deepcopy(data[0]) trainer = Trainer(model=model, optimizer=Adam(), loss_func=MSELoss()) trainer.fit(*data[1]) assert trainer.total_iterations == 200 assert trainer.total_epochs == 200 trainer.fit(*data[1], epochs=20) assert trainer.total_iterations == 220 assert trainer.total_epochs == 220 trainer.reset() assert trainer.total_iterations == 0 assert trainer.total_epochs == 0 trainer.fit(*data[1], epochs=20) assert trainer.total_iterations == 20 assert trainer.total_epochs == 20 assert isinstance(trainer.training_info, pd.DataFrame) assert 'i_epoch' in trainer.training_info.columns ret = trainer.to_namedtuple() assert isinstance(ret, trainer.results_tuple) train_set = DataLoader(TensorDataset(*data[1])) with pytest.raises(RuntimeError, match='parameter <training_dataset> is exclusive of <x_train> and <y_train>'): trainer.fit(*data[1], training_dataset=train_set) with pytest.raises(RuntimeError, match='missing parameter <x_train> or <y_train>'): trainer.fit(data[1][0])
def test_trainer_fit_3(data): model = deepcopy(data[0]) trainer = Trainer(model=model, optimizer=Adam(), loss_func=MSELoss(), epochs=5) trainer.fit(*data[1]) assert len(trainer.checkpoints.keys()) == 0 trainer.reset() assert trainer.total_iterations == 0 assert trainer.total_epochs == 0 assert len(trainer.get_checkpoint()) == 0 trainer.fit(*data[1], checkpoint=True) assert len(trainer.get_checkpoint()) == 5 assert isinstance(trainer.get_checkpoint(2), trainer.checkpoint_tuple) assert isinstance(trainer.get_checkpoint('cp_2'), trainer.checkpoint_tuple) with pytest.raises(TypeError, match='parameter <cp> must be str or int'): trainer.get_checkpoint([]) trainer.reset(to=3, remove_checkpoints=False) assert len(trainer.get_checkpoint()) == 5 assert isinstance(trainer.get_checkpoint(2), trainer.checkpoint_tuple) assert isinstance(trainer.get_checkpoint('cp_2'), trainer.checkpoint_tuple) trainer.reset(to='cp_3') assert trainer.total_iterations == 0 assert trainer.total_epochs == 0 assert len(trainer.get_checkpoint()) == 0 with pytest.raises( TypeError, match='parameter <to> must be torch.nnModule, int, or str'): trainer.reset(to=[]) # todo: need a real testing trainer.fit(*data[1], checkpoint=True) trainer.predict(*data[1], checkpoint=3) trainer.reset() trainer.fit(*data[1], checkpoint=lambda i: (True, f'new:{i}')) assert len(trainer.get_checkpoint()) == 5 assert trainer.get_checkpoint() == list([f'new:{i + 1}' for i in range(5)])