def __init__( self, *, train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], train_loop_config: Optional[Dict] = None, tensorflow_config: Optional[TensorflowConfig] = None, scaling_config: Optional[ScalingConfig] = None, run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, preprocessor: Optional[Preprocessor] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): if not tensorflow_config: tensorflow_config = TensorflowConfig() super(TensorflowTrainer, self).__init__( train_loop_per_worker=train_loop_per_worker, train_loop_config=train_loop_config, backend_config=tensorflow_config, scaling_config=scaling_config, run_config=run_config, datasets=datasets, preprocessor=preprocessor, resume_from_checkpoint=resume_from_checkpoint, )
def test_worker_kill(ray_start_2_cpus, backend): if backend == "test": test_config = TestConfig() elif backend == "torch": test_config = TorchConfig() elif backend == "tf": test_config = TensorflowConfig() elif backend == "horovod": test_config = HorovodConfig() trainer = Trainer(test_config, num_workers=2) def train_func(): for i in range(2): train.report(loss=1, iter=i) trainer.start() kill_callback = KillCallback(fail_on=0, trainer=trainer) trainer.run(train_func, callbacks=[kill_callback]) # Run 1: iter=0, counter=1, Successful # Run 2: iter=1, counter=1, Unsuccessful, starts training from beginning # Run 3: iter=0, counter=2, Successful # Run 4: iter=1, counter=3, Successful assert kill_callback.counter == 3 trainer.shutdown() trainer.start() kill_callback = KillCallback(fail_on=1, trainer=trainer) trainer.run(train_func, callbacks=[kill_callback]) # Run 1: iter=0, counter=1, Successful # Run 2: iter=1, counter=2, Successful # Run 3: None, counter=2, Unsuccessful, starts training from beginning. # Run 4: iter=0, counter=3, Successful # Run 5: iter=1, counter=4, Successful assert kill_callback.counter == 4 def train_func(): return 1 # Make sure Trainer is usable even after failure handling. trainer.run(train_func)
def test_tensorflow_start(ray_start_2_cpus): num_workers = 2 tensorflow_config = TensorflowConfig() e = BackendExecutor(tensorflow_config, num_workers=num_workers) e.start() def get_tf_config(): import json import os return json.loads(os.environ["TF_CONFIG"]) e.start_training(get_tf_config) results = e.finish_training() assert len(results) == num_workers workers = [result["cluster"]["worker"] for result in results] assert all(worker == workers[0] for worker in workers) indexes = [result["task"]["index"] for result in results] assert len(set(indexes)) == num_workers