def test_no_numpy(mock_no_numpy): with pytest.raises(RuntimeError, match=r"This module requires numpy to be installed."): FID() with pytest.raises(RuntimeError, match=r"fid_score requires numpy to be installed."): fid_score(0, 0, 0, 0)
def test_compute_fid_sqrtm(): mu1 = torch.tensor([0, 0]) mu2 = torch.tensor([0, 0]) sigma1 = torch.tensor([[-1, 1], [1, 1]], dtype=torch.float64) sigma2 = torch.tensor([[1, 0], [0, 1]], dtype=torch.float64) with pytest.raises(ValueError, match=r"Imaginary component "): fid_score(mu1, mu2, sigma1, sigma2) sigma1 = torch.ones( (2, 2), dtype=torch.float64) * torch.finfo(torch.float64).max sigma2 = torch.tensor([[1, 0.5], [0, 0.5]], dtype=torch.float64) assert torch.isinf(torch.tensor(fid_score(mu1, mu2, sigma1, sigma2)))
def test_fid_function(): train_samples, test_samples = torch.rand(10, 10), torch.rand(10, 10) mu1, sigma1 = train_samples.mean(axis=0), cov(train_samples, rowvar=False) mu2, sigma2 = test_samples.mean(axis=0), cov(test_samples, rowvar=False) sigma1 = torch.tensor(sigma1, dtype=torch.float64) sigma2 = torch.tensor(sigma2, dtype=torch.float64) assert pytest.approx( fid_score(mu1, mu2, sigma1, sigma2), rel=1e-5) == pytorch_fid_score.calculate_frechet_distance( mu1, sigma1, mu2, sigma2)