def _test_ddp_cat(rank, worldsize): setup_ddp(rank, worldsize) dummy = DummyMetric() dummy._reductions = {"foo": torch.cat} dummy.foo = [tensor([1])] dummy._sync_dist() assert torch.all(torch.eq(dummy.foo, tensor([1, 1])))
def _test_ddp_gather_uneven_tensors(rank, worldsize): setup_ddp(rank, worldsize) tensor = torch.ones(rank) result = gather_all_tensors(tensor) assert len(result) == worldsize for idx in range(worldsize): assert len(result[idx]) == idx assert (result[idx] == torch.ones_like(result[idx])).all()
def _test_ddp_sum(rank, worldsize): setup_ddp(rank, worldsize) dummy = DummyMetric() dummy._reductions = {"foo": torch.sum} dummy.foo = tensor(1) dummy._sync_dist() assert dummy.foo == worldsize
def _test_ddp_compositional_tensor(rank, worldsize): setup_ddp(rank, worldsize) dummy = DummyMetricSum() dummy._reductions = {"x": torch.sum} dummy = dummy.clone() + dummy.clone() dummy.update(tensor(1)) val = dummy.compute() assert val == 2 * worldsize
def _test_ddp_gather_uneven_tensors_multidim(rank, worldsize): setup_ddp(rank, worldsize) tensor = torch.ones(rank + 1, 2 - rank) result = gather_all_tensors(tensor) assert len(result) == worldsize for idx in range(worldsize): val = result[idx] assert val.shape == (idx + 1, 2 - idx) assert (val == torch.ones_like(val)).all()
def _test_ddp_sum_cat(rank, worldsize): setup_ddp(rank, worldsize) dummy = DummyMetric() dummy._reductions = {"foo": torch.cat, "bar": torch.sum} dummy.foo = [torch.tensor([1])] dummy.bar = torch.tensor(1) dummy._sync_dist() assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) assert dummy.bar == worldsize
def _test_non_contiguous_tensors(rank, worldsize): setup_ddp(rank, worldsize) class DummyCatMetric(Metric): def __init__(self): super().__init__() self.add_state("x", default=[], dist_reduce_fx=None) def update(self, x): self.x.append(x) def compute(self): x = torch.cat(self.x, dim=0) return x.sum() metric = DummyCatMetric() metric.update(torch.randn(10, 5)[:, 0])
def _test_state_dict_is_synced(rank, worldsize, tmpdir): setup_ddp(rank, worldsize) class DummyCatMetric(Metric): def __init__(self): super().__init__() self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum) self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum) def update(self, x): self.x += x self.c += 1 def compute(self): return self.x // self.c metric = DummyCatMetric() metric.persistent(True) steps = 5 for i in range(steps): metric(i) state_dict = metric.state_dict() sum = i * (i + 1) / 2 assert state_dict["x"] == sum * worldsize assert metric.x == sum assert metric.c == (i + 1) assert state_dict["c"] == metric.c * worldsize def reload_state_dict(state_dict, expected_x, expected_c): metric = DummyCatMetric() metric.load_state_dict(state_dict) assert metric.x == expected_x assert metric.c == expected_c with mock.patch.dict(os.environ, {"GLOBAL_RANK": str(rank)}): reload_state_dict(deepcopy(state_dict), 20 if not rank else 0, 10 if not rank else 0) reload_state_dict(deepcopy(state_dict), 20, 10)
def _test_state_dict_is_synced(rank, worldsize, tmpdir): setup_ddp(rank, worldsize) class DummyCatMetric(Metric): def __init__(self): super().__init__() self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum) self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum) def update(self, x): self.x += x self.c += 1 def compute(self): return self.x // self.c def __repr__(self): return f"DummyCatMetric(x={self.x}, c={self.c})" metric = DummyCatMetric() metric.persistent(True) def verify_metric(metric, i, world_size): state_dict = metric.state_dict() exp_sum = i * (i + 1) / 2 assert state_dict["x"] == exp_sum * world_size assert metric.x == exp_sum * world_size assert metric.c == (i + 1) * world_size assert state_dict["c"] == metric.c steps = 5 for i in range(steps): if metric._is_synced: with pytest.raises( TorchMetricsUserError, match="The Metric shouldn't be synced when performing"): metric(i) metric.unsync() metric(i) verify_metric(metric, i, 1) metric.sync() assert metric._is_synced with pytest.raises(TorchMetricsUserError, match="The Metric has already been synced."): metric.sync() verify_metric(metric, i, 2) metric.unsync() assert not metric._is_synced with pytest.raises(TorchMetricsUserError, match="The Metric has already been un-synced."): metric.unsync() with metric.sync_context(): assert metric._is_synced verify_metric(metric, i, 2) with metric.sync_context(should_unsync=False): assert metric._is_synced verify_metric(metric, i, 2) assert metric._is_synced metric.unsync() assert not metric._is_synced metric.sync() cache = metric._cache metric._cache = None with pytest.raises( TorchMetricsUserError, match="The internal cache should exist to unsync the Metric."): metric.unsync() metric._cache = cache def reload_state_dict(state_dict, expected_x, expected_c): metric = DummyCatMetric() metric.load_state_dict(state_dict) assert metric.x == expected_x assert metric.c == expected_c reload_state_dict(deepcopy(metric.state_dict()), 20, 10) metric.unsync() reload_state_dict(deepcopy(metric.state_dict()), 10, 5) metric.sync() filepath = os.path.join(tmpdir, f'weights-{rank}.pt') torch.save(metric.state_dict(), filepath) metric.unsync() with metric.sync_context(): torch.save(metric.state_dict(), filepath)