def test_average(): with pytest.raises(NotComputableError): v = Average() v.compute() mean_var = Average() y_true = torch.rand(100) + torch.randint(0, 10, size=(100, )).float() for y in y_true: mean_var.update(y.item()) m = mean_var.compute() assert m.item() == pytest.approx(y_true.mean().item()) mean_var = Average() y_true = torch.rand(100, 10) + torch.randint(0, 10, size=(100, 10)).float() for y in y_true: mean_var.update(y) m = mean_var.compute() assert m.numpy() == pytest.approx(y_true.mean(dim=0).numpy()) mean_var = Average() y_true = torch.rand(8, 16, 10) + torch.randint(0, 10, size=(8, 16, 10)).float() for y in y_true: mean_var.update(y) m = mean_var.compute() assert m.numpy() == pytest.approx( y_true.reshape(-1, 10).mean(dim=0).numpy())
def _test_distrib_average(device): with pytest.raises(NotComputableError): v = Average(device=device) v.compute() mean_var = Average(device=device) y_true = torch.rand(100, dtype=torch.float64) + torch.randint( 0, 10, size=(100, )).double() y_true = y_true.to(device) for y in y_true: mean_var.update(y) m = mean_var.compute() y_true = idist.all_reduce(y_true) assert m.item() == pytest.approx(y_true.mean().item() / idist.get_world_size()) mean_var = Average(device=device) y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint( 0, 10, size=(100, 10)).double() y_true = y_true.to(device) for y in y_true: mean_var.update(y) m = mean_var.compute() y_true = idist.all_reduce(y_true) np.testing.assert_almost_equal(m.cpu().numpy(), y_true.mean(dim=0).cpu().numpy() / idist.get_world_size(), decimal=5)