コード例 #1
0
def _test_distrib_integration(device):
    import numpy as np
    from ignite.engine import Engine

    rank = idist.get_rank()
    n_iters = 80
    s = 50
    offset = n_iters * s

    y_true = torch.arange(0, offset * idist.get_world_size(), dtype=torch.float).to(device)
    y_preds = torch.ones(offset * idist.get_world_size(), dtype=torch.float).to(device)

    def update(engine, i):
        return (
            y_preds[i * s + offset * rank : (i + 1) * s + offset * rank],
            y_true[i * s + offset * rank : (i + 1) * s + offset * rank],
        )

    engine = Engine(update)

    m = MeanAbsoluteError()
    m.attach(engine, "mae")

    data = list(range(n_iters))
    engine.run(data=data, max_epochs=1)

    assert "mae" in engine.state.metrics
    res = engine.state.metrics["mae"]

    true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy()))

    assert pytest.approx(res) == true_res
コード例 #2
0
    def _test(metric_device):
        engine = Engine(update)

        m = MeanAbsoluteError(device=metric_device)
        m.attach(engine, "mae")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=1)

        assert "mae" in engine.state.metrics
        res = engine.state.metrics["mae"]

        true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy()))

        assert pytest.approx(res) == true_res