def train(self): """Main training loop. Trigger remote training via ``train_remote`` on each worker. If using with Ray Tune, create a communication queue to revieve intermediate results, and process those results. Finally retrieve the training results from the rank 0 worker and return.""" trainer = self.trainer trainer_ref = ray.put(self.trainer) self.trainer = None queue = None if TUNE_INSTALLED and is_session_enabled(): # Create communication queue and send to all the workers. queue = Queue(actor_options={"num_cpus": 0}) result_futures = self.executor.run_async(self.train_remote, args=[trainer_ref, queue]) results = process_results(result_futures, queue) results, state_dict, best_path = results[0] self.trainer = trainer self.trainer.model.load_state_dict(state_dict) if self.trainer.checkpoint_callback: self.trainer.checkpoint_callback.best_model_path = best_path return results
def start_training(self, trainer): """Main training loop. Sets up the torch.distributed process group for each training worker. Then trigger remote training via ``train_remote`` on each worker. If using with Ray Tune, create a communication queue to revieve intermediate results, and process those results. Finally retrieve the training results from the rank 0 worker and return.""" # Get rank 0 worker address and port for DDP connection. os.environ["MASTER_ADDR"] = ray.get( self.workers[0].get_node_ip.remote()) os.environ["MASTER_PORT"] = str( ray.get(self.workers[0].execute.remote(find_free_port))) # Set environment variables for remote workers. keys = [ "PL_GLOBAL_SEED", "PL_TORCH_DISTRIBUTED_BACKEND", "MASTER_ADDR", "MASTER_PORT" ] values = [os.getenv(k) for k in keys] ray.get([w.set_env_vars.remote(keys, values) for w in self.workers]) self.global_to_local = self.get_local_ranks() model = self._model model_ref = ray.put(model) # Don't pickle the model when training remotely. self._model = None queue = None if TUNE_INSTALLED and is_session_enabled(): # Create communication queue and send to all the workers. queue = Queue(actor_options={"num_cpus": 0}) futures = [ self.workers[i].execute.remote(self.train_remote, model_ref, i, queue) for i in range(self.num_workers) ] results = process_results(futures, queue) # Get the results, checkpoint path, and model weights from worker 0. results, best_path, state_dict = results[0] # Set the state for PTL using the output from remote training. self._results = results self._model = model self._model.load_state_dict(state_dict) if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback\ .best_model_path = best_path if queue: # Shutdown the queue. queue.shutdown() return results
def train(self): """Main training loop. Sets up the torch.distributed process group for each training worker. Then trigger remote training via ``train_remote`` on each worker. If using with Ray Tune, create a communication queue to revieve intermediate results, and process those results. Finally retrieve the training results from the rank 0 worker and return.""" if "PL_GLOBAL_SEED" in os.environ: seed = os.environ["PL_GLOBAL_SEED"] ray.get([ w.set_env_var.remote("PL_GLOBAL_SEED", seed) for w in self.workers ]) # Get the rank 0 address for DDP connection. self.ddp_address = ray.get( self.workers[0].execute.remote(setup_address)) self.global_to_local = self.get_local_ranks() trainer = self.trainer assert trainer is not None trainer_ref = ray.put(trainer) # Don't pickle self.trainer when training remotely. self.trainer = None queue = None if TUNE_INSTALLED and is_session_enabled(): # Create communication queue and send to all the workers. queue = Queue(actor_options={"num_cpus": 0}) futures = [ self.workers[i].execute.remote(self.train_remote, trainer_ref, i, queue) for i in range(self.num_workers) ] results = process_results(futures, queue) results, best_path, state_dict = results[0] self.trainer = trainer self.trainer.model.load_state_dict(state_dict) if self.trainer.checkpoint_callback: self.trainer.checkpoint_callback.best_model_path = best_path if queue: # Shutdown the queue. queue.shutdown() return results
def execution_loop(self, trainer, tune_enabled: bool = True): """Main execution loop for training, testing, & prediction. Sets up the torch.distributed process group for each worker. Then trigger remote training/testing/eval via ``train_remote`` on each worker. If using with Ray Tune, create a communication queue to retrieve intermediate results, and process those results. Finally retrieve the training results from the rank 0 worker and return.""" # Sets environment variables for all workers. self._setup_env_vars() self.global_to_local = self.get_local_ranks() model = self._model model_ref = ray.put(model) # Don't pickle the model when training remotely. self._model = None queue = None if tune_enabled and TUNE_INSTALLED and is_session_enabled(): # Create communication queue and send to all the workers. queue = Queue(actor_options={"num_cpus": 0}) futures = [ self.workers[i].execute.remote(self.execute_remote, model_ref, i, queue) for i in range(self.num_workers) ] results = process_results(futures, queue) # Get the results, checkpoint path, and model weights from worker 0. results, best_path, state_stream = results[0] state_dict = load_state_stream(state_stream, to_gpu=self.use_gpu) # Set the state for PTL using the output from remote training. self._results = results self._model = model self._model.load_state_dict(state_dict) if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback \ .best_model_path = best_path if queue: # Shutdown the queue. queue.shutdown() return results
def start_training(self, trainer): """Main training loop. Trigger remote training via ``train_remote`` on each worker. If using with Ray Tune, create a communication queue to retrieve intermediate results, and process those results. Finally retrieve the training results from the rank 0 worker and return.""" model = self._model model_ref = ray.put(model) # Don't pickle the model when training remotely. self._model = None queue = None if TUNE_INSTALLED and is_session_enabled(): # Create communication queue and send to all the workers. queue = Queue(actor_options={"num_cpus": 0}) result_futures = self.executor.run_remote( self.train_remote, args=[model_ref, queue]) results = process_results(result_futures, queue) results, state_dict, best_path = results[0] self._results = results self._model = model self._model.load_state_dict(state_dict) self._model.trainer.accelerator.training_type_plugin = self if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback \ .best_model_path = best_path if queue: # Shutdown the queue. queue.shutdown() return results