コード例 #1
0
def test_max_failures(ray_start_2_cpus):
    test_config = TestConfig()

    def train_func():
        import sys
        sys.exit(0)

    trainer = Trainer(test_config, num_workers=2)
    trainer.start()
    iterator = trainer.run_iterator(train_func)
    with pytest.raises(RuntimeError):
        iterator.get_final_results(force=True)
    assert iterator._executor._num_failures == 3
コード例 #2
0
def test_no_exhaust(ray_start_2_cpus, tmp_path):
    """Tests if training can finish even if queue is not exhausted."""
    def train_func():
        for _ in range(2):
            train.report(loss=1)
        return 2

    config = TestConfig()
    trainer = Trainer(config, num_workers=2)
    trainer.start()

    iterator = trainer.run_iterator(train_func)
    output = iterator.get_final_results(force=True)

    assert output == [2, 2]
コード例 #3
0
def test_run_iterator_error(ray_start_2_cpus):
    config = TestConfig()

    def fail_train():
        raise NotImplementedError

    trainer = Trainer(config, num_workers=2)
    trainer.start()
    iterator = trainer.run_iterator(fail_train)

    with pytest.raises(NotImplementedError):
        next(iterator)

    assert iterator.get_final_results() is None
    assert iterator.is_finished()
コード例 #4
0
def test_run_iterator_returns(ray_start_2_cpus):
    config = TestConfig()

    def train_func():
        for i in range(3):
            train.report(index=i)
        return 1

    trainer = Trainer(config, num_workers=2)
    trainer.start()
    iterator = trainer.run_iterator(train_func)

    assert iterator.get_final_results() is None
    assert iterator.get_final_results(force=True) == [1, 1]

    with pytest.raises(StopIteration):
        next(iterator)
コード例 #5
0
def main(num_workers=1, use_gpu=False):
    mlflow.set_experiment("train_torch_fashion_mnist")

    trainer = Trainer(
        backend="torch", num_workers=num_workers, use_gpu=use_gpu)
    trainer.start()
    iterator = trainer.run_iterator(
        train_func=train_func,
        config={
            "lr": 1e-3,
            "batch_size": 64,
            "epochs": 4
        })

    for intermediate_result in iterator:
        first_worker_result = intermediate_result[0]
        mlflow.log_metric("loss", first_worker_result["loss"])

    print("Full losses for rank 0 worker: ", iterator.get_final_results())
コード例 #6
0
def test_run_iterator(ray_start_2_cpus):
    config = TestConfig()

    def train_func():
        for i in range(3):
            train.report(index=i)
        return 1

    trainer = Trainer(config, num_workers=2)
    trainer.start()
    iterator = trainer.run_iterator(train_func)

    count = 0
    for results in iterator:
        assert all(value["index"] == count for value in results)
        count += 1

    assert count == 3
    assert iterator.is_finished()
    assert iterator.get_final_results() == [1, 1]

    with pytest.raises(StopIteration):
        next(iterator)