def test_trainer_2(data): trainer = Trainer() with pytest.raises(RuntimeError, match='no model for training'): trainer.fit(*data[1]) with pytest.raises( TypeError, match='parameter `m` must be a instance of <torch.nn.modules>'): trainer.model = {} trainer.model = data[0] assert isinstance(trainer.model, torch.nn.Module) with pytest.raises(RuntimeError, match='no loss function for training'): trainer.fit(*data[1]) trainer.loss_func = MSELoss() assert trainer.loss_type == 'train_mse_loss' assert trainer.loss_func.__class__ == MSELoss with pytest.raises(RuntimeError, match='no optimizer for training'): trainer.fit(*data[1]) trainer.optimizer = Adam() assert isinstance(trainer.optimizer, torch.optim.Adam) assert isinstance(trainer._optimizer_state, dict) assert isinstance(trainer._init_states, dict) trainer.lr_scheduler = ExponentialLR(gamma=0.99) assert isinstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ExponentialLR)
def test_trainer_3(data): model = data[0] trainer = Trainer(model=model, optimizer=Adam(), loss_func=MSELoss()) assert isinstance(trainer.model, torch.nn.Module) assert isinstance(trainer.optimizer, torch.optim.Adam) assert isinstance(trainer._optimizer_state, dict) assert isinstance(trainer._init_states, dict) assert trainer.clip_grad is None assert trainer.lr_scheduler is None trainer.lr_scheduler = ExponentialLR(gamma=0.1) assert isinstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ExponentialLR) trainer.optimizer = SGD() assert isinstance(trainer.optimizer, torch.optim.SGD) trainer.clip_grad = ClipNorm(max_norm=0.4) assert isinstance(trainer.clip_grad, ClipNorm)