예제 #1
0
def test_no_numpy(mock_no_numpy):

    with pytest.raises(RuntimeError,
                       match=r"This module requires numpy to be installed."):
        FID()

    with pytest.raises(RuntimeError,
                       match=r"fid_score requires numpy to be installed."):
        fid_score(0, 0, 0, 0)
예제 #2
0
def test_compute_fid_sqrtm():
    mu1 = torch.tensor([0, 0])
    mu2 = torch.tensor([0, 0])

    sigma1 = torch.tensor([[-1, 1], [1, 1]], dtype=torch.float64)
    sigma2 = torch.tensor([[1, 0], [0, 1]], dtype=torch.float64)

    with pytest.raises(ValueError, match=r"Imaginary component "):
        fid_score(mu1, mu2, sigma1, sigma2)

    sigma1 = torch.ones(
        (2, 2), dtype=torch.float64) * torch.finfo(torch.float64).max
    sigma2 = torch.tensor([[1, 0.5], [0, 0.5]], dtype=torch.float64)

    assert torch.isinf(torch.tensor(fid_score(mu1, mu2, sigma1, sigma2)))
예제 #3
0
def test_fid_function():
    train_samples, test_samples = torch.rand(10, 10), torch.rand(10, 10)

    mu1, sigma1 = train_samples.mean(axis=0), cov(train_samples, rowvar=False)
    mu2, sigma2 = test_samples.mean(axis=0), cov(test_samples, rowvar=False)

    sigma1 = torch.tensor(sigma1, dtype=torch.float64)
    sigma2 = torch.tensor(sigma2, dtype=torch.float64)
    assert pytest.approx(
        fid_score(mu1, mu2, sigma1, sigma2),
        rel=1e-5) == pytorch_fid_score.calculate_frechet_distance(
            mu1, sigma1, mu2, sigma2)