Beispiel #1
0
    def _test(to_save, obj, name):
        save_handler = MagicMock()
        save_handler.remove = MagicMock()

        checkpointer = Checkpoint(to_save,
                                  save_handler=save_handler,
                                  score_name="loss",
                                  score_function=lambda e: e.state.score)

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1, score=-0.77)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        save_handler.assert_called_with(obj,
                                        "{}_loss=-0.7700.pth".format(name))

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        trainer.state.score = -0.76

        checkpointer(trainer)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(obj,
                                        "{}_loss=-0.7600.pth".format(name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with(
            "{}_loss=-0.7700.pth".format(name))
        assert checkpointer.last_checkpoint == "{}_loss=-0.7600.pth".format(
            name)
Beispiel #2
0
    def _test(to_save, obj, name, score_name=None):
        save_handler = MagicMock()
        save_handler.remove = MagicMock()

        checkpointer = Checkpoint(to_save,
                                  save_handler=save_handler,
                                  score_name=score_name,
                                  score_function=lambda e: e.state.epoch)

        if score_name is None:
            score_name = ""
        else:
            score_name += "="

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        save_handler.assert_called_with(obj,
                                        "{}_{}1.pth".format(name, score_name))

        trainer.state.epoch = 12
        trainer.state.iteration = 1234

        checkpointer(trainer)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(obj,
                                        "{}_{}12.pth".format(name, score_name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_{}1.pth".format(
            name, score_name))
        assert checkpointer.last_checkpoint == "{}_{}12.pth".format(
            name, score_name)
Beispiel #3
0
def test_mlflow_bad_metric_name_handling(dirname):
    import mlflow

    true_values = [123.0, 23.4, 333.4]
    with MLflowLogger(str(dirname / "mlruns")) as mlflow_logger:

        active_run = mlflow.active_run()

        handler = OutputHandler(tag="training", metric_names="all")
        engine = Engine(lambda e, b: None)
        engine.state = State(metrics={"metric:0 in %": 123.0, "metric 0": 1000.0})

        with pytest.warns(UserWarning, match=r"MLflowLogger output_handler encountered an invalid metric name"):

            engine.state.epoch = 1
            handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED)

            for _, v in enumerate(true_values):
                engine.state.epoch += 1
                engine.state.metrics["metric 0"] = v
                handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED)

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=str(dirname / "mlruns"))
    stored_values = client.get_metric_history(active_run.info.run_id, "training metric 0")

    for t, s in zip([1000.0] + true_values, stored_values):
        assert t == s.value
Beispiel #4
0
    def _test(filename_prefix, to_save, obj, name):
        save_handler = MagicMock()
        save_handler.remove = MagicMock()

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            filename_prefix=filename_prefix,
            global_step_transform=lambda e, _: e.state.epoch,
        )

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        if len(filename_prefix) > 0:
            filename_prefix += "_"

        save_handler.assert_called_with(
            obj, "{}{}_1.pth".format(filename_prefix, name))

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        checkpointer(trainer)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(
            obj, "{}{}_12.pth".format(filename_prefix, name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}{}_1.pth".format(
            filename_prefix, name))
        assert checkpointer.last_checkpoint == "{}{}_12.pth".format(
            filename_prefix, name)
Beispiel #5
0
    def _test(to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(to_save,
                                  save_handler=save_handler,
                                  score_name="loss",
                                  score_function=lambda e: e.state.score)

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1, score=-0.77)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        metadata = {"basename": name, "score_name": "loss", "priority": -0.77}
        save_handler.assert_called_with(obj, "{}_loss=-0.7700.pt".format(name),
                                        metadata)

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        trainer.state.score = -0.76

        checkpointer(trainer)
        assert save_handler.call_count == 2
        metadata["priority"] = -0.76
        save_handler.assert_called_with(obj, "{}_loss=-0.7600.pt".format(name),
                                        metadata)
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with(
            "{}_loss=-0.7700.pt".format(name))
        assert checkpointer.last_checkpoint == "{}_loss=-0.7600.pt".format(
            name)
