def test_save(self, valid_tree):

        model = Model.load(valid_tree)
        model.save(valid_tree / 'nowhere')

        assert _model_equal(model, Model.load(valid_tree / 'nowhere'))

        model.id = 'new_id'
        with pytest.raises(OSError):
            model.save(valid_tree / 'nowhere')
        with pytest.raises(OSError):
            model.save(valid_tree / 'nowhere', force=True)

        model.training._end_epoch = 0
        model.save(valid_tree / 'nowhere_now')

        model.training._end_epoch = 10

        with pytest.raises(OSError):
            model.save(valid_tree / 'nowhere_now')
        with pytest.raises(OSError):
            model.save(valid_tree / 'nowhere_now', force=True)

        model.initialisation.training._end_epoch = 0
        model.save(valid_tree / 'nowhere_again')

        model.initialisation.training._end_epoch = 10

        with pytest.raises(OSError):
            model.save(valid_tree / 'nowhere_again')

        model.save(valid_tree / 'nowhere_again', force=True)
        assert _model_equal(model, Model.load(valid_tree / 'nowhere_again'))
    def test_checkpoint(self, valid_tree):
        model = Model('my-producer', 'py_pa', '1.0.0', 'my-model',
                      'my-model-id', valid_tree / 'my_config_file.yaml',
                      {'my_param': 'my_value'})

        assert len(model.checkpoint_collection) == 0
        checkpoint_1 = Checkpoint(
            'first', valid_tree / 'data' / 'checkpoints' / '1.weight', 1)
        checkpoint_6 = Checkpoint(
            'sixth', valid_tree / 'data' / 'checkpoints' / '6.weight', 6)
        model.add_checkpoint(checkpoint_1)
        assert len(model.checkpoint_collection) == 1
        assert model.checkpoint_collection['first'] == checkpoint_1
        model.add_checkpoint(checkpoint_6)
        assert len(model.checkpoint_collection) == 2
        assert model.checkpoint_collection['sixth'] == checkpoint_6

        model = Model('my-producer', 'py_pa', '1.0.0', 'my-model',
                      'my-model-id', valid_tree / 'my_config_file.yaml',
                      {'my_param': 'my_value'})

        assert len(model.checkpoint_collection) == 0
        checkpoint_1 = Checkpoint(
            'first', valid_tree / 'data' / 'checkpoints' / '1.weight', 1)
        checkpoint_6 = Checkpoint(
            'sixth', valid_tree / 'data' / 'checkpoints' / '6.weight', 6)
        model.add_checkpoint(checkpoint_1, checkpoint_6)
        assert len(model.checkpoint_collection) == 2
        assert model.checkpoint_collection['first'] == checkpoint_1
        assert model.checkpoint_collection['sixth'] == checkpoint_6
    def test_save_extra_on_invalid(self, valid_tree_extra, invalid_tree):

        model = Model.load(valid_tree_extra)
        with pytest.raises(OSError):
            model.save(invalid_tree)

        model.save(invalid_tree, force=True)

        assert _model_equal(model, Model.load(invalid_tree))
    def test_initialisation(self, valid_tree):
        model = Model('my-producer', 'py_pa', '1.0.0', 'my-model',
                      'my-model-id', valid_tree / 'my_config_file.yaml',
                      {'my_param': 'my_value'})

        assert model.initialisation is None
        model.register_initialisation(valid_tree, checkpoint_reference=6)
        assert isinstance(model.initialisation, Model)
        assert model.initialisation.checkpoint == 6
    def test_save_valid_on_invalid_strict(self, valid_tree,
                                          invalid_tree_strict):

        model = Model.load(valid_tree)
        with pytest.raises(OSError):
            model.save(invalid_tree_strict)

        model.save(invalid_tree_strict, force=True)

        assert _model_equal(model, Model.load(invalid_tree_strict))
    def test_creation(self, valid_tree):
        model = Model('my-producer', 'py_pa', '1.0.0', 'my-model',
                      'my-model-id', valid_tree / 'my_config_file.yaml',
                      {'my_param': 'my_value'})

        assert model.name == 'my-model'
        assert model.id == 'my-model-id'
        assert Producer('my-producer', 'py_pa', '1.0.0', valid_tree /
                        'my_config_file.yaml').strict_equals(model.producer)
        assert isinstance(model.checkpoint_collection, CheckpointCollection)
        assert len(model.checkpoint_collection) == 0
        assert isinstance(model.training, Training)
        assert model.training.is_pending
        assert model.initialisation is None
        assert model.path is None
        assert model.checkpoint is None
    def test_save_extra_on_valid(self, valid_tree, valid_tree_extra):

        model = Model.load(valid_tree_extra)
        model.save(valid_tree)

        assert _model_equal(model, Model.load(valid_tree))
    def test_save_valid_on_empty(self, valid_tree, valid_tree_empty):

        model = Model.load(valid_tree)
        model.save(valid_tree_empty)

        assert _model_equal(model, Model.load(valid_tree_empty))
    def test_load_extra(self, valid_tree_extra):
        reference_model = Model('a_producer_name', 'py_pa', '1.0.0',
                                'my_model', 'some_id',
                                valid_tree_extra / 'my_config_file.yaml', {})

        reference_model._initialisation = Model(
            'a_producer_name', 'py_pa', '1.0.0', 'my_init_model',
            'some_init_id', valid_tree_extra / 'data' / 'initialisation' /
            'my_config_file.yaml', {})

        reference_model.initialisation._initialisation = \
            Checkpoint('my_init_file',
                       valid_tree_extra / 'data' / 'initialisation' / 'data' / 'initialisation' / 'init.weight',
                       hash='d41d8cd98f00b204e9800998ecf8427e')

        checkpoint_1 = Checkpoint(1,
                                  valid_tree_extra / 'data' / 'checkpoints' /
                                  '1.weight',
                                  3,
                                  hash='d41d8cd98f00b204e9800998ecf8427e')
        checkpoint_6 = Checkpoint(6,
                                  valid_tree_extra / 'data' / 'checkpoints' /
                                  '6.weight',
                                  10,
                                  hash='cfcd208495d565ef66e7dff9f98764da')

        init_checkpoint_6 = Checkpoint(6,
                                       epoch=10,
                                       hash='cfcd208495d565ef66e7dff9f98764da')

        init_checkpoint_7 = Checkpoint(7,
                                       epoch=9,
                                       hash='cfcd208495d565ef66e7dff9f98764da')

        reference_model.add_checkpoint(checkpoint_1, checkpoint_6)
        reference_model.initialisation.add_checkpoint(init_checkpoint_6,
                                                      init_checkpoint_7)

        reference_model.training = Training(
            **{
                'status': 'finished',
                'start_epoch': 0,
                'start_time': 0,
                'latest_epoch': 10,
                'latest_time': 10,
                'end_epoch': 10,
                'end_time': 10
            })

        reference_model.initialisation.training = Training(
            **{
                'status': 'finished',
                'start_epoch': 0,
                'start_time': 0,
                'latest_epoch': 10,
                'latest_time': 10,
                'end_epoch': 10,
                'end_time': 10
            })

        assert _model_equal(Model.load(valid_tree_extra), reference_model)
    def test_training(self, valid_tree):
        model = Model('my-producer', 'py_pa', '1.0.0', 'my-model',
                      'my-model-id', valid_tree / 'my_config_file.yaml',
                      {'my_param': 'my_value'})

        assert model.training.is_pending
        model.register_training_start(5)
        assert model.training.is_running
        assert model.training.start_epoch == 5
        assert model.training.latest_epoch == 5
        model.register_epoch()
        assert model.training.latest_epoch == 6
        model.register_epoch(15)
        assert model.training.latest_epoch == 15
        model.register_epoch()
        assert model.training.latest_epoch == 16
        model.register_training_end(success=True)
        assert model.training.is_finished
        assert model.training.end_epoch == 16

        model = Model('my-producer', 'py_pa', '1.0.0', 'my-model',
                      'my-model-id', valid_tree / 'my_config_file.yaml',
                      {'my_param': 'my_value'})

        assert model.training.is_pending
        model.register_training_start(5)
        assert model.training.is_running
        assert model.training.start_epoch == 5
        assert model.training.latest_epoch == 5
        model.register_epoch()
        assert model.training.latest_epoch == 6
        model.register_epoch(15)
        assert model.training.latest_epoch == 15
        model.register_epoch()
        assert model.training.latest_epoch == 16
        model.register_training_end(success=False)
        assert model.training.is_failed
        assert model.training.end_epoch == 16