コード例 #1
0
def test_compute_mean_std():
    n = 8
    b = 12
    c = 3
    w = h = 64
    true_data = np.arange(0, n * b * h * w * c, dtype="float64").reshape(
        n * b, c, h, w) - (n * b * c * w * h * 0.75)
    mean = true_data.transpose((0, 2, 3, 1)).reshape(-1, c).mean(axis=0)
    std = true_data.transpose((0, 2, 3, 1)).reshape(-1, c).std(axis=0)

    train_loader = torch.from_numpy(true_data).reshape(n, b, c, h, w)

    def compute_mean_std(engine, batch):
        _b, _c = batch.shape[:2]
        data = batch.reshape(_b, _c, -1).to(dtype=torch.float64)
        _mean = torch.mean(data, dim=-1)
        _mean2 = torch.mean(data**2, dim=-1)
        return {"mean": _mean, "mean^2": _mean2}

    compute_engine = Engine(compute_mean_std)
    img_mean = Average(output_transform=lambda output: output["mean"])
    img_mean2 = Average(output_transform=lambda output: output["mean^2"])
    img_mean.attach(compute_engine, "mean")
    img_mean2.attach(compute_engine, "mean2")
    state = compute_engine.run(train_loader)
    state.metrics["std"] = torch.sqrt(state.metrics["mean2"] -
                                      state.metrics["mean"]**2)

    np.testing.assert_almost_equal(state.metrics["mean"].numpy(),
                                   mean,
                                   decimal=7)
    np.testing.assert_almost_equal(state.metrics["std"].numpy(),
                                   std,
                                   decimal=5)
コード例 #2
0
def test_average():

    with pytest.raises(NotComputableError):
        v = Average()
        v.compute()

    mean_var = Average()
    y_true = torch.rand(100) + torch.randint(0, 10, size=(100, )).float()

    for y in y_true:
        mean_var.update(y.item())

    m = mean_var.compute()
    assert m.item() == pytest.approx(y_true.mean().item())

    mean_var = Average()
    y_true = torch.rand(100, 10) + torch.randint(0, 10, size=(100, 10)).float()
    for y in y_true:
        mean_var.update(y)

    m = mean_var.compute()
    assert m.numpy() == pytest.approx(y_true.mean(dim=0).numpy())

    mean_var = Average()
    y_true = torch.rand(8, 16, 10) + torch.randint(0, 10,
                                                   size=(8, 16, 10)).float()
    for y in y_true:
        mean_var.update(y)

    m = mean_var.compute()
    assert m.numpy() == pytest.approx(
        y_true.reshape(-1, 10).mean(dim=0).numpy())
コード例 #3
0
def _test_distrib_average(device):

    with pytest.raises(NotComputableError):
        v = Average(device=device)
        v.compute()

    mean_var = Average(device=device)
    y_true = torch.rand(100, dtype=torch.float64) + torch.randint(
        0, 10, size=(100, )).double()
    y_true = y_true.to(device)

    for y in y_true:
        mean_var.update(y)

    m = mean_var.compute()

    y_true = idist.all_reduce(y_true)
    assert m.item() == pytest.approx(y_true.mean().item() /
                                     idist.get_world_size())

    mean_var = Average(device=device)
    y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(
        0, 10, size=(100, 10)).double()
    y_true = y_true.to(device)

    for y in y_true:
        mean_var.update(y)

    m = mean_var.compute()

    y_true = idist.all_reduce(y_true)
    np.testing.assert_almost_equal(m.cpu().numpy(),
                                   y_true.mean(dim=0).cpu().numpy() /
                                   idist.get_world_size(),
                                   decimal=5)