Beispiel #6
0
def _test_neptune_saver_integration(device):

    model = torch.nn.Module().to(device)
    to_save_serializable = {"model": model}

    mock_logger = None
    if idist.get_rank() == 0:
        mock_logger = MagicMock(spec=NeptuneLogger)
        mock_logger.log_artifact = MagicMock()
        mock_logger.delete_artifacts = MagicMock()

    saver = NeptuneSaver(mock_logger)

    checkpoint = Checkpoint(to_save=to_save_serializable,
                            save_handler=saver,
                            n_saved=1)

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)
    checkpoint(trainer)
    trainer.state.iteration = 1
    checkpoint(trainer)
    if idist.get_rank() == 0:
        assert mock_logger.log_artifact.call_count == 2
        assert mock_logger.delete_artifacts.call_count == 1
Beispiel #7
0
    def _test(to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(to_save, save_handler=save_handler)
        assert checkpointer.last_checkpoint is None

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=0, iteration=0)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        metadata = {"basename": name, "score_name": None, "priority": 0}
        save_handler.assert_called_with(obj, "{}_0.pt".format(name), metadata)

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        checkpointer(trainer)
        assert save_handler.call_count == 2
        metadata["priority"] = 1234
        save_handler.assert_called_with(obj, "{}_1234.pt".format(name),
                                        metadata)
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_0.pt".format(name))
        assert checkpointer.last_checkpoint == "{}_1234.pt".format(name)
Beispiel #8
0
def test_load_checkpoint_with_different_num_classes(dirname):
    model = DummyPretrainedModel()
    to_save_single_object = {"model": model}

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)

    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    handler(trainer, to_save_single_object)

    fname = handler.last_checkpoint
    loaded_checkpoint = torch.load(fname)

    to_load_single_object = {"pretrained_features": model.features}

    with pytest.raises(RuntimeError):
        Checkpoint.load_objects(to_load_single_object, loaded_checkpoint)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        Checkpoint.load_objects(to_load_single_object,
                                loaded_checkpoint,
                                strict=False,
                                blah="blah")

    loaded_weights = to_load_single_object["pretrained_features"].state_dict(
    )["weight"]

    assert torch.all(model.state_dict()["features.weight"].eq(loaded_weights))
Beispiel #9
0
def test_base_output_handler_setup_output_metrics():

    engine = Engine(lambda engine, batch: None)
    true_metrics = {"a": 0, "b": 1}
    engine.state = State(metrics=true_metrics)
    engine.state.output = 12345

    # Only metric_names
    handler = DummyOutputHandler("tag", metric_names=['a', 'b'], output_transform=None)
    metrics = handler._setup_output_metrics(engine=engine)
    assert metrics == true_metrics

    # Only metric_names with a warning
    handler = DummyOutputHandler("tag", metric_names=['a', 'c'], output_transform=None)
    with pytest.warns(UserWarning):
        metrics = handler._setup_output_metrics(engine=engine)
    assert metrics == {"a": 0}

    # Only output as "output"
    handler = DummyOutputHandler("tag", metric_names=None, output_transform=lambda x: x)
    metrics = handler._setup_output_metrics(engine=engine)
    assert metrics == {"output": engine.state.output}

    # Only output as "loss"
    handler = DummyOutputHandler("tag", metric_names=None, output_transform=lambda x: {"loss": x})
    metrics = handler._setup_output_metrics(engine=engine)
    assert metrics == {"loss": engine.state.output}

    # Metrics and output
    handler = DummyOutputHandler("tag", metric_names=['a', 'b'], output_transform=lambda x: {"loss": x})
    metrics = handler._setup_output_metrics(engine=engine)
    assert metrics == {"a": 0, "b": 1, "loss": engine.state.output}
Beispiel #10
0
    def _test(to_save, obj, name, score_name=None):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(
            to_save, save_handler=save_handler, score_name=score_name, score_function=lambda e: e.state.epoch
        )

        if score_name is None:
            score_name = ""
        else:
            score_name += "="

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        metadata = {"basename": name, "score_name": score_name[:-1] if len(score_name) > 0 else None, "priority": 1}
        save_handler.assert_called_with(obj, "{}_{}1.pt".format(name, score_name), metadata)

        trainer.state.epoch = 12
        trainer.state.iteration = 1234

        checkpointer(trainer)
        assert save_handler.call_count == 2
        metadata["priority"] = 12
        save_handler.assert_called_with(obj, "{}_{}12.pt".format(name, score_name), metadata)
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_{}1.pt".format(name, score_name))
        assert checkpointer.last_checkpoint == "{}_{}12.pt".format(name, score_name)
