Ejemplo n.º 1
0
def test_trainer_fit_4(data):
    model = deepcopy(data[0])
    trainer = Trainer(model=model,
                      optimizer=Adam(),
                      loss_func=MSELoss(),
                      clip_grad=ClipValue(0.1),
                      lr_scheduler=ReduceLROnPlateau(),
                      epochs=10)

    count = 1
    for i in trainer(*data[1]):
        assert isinstance(i, dict)
        assert i['i_epoch'] == count
        if count == 3:
            trainer.early_stop('stop')
        count += 1

    assert trainer.total_epochs == 3
    assert trainer._early_stopping == (True, 'stop')

    trainer.reset()
    train_set = DataLoader(TensorDataset(*data[1]))
    count = 1
    for i in trainer(training_dataset=train_set):
        assert isinstance(i, dict)
        assert i['i_batch'] == count
        if count == 3:
            trainer.early_stop('stop!!!')
        count += 1
    assert trainer.total_iterations == 3
    assert trainer._early_stopping == (True, 'stop!!!')
Ejemplo n.º 2
0
def test_trainer_fit_1(data):
    model = deepcopy(data[0])
    trainer = Trainer(model=model, optimizer=Adam(), loss_func=MSELoss())
    trainer.fit(*data[1])
    assert trainer.total_iterations == 200
    assert trainer.total_epochs == 200

    trainer.fit(*data[1], epochs=20)
    assert trainer.total_iterations == 220
    assert trainer.total_epochs == 220

    trainer.reset()
    assert trainer.total_iterations == 0
    assert trainer.total_epochs == 0

    trainer.fit(*data[1], epochs=20)
    assert trainer.total_iterations == 20
    assert trainer.total_epochs == 20

    assert isinstance(trainer.training_info, pd.DataFrame)
    assert 'i_epoch' in trainer.training_info.columns

    ret = trainer.to_namedtuple()
    assert isinstance(ret, trainer.results_tuple)

    train_set = DataLoader(TensorDataset(*data[1]))
    with pytest.raises(RuntimeError, match='parameter <training_dataset> is exclusive of <x_train> and <y_train>'):
        trainer.fit(*data[1], training_dataset=train_set)

    with pytest.raises(RuntimeError, match='missing parameter <x_train> or <y_train>'):
        trainer.fit(data[1][0])
Ejemplo n.º 3
0
def test_trainer_fit_3(data):
    model = deepcopy(data[0])
    trainer = Trainer(model=model,
                      optimizer=Adam(),
                      loss_func=MSELoss(),
                      epochs=5)
    trainer.fit(*data[1])
    assert len(trainer.checkpoints.keys()) == 0

    trainer.reset()
    assert trainer.total_iterations == 0
    assert trainer.total_epochs == 0
    assert len(trainer.get_checkpoint()) == 0

    trainer.fit(*data[1], checkpoint=True)
    assert len(trainer.get_checkpoint()) == 5
    assert isinstance(trainer.get_checkpoint(2), trainer.checkpoint_tuple)
    assert isinstance(trainer.get_checkpoint('cp_2'), trainer.checkpoint_tuple)

    with pytest.raises(TypeError, match='parameter <cp> must be str or int'):
        trainer.get_checkpoint([])

    trainer.reset(to=3, remove_checkpoints=False)
    assert len(trainer.get_checkpoint()) == 5
    assert isinstance(trainer.get_checkpoint(2), trainer.checkpoint_tuple)
    assert isinstance(trainer.get_checkpoint('cp_2'), trainer.checkpoint_tuple)

    trainer.reset(to='cp_3')
    assert trainer.total_iterations == 0
    assert trainer.total_epochs == 0
    assert len(trainer.get_checkpoint()) == 0

    with pytest.raises(
            TypeError,
            match='parameter <to> must be torch.nnModule, int, or str'):
        trainer.reset(to=[])

    # todo: need a real testing
    trainer.fit(*data[1], checkpoint=True)
    trainer.predict(*data[1], checkpoint=3)

    trainer.reset()
    trainer.fit(*data[1], checkpoint=lambda i: (True, f'new:{i}'))
    assert len(trainer.get_checkpoint()) == 5
    assert trainer.get_checkpoint() == list([f'new:{i + 1}' for i in range(5)])