def _test(metric_device): with pytest.raises(NotComputableError): v = GeometricAverage(device=metric_device) v.compute() decimal = 5 if device.type != "xla" else 4 mean_var = GeometricAverage(device=metric_device) y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() y_true = y_true.to(device) for y in y_true: mean_var.update(y) m = mean_var.compute() log_y_true = torch.log(y_true) log_y_true = idist.all_reduce(log_y_true) np.testing.assert_almost_equal( m, torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).item(), decimal=decimal ) mean_var = GeometricAverage(device=metric_device) y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() y_true = y_true.to(device) for y in y_true: mean_var.update(y) m = mean_var.compute() log_y_true = torch.log(y_true) log_y_true = idist.all_reduce(log_y_true) np.testing.assert_almost_equal( m.cpu().numpy(), torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).cpu().numpy(), decimal=decimal )
def test_geom_average(): with pytest.raises(NotComputableError): v = GeometricAverage() v.compute() mean_var = GeometricAverage() y_true = torch.rand(100) + torch.randint(0, 10, size=(100, )).float() for y in y_true: mean_var.update(y.item()) m = mean_var.compute() assert m.item() == pytest.approx(_geom_mean(y_true)) mean_var = GeometricAverage() y_true = torch.rand(100, 10) + torch.randint(0, 10, size=(100, 10)).float() for y in y_true: mean_var.update(y) m = mean_var.compute() np.testing.assert_almost_equal(m.numpy(), _geom_mean(y_true), decimal=5) mean_var = GeometricAverage() y_true = torch.rand(8, 16, 10) + torch.randint(0, 10, size=(8, 16, 10)).float() for y in y_true: mean_var.update(y) m = mean_var.compute() np.testing.assert_almost_equal(m.numpy(), _geom_mean(y_true.reshape(-1, 10)), decimal=5)
def _test_distrib_geom_average(device): import torch.distributed as dist with pytest.raises(NotComputableError): v = GeometricAverage(device=device) v.compute() mean_var = GeometricAverage(device=device) y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() y_true = y_true.to(device) for y in y_true: mean_var.update(y) m = mean_var.compute() log_y_true = torch.log(y_true) dist.all_reduce(log_y_true) assert m.item() == pytest.approx(torch.exp(log_y_true.mean(dim=0) / dist.get_world_size()).item()) mean_var = GeometricAverage(device=device) y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() y_true = y_true.to(device) for y in y_true: mean_var.update(y) m = mean_var.compute() log_y_true = torch.log(y_true) dist.all_reduce(log_y_true) np.testing.assert_almost_equal( m.cpu().numpy(), torch.exp(log_y_true.mean(dim=0) / dist.get_world_size()).cpu().numpy(), decimal=5 )