Exemple #1
0
def test_validator_1(data):
    model = deepcopy(data[0])
    trainer = Trainer(model=model, optimizer=Adam(lr=0.1), loss_func=MSELoss(), epochs=20)
    trainer.extend(TensorConverter(), Validator('regress', early_stop=30, trace_order=1, warming_up=0, mae=0))
    trainer.fit(*data[1], *data[1])
    assert trainer.get_checkpoint() == ['mae']

    model = deepcopy(data[0])
    trainer = Trainer(model=model, optimizer=Adam(lr=0.1), loss_func=MSELoss(), epochs=20)
    trainer.extend(TensorConverter(), Validator('regress', early_stop=30, trace_order=5, warming_up=50, mae=0))
    trainer.fit(*data[1], *data[1])
    assert trainer.get_checkpoint() == []
Exemple #2
0
def test_trainer_fit_3(data):
    model = deepcopy(data[0])
    trainer = Trainer(model=model,
                      optimizer=Adam(),
                      loss_func=MSELoss(),
                      epochs=5)
    trainer.fit(*data[1])
    assert len(trainer.checkpoints.keys()) == 0

    trainer.reset()
    assert trainer.total_iterations == 0
    assert trainer.total_epochs == 0
    assert len(trainer.get_checkpoint()) == 0

    trainer.fit(*data[1], checkpoint=True)
    assert len(trainer.get_checkpoint()) == 5
    assert isinstance(trainer.get_checkpoint(2), trainer.checkpoint_tuple)
    assert isinstance(trainer.get_checkpoint('cp_2'), trainer.checkpoint_tuple)

    with pytest.raises(TypeError, match='parameter <cp> must be str or int'):
        trainer.get_checkpoint([])

    trainer.reset(to=3, remove_checkpoints=False)
    assert len(trainer.get_checkpoint()) == 5
    assert isinstance(trainer.get_checkpoint(2), trainer.checkpoint_tuple)
    assert isinstance(trainer.get_checkpoint('cp_2'), trainer.checkpoint_tuple)

    trainer.reset(to='cp_3')
    assert trainer.total_iterations == 0
    assert trainer.total_epochs == 0
    assert len(trainer.get_checkpoint()) == 0

    with pytest.raises(
            TypeError,
            match='parameter <to> must be torch.nnModule, int, or str'):
        trainer.reset(to=[])

    # todo: need a real testing
    trainer.fit(*data[1], checkpoint=True)
    trainer.predict(*data[1], checkpoint=3)

    trainer.reset()
    trainer.fit(*data[1], checkpoint=lambda i: (True, f'new:{i}'))
    assert len(trainer.get_checkpoint()) == 5
    assert trainer.get_checkpoint() == list([f'new:{i + 1}' for i in range(5)])