Esempio n. 1
0
def test_format_keys(
    input_key: Union[str, Iterable[str], Dict[str, str]],
    target_key: Union[str, Iterable[str], Dict[str, str]],
    keys: Dict[str, str],
) -> None:
    """Check MetricCallback converts keys correctly"""
    accuracy = AccuracyMetric()
    callback = dl.BatchMetricCallback(
        metric=accuracy, input_key=input_key, target_key=target_key
    )
    assert callback._keys == keys
Esempio n. 2
0
def test_classification_pipeline():
    """
    Test if classification pipeline can run and compute metrics.
    In this test we check that BatchMetricCallback works with
    AccuracyMetric (ICallbackBatchMetric).
    """
    x = torch.rand(NUM_SAMPLES, NUM_FEATURES)
    y = (torch.rand(NUM_SAMPLES) * NUM_CLASSES).long()
    dataset = TensorDataset(x, y)
    loader = DataLoader(dataset, batch_size=64, num_workers=1)

    model = DummyModel(num_features=NUM_FEATURES, num_classes=NUM_CLASSES)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())

    runner = dl.SupervisedRunner(input_key="features",
                                 output_key="logits",
                                 target_key="targets")
    with TemporaryDirectory() as logdir:
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=OrderedDict({
                "train": loader,
                "valid": loader
            }),
            logdir=logdir,
            num_epochs=3,
            verbose=False,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
            callbacks=OrderedDict({
                "classification":
                dl.BatchMetricCallback(
                    metric=AccuracyMetric(num_classes=NUM_CLASSES),
                    input_key="logits",
                    target_key="targets",
                ),
            }),
        )
        assert "accuracy01" in runner.batch_metrics
        assert "accuracy01" in runner.loader_metrics