Exemplo n.º 1
0
def test_mismatch_checkpoint_report(ray_start_2_cpus):
    def train_func():
        if (train.world_rank()) == 0:
            train.save_checkpoint(epoch=0)
        else:
            train.report(iter=0)

    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()
    e.start_training(train_func)
    with pytest.raises(RuntimeError):
        e.get_next_results()
Exemplo n.º 2
0
def test_train_failure(ray_start_2_cpus):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()

    with pytest.raises(TrainBackendError):
        e.get_next_results()

    with pytest.raises(TrainBackendError):
        e.pause_reporting()

    with pytest.raises(TrainBackendError):
        e.finish_training()

    e.start_training(lambda: 1)

    with pytest.raises(TrainBackendError):
        e.start_training(lambda: 2)

    assert e.finish_training() == [1, 1]