Example #1
0
def test_max_failures(ray_start_2_cpus):
    test_config = TestConfig()

    def train():
        import sys
        sys.exit(0)

    trainer = Trainer(test_config, num_workers=2)
    trainer.start()
    iterator = trainer.run_iterator(train)
    with pytest.raises(RuntimeError):
        iterator.get_final_results(force=True)
    assert iterator._executor._num_failures == 3
Example #2
0
def test_run_iterator_error(ray_start_2_cpus):
    config = TestConfig()

    def fail_train():
        raise NotImplementedError

    trainer = Trainer(config, num_workers=2)
    trainer.start()
    iterator = trainer.run_iterator(fail_train)

    with pytest.raises(NotImplementedError):
        next(iterator)

    assert iterator.get_final_results() is None
    assert iterator.is_finished()
Example #3
0
def test_run_iterator_returns(ray_start_2_cpus):
    config = TestConfig()

    def train_func():
        for i in range(3):
            sgd.report(index=i)
        return 1

    trainer = Trainer(config, num_workers=2)
    trainer.start()
    iterator = trainer.run_iterator(train_func)

    assert iterator.get_final_results() is None
    assert iterator.get_final_results(force=True) == [1, 1]

    with pytest.raises(StopIteration):
        next(iterator)
Example #4
0
def main(num_workers=1, use_gpu=False):
    mlflow.set_experiment("sgd_torch_fashion_mnist")

    trainer = Trainer(backend="torch",
                      num_workers=num_workers,
                      use_gpu=use_gpu)
    trainer.start()
    iterator = trainer.run_iterator(train_func=train_func,
                                    config={
                                        "lr": 1e-3,
                                        "batch_size": 64,
                                        "epochs": 4
                                    })

    for intermediate_result in iterator:
        first_worker_result = intermediate_result[0]
        mlflow.log_metric("loss", first_worker_result["loss"])

    print("Full losses for rank 0 worker: ", iterator.get_final_results())
Example #5
0
def test_run_iterator(ray_start_2_cpus):
    config = TestConfig()

    def train_func():
        for i in range(3):
            sgd.report(index=i)
        return 1

    trainer = Trainer(config, num_workers=2)
    trainer.start()
    iterator = trainer.run_iterator(train_func)

    count = 0
    for results in iterator:
        assert (value["index"] == count for value in results)
        count += 1

    assert count == 3
    assert iterator.is_finished()
    assert iterator.get_final_results() == [1, 1]

    with pytest.raises(StopIteration):
        next(iterator)