Exemplo n.º 1
0
def test_trainer_2(data):
    trainer = Trainer()
    with pytest.raises(RuntimeError, match='no model for training'):
        trainer.fit(*data[1])

    with pytest.raises(
            TypeError,
            match='parameter `m` must be a instance of <torch.nn.modules>'):
        trainer.model = {}

    trainer.model = data[0]
    assert isinstance(trainer.model, torch.nn.Module)
    with pytest.raises(RuntimeError, match='no loss function for training'):
        trainer.fit(*data[1])

    trainer.loss_func = MSELoss()
    assert trainer.loss_type == 'train_mse_loss'
    assert trainer.loss_func.__class__ == MSELoss
    with pytest.raises(RuntimeError, match='no optimizer for training'):
        trainer.fit(*data[1])

    trainer.optimizer = Adam()
    assert isinstance(trainer.optimizer, torch.optim.Adam)
    assert isinstance(trainer._optimizer_state, dict)
    assert isinstance(trainer._init_states, dict)

    trainer.lr_scheduler = ExponentialLR(gamma=0.99)
    assert isinstance(trainer.lr_scheduler,
                      torch.optim.lr_scheduler.ExponentialLR)
Exemplo n.º 2
0
def test_trainer_3(data):
    model = data[0]
    trainer = Trainer(model=model, optimizer=Adam(), loss_func=MSELoss())
    assert isinstance(trainer.model, torch.nn.Module)
    assert isinstance(trainer.optimizer, torch.optim.Adam)
    assert isinstance(trainer._optimizer_state, dict)
    assert isinstance(trainer._init_states, dict)
    assert trainer.clip_grad is None
    assert trainer.lr_scheduler is None

    trainer.lr_scheduler = ExponentialLR(gamma=0.1)
    assert isinstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ExponentialLR)

    trainer.optimizer = SGD()
    assert isinstance(trainer.optimizer, torch.optim.SGD)

    trainer.clip_grad = ClipNorm(max_norm=0.4)
    assert isinstance(trainer.clip_grad, ClipNorm)