def test_mismatch_checkpoint_report(ray_start_2_cpus): def train_func(): if (train.world_rank()) == 0: train.save_checkpoint(epoch=0) else: train.report(iter=0) config = TestConfig() e = BackendExecutor(config, num_workers=2) e.start() e.start_training(train_func) with pytest.raises(RuntimeError): e.get_next_results()
def test_train_failure(ray_start_2_cpus): config = TestConfig() e = BackendExecutor(config, num_workers=2) e.start() with pytest.raises(TrainBackendError): e.get_next_results() with pytest.raises(TrainBackendError): e.pause_reporting() with pytest.raises(TrainBackendError): e.finish_training() e.start_training(lambda: 1) with pytest.raises(TrainBackendError): e.start_training(lambda: 2) assert e.finish_training() == [1, 1]