def _test(metric_device):
        engine = Engine(update)

        m = RootMeanSquaredError(device=metric_device)
        m.attach(engine, "rmse")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=1)

        assert "rmse" in engine.state.metrics
        res = engine.state.metrics["rmse"]

        y_preds_full = []
        for i in range(idist.get_world_size()):
            y_preds_full.append((i + 1) * torch.ones(offset))
        y_preds_full = torch.stack(y_preds_full).to(device).flatten()

        true_res = np.sqrt(np.mean(np.square((y_true - y_preds_full).cpu().numpy())))

        assert pytest.approx(res, rel=tol) == true_res
예제 #2
0
def _test_distrib_itegration(device):
    import numpy as np
    import torch.distributed as dist

    from ignite.engine import Engine

    rank = dist.get_rank()
    n_iters = 100
    s = 50
    offset = n_iters * s

    y_true = torch.arange(0, offset * dist.get_world_size(), dtype=torch.float).to(
        device
    )
    y_preds = (rank + 1) * torch.ones(offset, dtype=torch.float).to(device)

    def update(engine, i):
        return (
            y_preds[i * s : (i + 1) * s],
            y_true[i * s + offset * rank : (i + 1) * s + offset * rank],
        )

    engine = Engine(update)

    m = RootMeanSquaredError(device=device)
    m.attach(engine, "rmse")

    data = list(range(n_iters))
    engine.run(data=data, max_epochs=1)

    assert "rmse" in engine.state.metrics
    res = engine.state.metrics["rmse"]

    y_preds_full = []
    for i in range(dist.get_world_size()):
        y_preds_full.append((i + 1) * torch.ones(offset))
    y_preds_full = torch.stack(y_preds_full).to(device).flatten()

    true_res = np.sqrt(np.mean(np.square((y_true - y_preds_full).cpu().numpy())))

    assert pytest.approx(res) == true_res