Beispiel #1
0
def _test_distrib_geom_average(device):

    with pytest.raises(NotComputableError):
        v = GeometricAverage(device=device)
        v.compute()

    mean_var = GeometricAverage(device=device)
    y_true = torch.rand(100, dtype=torch.float64) + torch.randint(
        0, 10, size=(100, )).double()
    y_true = y_true.to(device)

    for y in y_true:
        mean_var.update(y)

    m = mean_var.compute()
    log_y_true = torch.log(y_true)
    log_y_true = idist.all_reduce(log_y_true)
    assert m.item() == pytest.approx(
        torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).item())

    mean_var = GeometricAverage(device=device)
    y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(
        0, 10, size=(100, 10)).double()
    y_true = y_true.to(device)

    for y in y_true:
        mean_var.update(y)

    m = mean_var.compute()
    log_y_true = torch.log(y_true)
    log_y_true = idist.all_reduce(log_y_true)
    np.testing.assert_almost_equal(m.cpu().numpy(),
                                   torch.exp(
                                       log_y_true.mean(dim=0) /
                                       idist.get_world_size()).cpu().numpy(),
                                   decimal=5)
Beispiel #2
0
def test_geom_average():

    with pytest.raises(NotComputableError):
        v = GeometricAverage()
        v.compute()

    mean_var = GeometricAverage()
    y_true = torch.rand(100) + torch.randint(0, 10, size=(100, )).float()

    for y in y_true:
        mean_var.update(y.item())

    m = mean_var.compute()
    assert m.item() == pytest.approx(_geom_mean(y_true))

    mean_var = GeometricAverage()
    y_true = torch.rand(100, 10) + torch.randint(0, 10, size=(100, 10)).float()
    for y in y_true:
        mean_var.update(y)

    m = mean_var.compute()
    np.testing.assert_almost_equal(m.numpy(), _geom_mean(y_true), decimal=5)

    mean_var = GeometricAverage()
    y_true = torch.rand(8, 16, 10) + torch.randint(0, 10,
                                                   size=(8, 16, 10)).float()
    for y in y_true:
        mean_var.update(y)

    m = mean_var.compute()
    np.testing.assert_almost_equal(m.numpy(),
                                   _geom_mean(y_true.reshape(-1, 10)),
                                   decimal=5)