def test_compute_mean_std(): n = 8 b = 12 c = 3 w = h = 64 true_data = np.arange(0, n * b * h * w * c, dtype="float64").reshape( n * b, c, h, w) - (n * b * c * w * h * 0.75) mean = true_data.transpose((0, 2, 3, 1)).reshape(-1, c).mean(axis=0) std = true_data.transpose((0, 2, 3, 1)).reshape(-1, c).std(axis=0) train_loader = torch.from_numpy(true_data).reshape(n, b, c, h, w) def compute_mean_std(engine, batch): _b, _c = batch.shape[:2] data = batch.reshape(_b, _c, -1).to(dtype=torch.float64) _mean = torch.mean(data, dim=-1) _mean2 = torch.mean(data**2, dim=-1) return {"mean": _mean, "mean^2": _mean2} compute_engine = Engine(compute_mean_std) img_mean = Average(output_transform=lambda output: output["mean"]) img_mean2 = Average(output_transform=lambda output: output["mean^2"]) img_mean.attach(compute_engine, "mean") img_mean2.attach(compute_engine, "mean2") state = compute_engine.run(train_loader) state.metrics["std"] = torch.sqrt(state.metrics["mean2"] - state.metrics["mean"]**2) np.testing.assert_almost_equal(state.metrics["mean"].numpy(), mean, decimal=7) np.testing.assert_almost_equal(state.metrics["std"].numpy(), std, decimal=5)
def test_average(): with pytest.raises(NotComputableError): v = Average() v.compute() mean_var = Average() y_true = torch.rand(100) + torch.randint(0, 10, size=(100, )).float() for y in y_true: mean_var.update(y.item()) m = mean_var.compute() assert m.item() == pytest.approx(y_true.mean().item()) mean_var = Average() y_true = torch.rand(100, 10) + torch.randint(0, 10, size=(100, 10)).float() for y in y_true: mean_var.update(y) m = mean_var.compute() assert m.numpy() == pytest.approx(y_true.mean(dim=0).numpy()) mean_var = Average() y_true = torch.rand(8, 16, 10) + torch.randint(0, 10, size=(8, 16, 10)).float() for y in y_true: mean_var.update(y) m = mean_var.compute() assert m.numpy() == pytest.approx( y_true.reshape(-1, 10).mean(dim=0).numpy())
def _test_distrib_average(device): with pytest.raises(NotComputableError): v = Average(device=device) v.compute() mean_var = Average(device=device) y_true = torch.rand(100, dtype=torch.float64) + torch.randint( 0, 10, size=(100, )).double() y_true = y_true.to(device) for y in y_true: mean_var.update(y) m = mean_var.compute() y_true = idist.all_reduce(y_true) assert m.item() == pytest.approx(y_true.mean().item() / idist.get_world_size()) mean_var = Average(device=device) y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint( 0, 10, size=(100, 10)).double() y_true = y_true.to(device) for y in y_true: mean_var.update(y) m = mean_var.compute() y_true = idist.all_reduce(y_true) np.testing.assert_almost_equal(m.cpu().numpy(), y_true.mean(dim=0).cpu().numpy() / idist.get_world_size(), decimal=5)