Beispiel #11
0
    def _test(filename_prefix, to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            filename_prefix=filename_prefix,
            global_step_transform=lambda e, _: e.state.epoch,
        )

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=2, iteration=1)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        if len(filename_prefix) > 0:
            filename_prefix += "_"

        metadata = {"basename": "{}{}".format(filename_prefix, name), "score_name": None, "priority": 2}
        save_handler.assert_called_with(obj, "{}{}_2.pt".format(filename_prefix, name), metadata)

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        checkpointer(trainer)
        assert save_handler.call_count == 2
        metadata["priority"] = 12
        save_handler.assert_called_with(obj, "{}{}_12.pt".format(filename_prefix, name), metadata)
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}{}_2.pt".format(filename_prefix, name))
        assert checkpointer.last_checkpoint == "{}{}_12.pt".format(filename_prefix, name)
Beispiel #12
0
def test_checkpoint_last_checkpoint_on_score():
    save_handler = MagicMock()
    save_handler.remove = MagicMock()
    to_save = {"model": DummyModel()}

    checkpointer = Checkpoint(
        to_save,
        save_handler=save_handler,
        n_saved=None,
        score_name="val_acc",
        score_function=lambda e: e.state.metrics["val_acc"],
    )

    trainer = Engine(lambda e, b: None)

    val_acc = 0.0
    for i in range(10):
        val_acc = i * 0.1
        trainer.state = State(epoch=1,
                              iteration=i,
                              metrics={"val_acc": val_acc})
        checkpointer(trainer)

    assert save_handler.call_count == 10
    assert checkpointer.last_checkpoint == "{}_val_acc=0.9000.pth".format(
        "model")
Beispiel #13
0
    def _test(ext, require_empty, archived):
        previous_fname = os.path.join(
            dirname, "{}_{}_{}{}".format(_PREFIX, "obj", 1, ext))
        with open(previous_fname, "w") as f:
            f.write("test")

        h = ModelCheckpoint(dirname,
                            _PREFIX,
                            create_dir=True,
                            require_empty=require_empty,
                            archived=archived)
        engine = Engine(lambda e, b: None)
        engine.state = State(epoch=0, iteration=1)

        model = DummyModel()
        to_save = {"model": model}
        h(engine, to_save)

        fname = h.last_checkpoint
        ext = ".pth.tar" if archived else ".pth"
        assert isinstance(fname, str)
        assert os.path.join(dirname,
                            "{}_{}_{}{}".format(_PREFIX, "model", 1,
                                                ext)) == fname
        assert os.path.exists(fname)
        assert os.path.exists(previous_fname)
        loaded_objects = torch.load(fname)
        assert loaded_objects == model.state_dict()
        os.remove(fname)
Beispiel #14
0
def test_best_k(dirname):
    scores = iter([1.2, -2.0, 3.1, -4.0])

    def score_function(_):
        return next(scores)

    h = ModelCheckpoint(dirname,
                        _PREFIX,
                        create_dir=False,
                        n_saved=2,
                        score_function=score_function)

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}
    for _ in range(4):
        h(engine, to_save)

    expected = [
        "{}_{}_{:.4f}.pth".format(_PREFIX, "model", i) for i in [1.2, 3.1]
    ]

    assert sorted(os.listdir(dirname)) == expected
def test_trains_disk_saver_integration_no_logger():
    model = torch.nn.Module()
    to_save_serializable = {"model": model}

    with pytest.warns(
            UserWarning,
            match="TrainsSaver created a temporary checkpoints directory"):
        trains.Task.current_task = Mock(return_value=object())
        trains.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock(
        )
        trains_saver = TrainsSaver()
        checkpoint = Checkpoint(to_save=to_save_serializable,
                                save_handler=trains_saver,
                                n_saved=1)

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)
    checkpoint(trainer)
    trainer.state.iteration = 1
    checkpoint(trainer)

    if trains_saver._atomic:
        assert trains.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2
    else:
        saved_files = list(os.listdir(trains_saver.dirname))
        assert len(saved_files) == 1
        assert saved_files[0] == "model_1.pt"
