Beispiel #1
0
    def _test(metric_device):
        n_iters = 60
        s = 16
        offset = n_iters * s

        n_features = 10

        y_pred = torch.rand(offset * idist.get_world_size(), n_features)
        y_true = torch.rand(offset * idist.get_world_size(), n_features)

        def update(_, i):
            return (
                y_pred[i * s + rank * offset:(i + 1) * s + rank * offset, :],
                y_true[i * s + rank * offset:(i + 1) * s + rank * offset, :],
            )

        engine = Engine(update)
        m = FID(num_features=n_features,
                feature_extractor=torch.nn.Identity(),
                device=metric_device)
        m.attach(engine, "fid")

        engine.run(data=list(range(n_iters)), max_epochs=1)

        assert "fid" in engine.state.metrics

        evaluator = pytorch_fid_score.calculate_frechet_distance
        mu1, sigma1 = y_pred.mean(axis=0).to("cpu"), cov(y_pred.to("cpu"),
                                                         rowvar=False)
        mu2, sigma2 = y_true.mean(axis=0).to("cpu"), cov(y_true.to("cpu"),
                                                         rowvar=False)
        assert pytest.approx(evaluator(mu1, sigma1, mu2, sigma2),
                             rel=1e-5) == m.compute()
Beispiel #2
0
def test_compute_fid_from_features():
    train_samples, test_samples = torch.rand(10, 10), torch.rand(10, 10)

    fid_scorer = FID(num_features=10, feature_extractor=torch.nn.Identity())
    fid_scorer.update([train_samples[:5], test_samples[:5]])
    fid_scorer.update([train_samples[5:], test_samples[5:]])

    mu1, sigma1 = train_samples.mean(axis=0), cov(train_samples, rowvar=False)
    mu2, sigma2 = test_samples.mean(axis=0), cov(test_samples, rowvar=False)

    assert (pytest.approx(pytorch_fid_score.calculate_frechet_distance(
        mu1, sigma1, mu2, sigma2),
                          rel=1e-5) == fid_scorer.compute())
Beispiel #3
0
def test_no_numpy(mock_no_numpy):

    with pytest.raises(RuntimeError,
                       match=r"This module requires numpy to be installed."):
        FID()

    with pytest.raises(RuntimeError,
                       match=r"fid_score requires numpy to be installed."):
        fid_score(0, 0, 0, 0)
Beispiel #4
0
def test_statistics():
    train_samples, test_samples = torch.rand(10, 10), torch.rand(10, 10)
    fid_scorer = FID(num_features=10, feature_extractor=torch.nn.Identity())
    fid_scorer.update([train_samples[:5], test_samples[:5]])
    fid_scorer.update([train_samples[5:], test_samples[5:]])

    mu1, sigma1 = train_samples.mean(axis=0), torch.tensor(
        cov(train_samples, rowvar=False))
    mu2, sigma2 = test_samples.mean(axis=0), torch.tensor(
        cov(test_samples, rowvar=False))

    fid_mu1 = fid_scorer._train_total / fid_scorer._num_examples
    fid_sigma1 = fid_scorer._get_covariance(fid_scorer._train_sigma,
                                            fid_scorer._train_total)

    fid_mu2 = fid_scorer._test_total / fid_scorer._num_examples
    fid_sigma2 = fid_scorer._get_covariance(fid_scorer._test_sigma,
                                            fid_scorer._test_total)

    assert torch.isclose(mu1.double(), fid_mu1).all()
    for cov1, cov2 in zip(sigma1, fid_sigma1):
        assert torch.isclose(cov1.double(), cov2, rtol=1e-04, atol=1e-04).all()

    assert torch.isclose(mu2.double(), fid_mu2).all()
    for cov1, cov2 in zip(sigma2, fid_sigma2):
        assert torch.isclose(cov1.double(), cov2, rtol=1e-04, atol=1e-04).all()
Beispiel #5
0
def test_wrong_inputs():

    with pytest.raises(ValueError,
                       match=r"Argument num_features must be greater to zero"):
        FID(num_features=-1, feature_extractor=torch.nn.Identity())

    with pytest.raises(
            ValueError,
            match=r"feature_extractor output must be a tensor of dim 2, got: 1"
    ):
        FID(num_features=1, feature_extractor=torch.nn.Identity()).update(
            torch.Tensor([[], []]))

    with pytest.raises(ValueError,
                       match=r"Batch size should be greater than one, got: 0"):
        FID(num_features=1,
            feature_extractor=torch.nn.Identity()).update(torch.rand(2, 0, 0))

    with pytest.raises(
            ValueError,
            match=
            r"num_features returned by feature_extractor should be 1, got: 0"):
        FID(num_features=1,
            feature_extractor=torch.nn.Identity()).update(torch.rand(2, 2, 0))

    err_str = (
        "Number of Training Features and Testing Features should be equal (torch.Size([9, 2]) != torch.Size([5, 2]))"
    )
    with pytest.raises(
            ValueError,
            match=re.escape(err_str),
    ):
        FID(num_features=2, feature_extractor=torch.nn.Identity()).update(
            (torch.rand(9, 2), torch.rand(5, 2)))

    with pytest.raises(
            TypeError,
            match=r"Argument feature_extractor must be of type torch.nn.Module"
    ):
        FID(num_features=1, feature_extractor=lambda x: x)

    with pytest.raises(
            ValueError,
            match=
            r"Argument num_features must be provided, if feature_extractor is specified."
    ):
        FID(feature_extractor=torch.nn.Identity())