Ejemplo n.º 1
0
def _test_ddp_cat(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetric()
    dummy._reductions = {"foo": torch.cat}
    dummy.foo = [tensor([1])]
    dummy._sync_dist()

    assert torch.all(torch.eq(dummy.foo, tensor([1, 1])))
Ejemplo n.º 2
0
def _test_ddp_gather_uneven_tensors(rank, worldsize):
    setup_ddp(rank, worldsize)
    tensor = torch.ones(rank)
    result = gather_all_tensors(tensor)
    assert len(result) == worldsize
    for idx in range(worldsize):
        assert len(result[idx]) == idx
        assert (result[idx] == torch.ones_like(result[idx])).all()
Ejemplo n.º 3
0
def _test_ddp_sum(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetric()
    dummy._reductions = {"foo": torch.sum}
    dummy.foo = tensor(1)
    dummy._sync_dist()

    assert dummy.foo == worldsize
Ejemplo n.º 4
0
def _test_ddp_compositional_tensor(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetricSum()
    dummy._reductions = {"x": torch.sum}
    dummy = dummy.clone() + dummy.clone()
    dummy.update(tensor(1))
    val = dummy.compute()
    assert val == 2 * worldsize
Ejemplo n.º 5
0
def _test_ddp_gather_uneven_tensors_multidim(rank, worldsize):
    setup_ddp(rank, worldsize)
    tensor = torch.ones(rank + 1, 2 - rank)
    result = gather_all_tensors(tensor)
    assert len(result) == worldsize
    for idx in range(worldsize):
        val = result[idx]
        assert val.shape == (idx + 1, 2 - idx)
        assert (val == torch.ones_like(val)).all()
Ejemplo n.º 6
0
def _test_ddp_sum_cat(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetric()
    dummy._reductions = {"foo": torch.cat, "bar": torch.sum}
    dummy.foo = [torch.tensor([1])]
    dummy.bar = torch.tensor(1)
    dummy._sync_dist()

    assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
    assert dummy.bar == worldsize
Ejemplo n.º 7
0
def _test_non_contiguous_tensors(rank, worldsize):
    setup_ddp(rank, worldsize)

    class DummyCatMetric(Metric):
        def __init__(self):
            super().__init__()
            self.add_state("x", default=[], dist_reduce_fx=None)

        def update(self, x):
            self.x.append(x)

        def compute(self):
            x = torch.cat(self.x, dim=0)
            return x.sum()

    metric = DummyCatMetric()
    metric.update(torch.randn(10, 5)[:, 0])
Ejemplo n.º 8
0
def _test_state_dict_is_synced(rank, worldsize, tmpdir):
    setup_ddp(rank, worldsize)

    class DummyCatMetric(Metric):
        def __init__(self):
            super().__init__()
            self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum)
            self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum)

        def update(self, x):
            self.x += x
            self.c += 1

        def compute(self):
            return self.x // self.c

    metric = DummyCatMetric()
    metric.persistent(True)

    steps = 5
    for i in range(steps):
        metric(i)
        state_dict = metric.state_dict()

        sum = i * (i + 1) / 2
        assert state_dict["x"] == sum * worldsize
        assert metric.x == sum
        assert metric.c == (i + 1)
        assert state_dict["c"] == metric.c * worldsize

    def reload_state_dict(state_dict, expected_x, expected_c):
        metric = DummyCatMetric()
        metric.load_state_dict(state_dict)
        assert metric.x == expected_x
        assert metric.c == expected_c

    with mock.patch.dict(os.environ, {"GLOBAL_RANK": str(rank)}):
        reload_state_dict(deepcopy(state_dict), 20 if not rank else 0,
                          10 if not rank else 0)

    reload_state_dict(deepcopy(state_dict), 20, 10)
Ejemplo n.º 9
0
def _test_state_dict_is_synced(rank, worldsize, tmpdir):
    setup_ddp(rank, worldsize)

    class DummyCatMetric(Metric):
        def __init__(self):
            super().__init__()
            self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum)
            self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum)

        def update(self, x):
            self.x += x
            self.c += 1

        def compute(self):
            return self.x // self.c

        def __repr__(self):
            return f"DummyCatMetric(x={self.x}, c={self.c})"

    metric = DummyCatMetric()
    metric.persistent(True)

    def verify_metric(metric, i, world_size):
        state_dict = metric.state_dict()
        exp_sum = i * (i + 1) / 2
        assert state_dict["x"] == exp_sum * world_size
        assert metric.x == exp_sum * world_size
        assert metric.c == (i + 1) * world_size
        assert state_dict["c"] == metric.c

    steps = 5
    for i in range(steps):

        if metric._is_synced:

            with pytest.raises(
                    TorchMetricsUserError,
                    match="The Metric shouldn't be synced when performing"):
                metric(i)

            metric.unsync()

        metric(i)

        verify_metric(metric, i, 1)

        metric.sync()
        assert metric._is_synced

        with pytest.raises(TorchMetricsUserError,
                           match="The Metric has already been synced."):
            metric.sync()

        verify_metric(metric, i, 2)

        metric.unsync()
        assert not metric._is_synced

        with pytest.raises(TorchMetricsUserError,
                           match="The Metric has already been un-synced."):
            metric.unsync()

        with metric.sync_context():
            assert metric._is_synced
            verify_metric(metric, i, 2)

        with metric.sync_context(should_unsync=False):
            assert metric._is_synced
            verify_metric(metric, i, 2)

        assert metric._is_synced

        metric.unsync()
        assert not metric._is_synced

        metric.sync()
        cache = metric._cache
        metric._cache = None

        with pytest.raises(
                TorchMetricsUserError,
                match="The internal cache should exist to unsync the Metric."):
            metric.unsync()

        metric._cache = cache

    def reload_state_dict(state_dict, expected_x, expected_c):
        metric = DummyCatMetric()
        metric.load_state_dict(state_dict)
        assert metric.x == expected_x
        assert metric.c == expected_c

    reload_state_dict(deepcopy(metric.state_dict()), 20, 10)

    metric.unsync()
    reload_state_dict(deepcopy(metric.state_dict()), 10, 5)

    metric.sync()

    filepath = os.path.join(tmpdir, f'weights-{rank}.pt')

    torch.save(metric.state_dict(), filepath)

    metric.unsync()
    with metric.sync_context():
        torch.save(metric.state_dict(), filepath)