Beispiel #16
0
def test_best_k_with_suffix(dirname):
    scores = [0.3456789, 0.1234, 0.4567, 0.134567]
    scores_iter = iter(scores)

    def score_function(engine):
        return next(scores_iter)

    h = ModelCheckpoint(dirname,
                        _PREFIX,
                        create_dir=False,
                        n_saved=2,
                        score_function=score_function,
                        score_name="val_loss")

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}
    for _ in range(4):
        engine.state.epoch += 1
        h(engine, to_save)

    expected = [
        "{}_{}_val_loss={:.4}.pth".format(_PREFIX, "model", scores[e - 1])
        for e in [1, 3]
    ]

    assert sorted(os.listdir(dirname)) == expected
Beispiel #17
0
    def _test(
        to_save,
        filename_prefix="",
        score_function=None,
        score_name=None,
        global_step_transform=None,
        filename_pattern=None,
    ):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            filename_prefix=filename_prefix,
            score_function=score_function,
            score_name=score_name,
            global_step_transform=global_step_transform,
            filename_pattern=filename_pattern,
        )

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=12, iteration=203, score=0.9999)

        checkpointer(trainer)
        return checkpointer.last_checkpoint
Beispiel #18
0
def test_base_output_handler_with_another_engine():
    engine = Engine(lambda engine, batch: None)
    true_metrics = {"a": 0, "b": 1}
    engine.state = State(metrics=true_metrics)
    engine.state.output = 12345

    with pytest.warns(DeprecationWarning, match="Use of another_engine is deprecated"):
        handler = DummyOutputHandler("tag", metric_names=['a', 'b'], output_transform=None, another_engine=engine)
Beispiel #19
0
def test__setup_engine():
    engine = Engine(lambda e, b: 1)
    engine.state = State(iteration=10, epoch=1, max_epochs=100, epoch_length=100)

    data = list(range(100))
    engine.state.dataloader = data
    engine._setup_engine()
    assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10
Beispiel #20
0
    def _test(to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        trainer = Engine(lambda e, b: None)
        evaluator = Engine(lambda e, b: None)
        trainer.state = State(epoch=11, iteration=1)

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            global_step_transform=lambda _1, _2: trainer.state.epoch,
            score_name="val_acc",
            score_function=lambda e: e.state.metrics["val_acc"],
        )

        evaluator.state = State(epoch=1,
                                iteration=1000,
                                metrics={"val_acc": 0.77})

        checkpointer(evaluator)
        assert save_handler.call_count == 1

        metadata = {
            "basename": name,
            "score_name": "val_acc",
            "priority": 0.77
        }
        save_handler.assert_called_with(obj,
                                        "{}_11_val_acc=0.7700.pt".format(name),
                                        metadata)

        trainer.state.epoch = 12
        evaluator.state.metrics["val_acc"] = 0.78

        checkpointer(evaluator)
        assert save_handler.call_count == 2
        metadata["priority"] = 0.78
        save_handler.assert_called_with(obj,
                                        "{}_12_val_acc=0.7800.pt".format(name),
                                        metadata)
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with(
            "{}_11_val_acc=0.7700.pt".format(name))
        assert checkpointer.last_checkpoint == "{}_12_val_acc=0.7800.pt".format(
            name)
