def _test(metric_device): n_iters = 60 s = 16 offset = n_iters * s n_features = 10 y_pred = torch.rand(offset * idist.get_world_size(), n_features) y_true = torch.rand(offset * idist.get_world_size(), n_features) def update(_, i): return ( y_pred[i * s + rank * offset:(i + 1) * s + rank * offset, :], y_true[i * s + rank * offset:(i + 1) * s + rank * offset, :], ) engine = Engine(update) m = FID(num_features=n_features, feature_extractor=torch.nn.Identity(), device=metric_device) m.attach(engine, "fid") engine.run(data=list(range(n_iters)), max_epochs=1) assert "fid" in engine.state.metrics evaluator = pytorch_fid_score.calculate_frechet_distance mu1, sigma1 = y_pred.mean(axis=0).to("cpu"), cov(y_pred.to("cpu"), rowvar=False) mu2, sigma2 = y_true.mean(axis=0).to("cpu"), cov(y_true.to("cpu"), rowvar=False) assert pytest.approx(evaluator(mu1, sigma1, mu2, sigma2), rel=1e-5) == m.compute()
def test_compute_fid_from_features(): train_samples, test_samples = torch.rand(10, 10), torch.rand(10, 10) fid_scorer = FID(num_features=10, feature_extractor=torch.nn.Identity()) fid_scorer.update([train_samples[:5], test_samples[:5]]) fid_scorer.update([train_samples[5:], test_samples[5:]]) mu1, sigma1 = train_samples.mean(axis=0), cov(train_samples, rowvar=False) mu2, sigma2 = test_samples.mean(axis=0), cov(test_samples, rowvar=False) assert (pytest.approx(pytorch_fid_score.calculate_frechet_distance( mu1, sigma1, mu2, sigma2), rel=1e-5) == fid_scorer.compute())
def test_no_numpy(mock_no_numpy): with pytest.raises(RuntimeError, match=r"This module requires numpy to be installed."): FID() with pytest.raises(RuntimeError, match=r"fid_score requires numpy to be installed."): fid_score(0, 0, 0, 0)
def test_statistics(): train_samples, test_samples = torch.rand(10, 10), torch.rand(10, 10) fid_scorer = FID(num_features=10, feature_extractor=torch.nn.Identity()) fid_scorer.update([train_samples[:5], test_samples[:5]]) fid_scorer.update([train_samples[5:], test_samples[5:]]) mu1, sigma1 = train_samples.mean(axis=0), torch.tensor( cov(train_samples, rowvar=False)) mu2, sigma2 = test_samples.mean(axis=0), torch.tensor( cov(test_samples, rowvar=False)) fid_mu1 = fid_scorer._train_total / fid_scorer._num_examples fid_sigma1 = fid_scorer._get_covariance(fid_scorer._train_sigma, fid_scorer._train_total) fid_mu2 = fid_scorer._test_total / fid_scorer._num_examples fid_sigma2 = fid_scorer._get_covariance(fid_scorer._test_sigma, fid_scorer._test_total) assert torch.isclose(mu1.double(), fid_mu1).all() for cov1, cov2 in zip(sigma1, fid_sigma1): assert torch.isclose(cov1.double(), cov2, rtol=1e-04, atol=1e-04).all() assert torch.isclose(mu2.double(), fid_mu2).all() for cov1, cov2 in zip(sigma2, fid_sigma2): assert torch.isclose(cov1.double(), cov2, rtol=1e-04, atol=1e-04).all()
def test_wrong_inputs(): with pytest.raises(ValueError, match=r"Argument num_features must be greater to zero"): FID(num_features=-1, feature_extractor=torch.nn.Identity()) with pytest.raises( ValueError, match=r"feature_extractor output must be a tensor of dim 2, got: 1" ): FID(num_features=1, feature_extractor=torch.nn.Identity()).update( torch.Tensor([[], []])) with pytest.raises(ValueError, match=r"Batch size should be greater than one, got: 0"): FID(num_features=1, feature_extractor=torch.nn.Identity()).update(torch.rand(2, 0, 0)) with pytest.raises( ValueError, match= r"num_features returned by feature_extractor should be 1, got: 0"): FID(num_features=1, feature_extractor=torch.nn.Identity()).update(torch.rand(2, 2, 0)) err_str = ( "Number of Training Features and Testing Features should be equal (torch.Size([9, 2]) != torch.Size([5, 2]))" ) with pytest.raises( ValueError, match=re.escape(err_str), ): FID(num_features=2, feature_extractor=torch.nn.Identity()).update( (torch.rand(9, 2), torch.rand(5, 2))) with pytest.raises( TypeError, match=r"Argument feature_extractor must be of type torch.nn.Module" ): FID(num_features=1, feature_extractor=lambda x: x) with pytest.raises( ValueError, match= r"Argument num_features must be provided, if feature_extractor is specified." ): FID(feature_extractor=torch.nn.Identity())