示例#1
0
def test_persist_1(data):
    model = deepcopy(data[0])
    trainer = Trainer(model=model,
                      optimizer=Adam(lr=0.1),
                      loss_func=MSELoss(),
                      epochs=200)
    trainer.extend(TensorConverter(), Persist('model_dir'))
    trainer.fit(*data[1], *data[1])

    persist = trainer['persist']
    checker = persist._checker
    assert isinstance(persist, Persist)
    assert isinstance(checker.model, torch.nn.Module)
    assert isinstance(checker.describe, dict)
    assert isinstance(checker.files, list)
    assert set(checker.files) == {
        'model', 'init_state', 'model_structure', 'describe', 'training_info',
        'final_state'
    }

    trainer = Trainer.load(checker)
    assert isinstance(trainer.training_info, pd.DataFrame)
    assert isinstance(trainer.model, torch.nn.Module)
    assert isinstance(trainer._training_info, list)
    assert trainer.optimizer is None
    assert trainer.lr_scheduler is None
    assert trainer.x_val is None
    assert trainer.y_val is None
    assert trainer.validate_dataset is None
    assert trainer._optimizer_state is None
    assert trainer.total_epochs == 0
    assert trainer.total_iterations == 0
    assert trainer.loss_type is None
    assert trainer.loss_func is None

    trainer = Trainer.load(from_=checker.path,
                           optimizer=Adam(),
                           loss_func=MSELoss(),
                           lr_scheduler=ExponentialLR(gamma=0.99),
                           clip_grad=ClipValue(clip_value=0.1))
    assert isinstance(trainer._scheduler, ExponentialLR)
    assert isinstance(trainer._optim, Adam)
    assert isinstance(trainer.clip_grad, ClipValue)
    assert isinstance(trainer.loss_func, MSELoss)
示例#2
0
def test_persist_save_checkpoints(data):

    class _Trainer(BaseRunner):

        def __init__(self):
            super().__init__()
            self.model = SequentialLinear(50, 2)

        def predict(self, x_, y_):  # noqa
            return x_, y_

    cp_1 = Trainer.checkpoint_tuple(
        id='cp_1',
        iterations=111,
        model_state=SequentialLinear(50, 2).state_dict(),
    )
    cp_2 = Trainer.checkpoint_tuple(
        id='cp_2',
        iterations=111,
        model_state=SequentialLinear(50, 2).state_dict(),
    )

    # save checkpoint
    p = Persist('test_model_1', increment=False, only_best_states=False)
    p.before_proc(trainer=_Trainer())
    p.on_checkpoint(cp_1, trainer=_Trainer())
    p.on_checkpoint(cp_2, trainer=_Trainer())
    assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' / 'cp_1.pth.s').exists()
    assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' / 'cp_2.pth.s').exists()

    # reduced save checkpoint
    p = Persist('test_model_2', increment=False, only_best_states=True)
    p.before_proc(trainer=_Trainer())
    p.on_checkpoint(cp_1, trainer=_Trainer())
    p.on_checkpoint(cp_2, trainer=_Trainer())
    assert (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp.pth.s').exists()
    assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp_1.pth.s').exists()
    assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp_2.pth.s').exists()

    # no checkpoint will be saved
    p = Persist('test_model_3', increment=False, only_best_states=True)
    p.before_proc(trainer=_Trainer())
    p.on_checkpoint(cp_2, trainer=_Trainer())
    assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' / 'cp.pth.s').exists()
    assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' / 'cp_2.pth.s').exists()
示例#3
0
def test_persist_1(data):

    class _Trainer(BaseRunner):

        def __init__(self):
            super().__init__()
            self.model = SequentialLinear(50, 2)

        def predict(self, x_, y_):  # noqa
            return x_, y_

    p = Persist()

    with pytest.raises(ValueError, match='can not access property `path` before training'):
        p.path

    p.before_proc(trainer=_Trainer())
    assert p.path == str(Path('.').resolve() / Path(os.getcwd()).name)
    with pytest.raises(ValueError, match='can not reset property `path` after training'):
        p.path = 'aa'

    p = Persist('test_model')
    p.before_proc(trainer=_Trainer())
    assert p.path == str(Path('.').resolve() / 'test_model')
    assert (Path('.').resolve() / 'test_model' / 'describe.pkl.z').exists()
    assert (Path('.').resolve() / 'test_model' / 'init_state.pth.s').exists()
    assert (Path('.').resolve() / 'test_model' / 'model.pth.m').exists()
    assert (Path('.').resolve() / 'test_model' / 'model_structure.pkl.z').exists()

    p = Persist('test_model', increment=True)
    p.before_proc(trainer=_Trainer())
    assert p.path == str(Path('.').resolve() / 'test_model@1')
    assert (Path('.').resolve() / 'test_model@1' / 'describe.pkl.z').exists()
    assert (Path('.').resolve() / 'test_model@1' / 'init_state.pth.s').exists()
    assert (Path('.').resolve() / 'test_model@1' / 'model.pth.m').exists()
    assert (Path('.').resolve() / 'test_model@1' / 'model_structure.pkl.z').exists()