Beispiel #1
0
def test_evaluation_loader_custom_model() -> None:
    """
    Test if evaluate loader works with custom model.
    """
    dataset = DummyDataset()
    model = nn.Linear(in_features=dataset.features_dim,
                      out_features=dataset.out_dim)
    loader = DataLoader(dataset=dataset, batch_size=1)
    callbacks = [
        dl.AccuracyCallback(input_key="logits",
                            target_key="targets",
                            topk=(1, ))
    ]
    runner = SupervisedRunner()

    runner.evaluate_loader(loader=loader, callbacks=callbacks, model=model)
Beispiel #2
0
def test_evaluation_loader_metrics() -> None:
    """
    Test if metrics in evaluate loader works properly.
    """
    dataset = DummyDataset()
    model = nn.Linear(in_features=dataset.features_dim,
                      out_features=dataset.out_dim)
    loader = DataLoader(dataset=dataset, batch_size=1)
    callbacks = [
        dl.AccuracyCallback(input_key="logits",
                            target_key="targets",
                            topk=(1, ))
    ]
    runner = SupervisedRunner()
    runner.train(
        loaders={
            "train": loader,
            "valid": loader
        },
        model=model,
        num_epochs=1,
        criterion=nn.BCEWithLogitsLoss(),
        callbacks=callbacks,
    )
    runner_internal_metrics = runner.loader_metrics
    evaluate_loader_metrics = runner.evaluate_loader(loader=loader,
                                                     callbacks=callbacks)
    assert runner_internal_metrics["accuracy01"] == evaluate_loader_metrics[
        "accuracy01"]
Beispiel #3
0
def test_evaluation_loader_empty_model() -> None:
    """
    Test if there is no model was given, assertion raises.
    """
    with pytest.raises(AssertionError) as record:
        dataset = DummyDataset()
        loader = DataLoader(dataset=dataset, batch_size=1)
        callbacks = [
            dl.AccuracyCallback(input_key="logits",
                                target_key="targets",
                                topk=(1, ))
        ]
        runner = SupervisedRunner()
        runner.evaluate_loader(loader=loader, callbacks=callbacks, model=None)
        if not record:
            pytest.fail("Expected assertion bacuase model is empty!")