Exemplo n.º 1
0
def test_to_worker_group(ray_start_2_cpus):
    config = TestConfig()
    trainer = Trainer(config, num_workers=2)

    class Incrementer:
        def __init__(self, starting=0):
            self.count = starting

        def increment(self):
            self.count += 1

        def get_count(self):
            return self.count

    workers = trainer.to_worker_group(Incrementer, starting=2)
    assert ray.get([w.get_count.remote() for w in workers]) == [2, 2]

    ray.get([w.increment.remote() for w in workers])
    assert ray.get([w.get_count.remote() for w in workers]) == [3, 3]

    ray.get(workers[0].increment.remote())
    assert ray.get([w.get_count.remote() for w in workers]) == [4, 3]

    ray.get(workers[1].increment.remote())
    assert ray.get([w.get_count.remote() for w in workers]) == [4, 4]
Exemplo n.º 2
0
def test_horovod_torch_mnist_stateful(ray_start_4_cpus):
    num_workers = 2
    num_epochs = 2
    trainer = Trainer("horovod", num_workers)
    workers = trainer.to_worker_group(HorovodTrainClass,
                                      config={
                                          "num_epochs": num_epochs,
                                          "lr": 1e-3
                                      })
    results = []
    for epoch in range(num_epochs):
        results.append(ray.get([w.train.remote(epoch=epoch) for w in workers]))
    trainer.shutdown()

    assert len(results) == num_epochs
    for i in range(num_workers):
        assert results[num_epochs - 1][i] < results[0][i]