Пример #1
0
def test_persist_1(data):
    class _Trainer(BaseRunner):
        def __init__(self):
            super().__init__()
            self.model = SequentialLinear(50, 2)

        def predict(self, x_, y_):  # noqa
            return x_, y_

    p = Persist()

    with pytest.raises(ValueError,
                       match='can not access property `path` before training'):
        p.path

    p.before_proc(trainer=_Trainer())
    assert p.path == str(Path('.').resolve() / Path(os.getcwd()).name)
    with pytest.raises(ValueError,
                       match='can not reset property `path` after training'):
        p.path = 'aa'

    p = Persist('test_model')
    p.before_proc(trainer=_Trainer())
    assert p.path == str(Path('.').resolve() / 'test_model')

    assert (Path('.').resolve() / 'test_model' / 'describe.pkl.z').exists()
    assert (Path('.').resolve() / 'test_model' / 'init_state.pth.s').exists()
    assert (Path('.').resolve() / 'test_model' / 'model.pth.m').exists()
    assert (Path('.').resolve() / 'test_model' /
            'model_structure.pkl.z').exists()

    p = Persist('test_model', increment=True)
    p.before_proc(trainer=_Trainer())
    assert p.path == str(Path('.').resolve() / 'test_model@1')
Пример #2
0
def test_persist_save_checkpoints(data):
    class _Trainer(BaseRunner):
        def __init__(self):
            super().__init__()
            self.model = SequentialLinear(50, 2)

        def predict(self, x_, y_):  # noqa
            return x_, y_

    cp_1 = Trainer.checkpoint_tuple(
        id='cp_1',
        iterations=111,
        model_state=SequentialLinear(50, 2).state_dict(),
    )
    cp_2 = Trainer.checkpoint_tuple(
        id='cp_2',
        iterations=111,
        model_state=SequentialLinear(50, 2).state_dict(),
    )

    # save checkpoint
    p = Persist('test_model_1', increment=False, only_best_states=False)
    p.before_proc(trainer=_Trainer())
    p.on_checkpoint(cp_1)
    p.on_checkpoint(cp_2)
    assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' /
            'cp_1.pth.s').exists()
    assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' /
            'cp_2.pth.s').exists()

    # reduced save checkpoint
    p = Persist('test_model_2', increment=False, only_best_states=True)
    p.before_proc(trainer=_Trainer())
    p.on_checkpoint(cp_1)
    p.on_checkpoint(cp_2)
    assert (Path('.').resolve() / 'test_model_2' / 'checkpoints' /
            'cp.pth.s').exists()
    assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' /
                'cp_1.pth.s').exists()
    assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' /
                'cp_2.pth.s').exists()

    # no checkpoint will be saved
    p = Persist('test_model_3', increment=False, only_best_states=True)
    p.before_proc(trainer=_Trainer())
    p.on_checkpoint(cp_2)
    assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' /
                'cp.pth.s').exists()
    assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' /
                'cp_2.pth.s').exists()