Beispiel #21
0
def test_base_output_handler_setup_output_state_attrs():
    engine = Engine(lambda engine, batch: None)
    true_metrics = {"a": 0, "b": 1}
    engine.state = State(metrics=true_metrics)
    engine.state.alpha = 3.899
    engine.state.beta = torch.tensor(5.499)
    engine.state.gamma = torch.tensor([2106.0, 6.0])
    engine.state.output = 12345

    # Only State Attributes
    handler = DummyOutputHandler(tag="tag",
                                 metric_names=None,
                                 output_transform=None,
                                 state_attributes=["alpha", "beta", "gamma"])
    state_attrs = handler._setup_output_metrics_state_attrs(engine=engine,
                                                            key_tuple=False)
    assert state_attrs == {
        "tag/alpha": 3.899,
        "tag/beta": torch.tensor(5.499),
        "tag/gamma/0": 2106.0,
        "tag/gamma/1": 6.0,
    }

    # Metrics and Attributes
    handler = DummyOutputHandler(tag="tag",
                                 metric_names=["a", "b"],
                                 output_transform=None,
                                 state_attributes=["alpha", "beta", "gamma"])
    state_attrs = handler._setup_output_metrics_state_attrs(engine=engine,
                                                            key_tuple=False)
    assert state_attrs == {
        "tag/a": 0,
        "tag/b": 1,
        "tag/alpha": 3.899,
        "tag/beta": torch.tensor(5.499),
        "tag/gamma/0": 2106.0,
        "tag/gamma/1": 6.0,
    }

    # Metrics, Attributes and output
    handler = DummyOutputHandler(
        tag="tag",
        metric_names="all",
        output_transform=lambda x: {"loss": x},
        state_attributes=["alpha", "beta", "gamma"],
    )
    state_attrs = handler._setup_output_metrics_state_attrs(engine=engine,
                                                            key_tuple=False)
    assert state_attrs == {
        "tag/a": 0,
        "tag/b": 1,
        "tag/alpha": 3.899,
        "tag/beta": torch.tensor(5.499),
        "tag/gamma/0": 2106.0,
        "tag/gamma/1": 6.0,
        "tag/loss": engine.state.output,
    }
Beispiel #22
0
def test_checkpoint_load_objects_from_saved_file(dirname):
    def _get_single_obj_to_save():
        model = DummyModel()
        to_save = {
            "model": model,
        }
        return to_save

    def _get_multiple_objs_to_save():
        model = DummyModel()
        optim = torch.optim.SGD(model.parameters(), lr=0.001)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5)
        to_save = {
            "model": model,
            "optimizer": optim,
            "lr_scheduler": lr_scheduler,
        }
        return to_save

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)

    # case: multiple objects
    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    to_save = _get_multiple_objs_to_save()
    handler(trainer, to_save)
    fname = handler.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    Checkpoint.load_objects(to_save, loaded_objects)
    os.remove(fname)

    # case: saved multiple objects, loaded single object
    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    to_save = _get_multiple_objs_to_save()
    handler(trainer, to_save)
    fname = handler.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    to_load = {'model': to_save['model']}
    Checkpoint.load_objects(to_load, loaded_objects)
    os.remove(fname)

    # case: single object
    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    to_save = _get_single_obj_to_save()
    handler(trainer, to_save)
    fname = handler.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    Checkpoint.load_objects(to_save, loaded_objects)
Beispiel #23
0
def test_checkpoint_score_function_wrong_output():
    model = DummyModel()
    to_save = {'model': model}

    checkpointer = Checkpoint(to_save, lambda x: x, score_function=lambda e: {"1": 1}, score_name="acc")
    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)
    with pytest.raises(ValueError, match=r"Output of score_function should be a number"):
        checkpointer(trainer)
Beispiel #24
0
def test__setup_engine():
    engine = Engine(lambda e, b: 1)
    engine.state = State(
        iteration=10, epoch=1, max_epochs=100, epoch_length=100, seed=12
    )

    data = list(range(100))
    engine.state.dataloader = data
    engine._setup_engine()
    assert engine._dataloader_len == len(data)
