def test_max_failures(ray_start_2_cpus): test_config = TestConfig() def train_func(): import sys sys.exit(0) trainer = Trainer(test_config, num_workers=2) trainer.start() iterator = trainer.run_iterator(train_func) with pytest.raises(RuntimeError): iterator.get_final_results(force=True) assert iterator._executor._num_failures == 3
def test_no_exhaust(ray_start_2_cpus, tmp_path): """Tests if training can finish even if queue is not exhausted.""" def train_func(): for _ in range(2): train.report(loss=1) return 2 config = TestConfig() trainer = Trainer(config, num_workers=2) trainer.start() iterator = trainer.run_iterator(train_func) output = iterator.get_final_results(force=True) assert output == [2, 2]
def test_run_iterator_error(ray_start_2_cpus): config = TestConfig() def fail_train(): raise NotImplementedError trainer = Trainer(config, num_workers=2) trainer.start() iterator = trainer.run_iterator(fail_train) with pytest.raises(NotImplementedError): next(iterator) assert iterator.get_final_results() is None assert iterator.is_finished()
def test_run_iterator_returns(ray_start_2_cpus): config = TestConfig() def train_func(): for i in range(3): train.report(index=i) return 1 trainer = Trainer(config, num_workers=2) trainer.start() iterator = trainer.run_iterator(train_func) assert iterator.get_final_results() is None assert iterator.get_final_results(force=True) == [1, 1] with pytest.raises(StopIteration): next(iterator)
def main(num_workers=1, use_gpu=False): mlflow.set_experiment("train_torch_fashion_mnist") trainer = Trainer( backend="torch", num_workers=num_workers, use_gpu=use_gpu) trainer.start() iterator = trainer.run_iterator( train_func=train_func, config={ "lr": 1e-3, "batch_size": 64, "epochs": 4 }) for intermediate_result in iterator: first_worker_result = intermediate_result[0] mlflow.log_metric("loss", first_worker_result["loss"]) print("Full losses for rank 0 worker: ", iterator.get_final_results())
def test_run_iterator(ray_start_2_cpus): config = TestConfig() def train_func(): for i in range(3): train.report(index=i) return 1 trainer = Trainer(config, num_workers=2) trainer.start() iterator = trainer.run_iterator(train_func) count = 0 for results in iterator: assert all(value["index"] == count for value in results) count += 1 assert count == 3 assert iterator.is_finished() assert iterator.get_final_results() == [1, 1] with pytest.raises(StopIteration): next(iterator)