def test_to_worker_group(ray_start_2_cpus): config = TestConfig() trainer = Trainer(config, num_workers=2) class Incrementer: def __init__(self, starting=0): self.count = starting def increment(self): self.count += 1 def get_count(self): return self.count workers = trainer.to_worker_group(Incrementer, starting=2) assert ray.get([w.get_count.remote() for w in workers]) == [2, 2] ray.get([w.increment.remote() for w in workers]) assert ray.get([w.get_count.remote() for w in workers]) == [3, 3] ray.get(workers[0].increment.remote()) assert ray.get([w.get_count.remote() for w in workers]) == [4, 3] ray.get(workers[1].increment.remote()) assert ray.get([w.get_count.remote() for w in workers]) == [4, 4]
def test_horovod_torch_mnist_stateful(ray_start_2_cpus): num_workers = 2 num_epochs = 2 trainer = Trainer("horovod", num_workers) workers = trainer.to_worker_group(HorovodTrainClass, config={ "num_epochs": num_epochs, "lr": 1e-3 }) results = [] for epoch in range(num_epochs): results.append(ray.get([w.train.remote(epoch=epoch) for w in workers])) trainer.shutdown() assert len(results) == num_epochs for i in range(num_workers): assert results[num_epochs - 1][i] < results[0][i]