Beispiel #25
0
def test_base_output_handler_setup_output_metrics():

    engine = Engine(lambda engine, batch: None)
    true_metrics = {"a": 0, "b": 1}
    engine.state = State(metrics=true_metrics)
    engine.state.output = 12345

    # Only metric_names
    handler = DummyOutputHandler("tag",
                                 metric_names=["a", "b"],
                                 output_transform=None)
    metrics = handler._setup_output_metrics_state_attrs(engine=engine,
                                                        key_tuple=False)
    assert metrics == {"tag/a": 0, "tag/b": 1}

    # Only metric_names with a warning
    handler = DummyOutputHandler("tag",
                                 metric_names=["a", "c"],
                                 output_transform=None)
    with pytest.warns(UserWarning):
        metrics = handler._setup_output_metrics_state_attrs(engine=engine,
                                                            key_tuple=False)
    assert metrics == {"tag/a": 0}

    # Only output as "output"
    handler = DummyOutputHandler("tag",
                                 metric_names=None,
                                 output_transform=lambda x: x)
    metrics = handler._setup_output_metrics_state_attrs(engine=engine,
                                                        key_tuple=False)
    assert metrics == {"tag/output": engine.state.output}

    # Only output as "loss"
    handler = DummyOutputHandler("tag",
                                 metric_names=None,
                                 output_transform=lambda x: {"loss": x})
    metrics = handler._setup_output_metrics_state_attrs(engine=engine,
                                                        key_tuple=False)
    assert metrics == {"tag/loss": engine.state.output}

    # Metrics and output
    handler = DummyOutputHandler("tag",
                                 metric_names=["a", "b"],
                                 output_transform=lambda x: {"loss": x})
    metrics = handler._setup_output_metrics_state_attrs(engine=engine,
                                                        key_tuple=False)
    assert metrics == {"tag/a": 0, "tag/b": 1, "tag/loss": engine.state.output}

    # All metrics
    handler = DummyOutputHandler("tag",
                                 metric_names="all",
                                 output_transform=None)
    metrics = handler._setup_output_metrics_state_attrs(engine=engine,
                                                        key_tuple=False)
    assert metrics == {"tag/a": 0, "tag/b": 1}
Beispiel #26
0
def test_checkpoint_save_handler_callable():
    def save_handler(c, f):
        assert f == "model_12.pt"

    to_save = {"model": DummyModel()}

    checkpointer = Checkpoint(to_save, save_handler=save_handler,)

    trainer = Engine(lambda e, b: None)

    trainer.state = State(epoch=1, iteration=12)
    checkpointer(trainer)
Beispiel #27
0
    def _test(to_save, obj, name):
        save_handler = MagicMock()
        save_handler.remove = MagicMock()

        trainer = Engine(lambda e, b: None)
        evaluator = Engine(lambda e, b: None)
        trainer.state = State(epoch=11, iteration=1)

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            global_step_transform=lambda _1, _2: trainer.state.epoch,
            score_name="val_acc",
            score_function=lambda e: e.state.metrics["val_acc"],
        )

        evaluator.state = State(epoch=1,
                                iteration=1000,
                                metrics={"val_acc": 0.77})

        checkpointer(evaluator)
        assert save_handler.call_count == 1

        save_handler.assert_called_with(obj,
                                        "{}_11_val_acc=0.77.pth".format(name))

        trainer.state.epoch = 12
        evaluator.state.metrics["val_acc"] = 0.78

        checkpointer(evaluator)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(obj,
                                        "{}_12_val_acc=0.78.pth".format(name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with(
            "{}_11_val_acc=0.77.pth".format(name))
        assert checkpointer.last_checkpoint == "{}_12_val_acc=0.78.pth".format(
            name)
Beispiel #28
0
def test_checkpoint_last_checkpoint():
    save_handler = MagicMock(spec=BaseSaveHandler)
    to_save = {"model": DummyModel()}

    checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=None)

    trainer = Engine(lambda e, b: None)

    for i in range(10):
        trainer.state = State(epoch=1, iteration=i)
        checkpointer(trainer)

    assert save_handler.call_count == 10
    assert checkpointer.last_checkpoint == "{}_9.pt".format("model")
Beispiel #29
0
def test_checkpoint_last_checkpoint():
    save_handler = MagicMock()
    save_handler.__call__ = MagicMock()
    model = DummyModel()
    to_save = {'model': model}

    checkpointer = Checkpoint(to_save, save_handler=save_handler)
    assert checkpointer.last_checkpoint is None

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    checkpointer(engine)
    assert checkpointer.last_checkpoint == "model_0.pth"
Beispiel #30
0
def test_valid_state_dict_save(dirname):
    model = DummyModel()
    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    to_save = {"name": 42}
    with pytest.raises(TypeError, match=r"should have `state_dict` method"):
        h(engine, to_save)
    to_save = {"name": model}
    try:
        h(engine, to_save)
    except ValueError:
        pytest.fail("Unexpected ValueError")