class DDPPlugin(ParallelPlugin): """ Plugin for multi-process single-device training on one or multiple nodes. The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes. """ distributed_backend = "ddp" def __init__( self, parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) self.interactive_ddp_procs = [] if num_nodes is not None: rank_zero_deprecation( "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6." " Notice that it will be overriden by the trainer setting.") self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_deprecation( "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6." " Notice that it will be overriden by the trainer setting.") self._sync_batchnorm = sync_batchnorm or False self.dist = LightningDistributed() self.num_processes = len( self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._has_spawned_children = False self._task_idx = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self.set_world_ranks() @property def is_distributed(self) -> bool: return True @property def root_device(self) -> torch.device: return self.parallel_devices[self.local_rank] @property def num_nodes(self) -> int: return self._num_nodes @num_nodes.setter def num_nodes(self, num_nodes: int) -> None: # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks self._num_nodes = num_nodes self.set_world_ranks() @property def sync_batchnorm(self) -> bool: return self._sync_batchnorm @sync_batchnorm.setter def sync_batchnorm(self, sync_batchnorm: bool) -> None: self._sync_batchnorm = sync_batchnorm @property def task_idx(self) -> Optional[int]: rank_zero_deprecation( f'`{self.__class__.__name__}.task_idx` is deprecated in v1.4 and will be removed in v1.6. Use ' f'`{self.__class__.__name__}.local_rank` instead.') return self._task_idx @task_idx.setter def task_idx(self, task_idx: int) -> None: self._task_idx = task_idx @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs @property def _is_single_process_single_device(self) -> bool: return True def setup_environment(self) -> None: # start the other scripts if not self.cluster_environment.creates_children() and os.environ.get( "PL_IN_DDP_SUBPROCESS", "0") != "1": self._call_children_scripts() # set the task idx self.task_idx = self.cluster_environment.local_rank() self.setup_distributed() def _call_children_scripts(self): # bookkeeping of spawned processes assert self.local_rank == 0 self._check_can_spawn_children() self._has_spawned_children = True # DDP Environment variables os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # allow the user to pass the node rank os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) # create a temporary directory used to synchronize processes on deadlock. os.environ["PL_DDP_SYNC_TMPDIR"] = self._sync_dir = tempfile.mkdtemp() # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` # See https://docs.python.org/3/reference/import.html#main-spec if __main__.__spec__ is None: # pragma: no-cover # Script called as `python a/b/c.py` # when user is using hydra find the absolute path path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path # pull out the commands used to run the script and resolve the abs file path command = sys.argv try: full_path = path_lib(command[0]) except Exception: full_path = os.path.abspath(command[0]) command[0] = full_path # use the same python interpreter and actually running command = [sys.executable] + command else: # Script called as `python -m a.b.c` command = [sys.executable, "-m", __main__.__spec__.name ] + sys.argv[1:] # the visible devices tell us how many GPUs we want to use. # when the trainer script was called the device has already been scoped by the time # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone # but forward the GPUs selected via environment variables if self.parallel_devices is None: raise MisconfigurationException( "you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)" ) os.environ["PL_IN_DDP_SUBPROCESS"] = "1" os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" self.interactive_ddp_procs = [] for local_rank in range(1, self.num_processes): env_copy = os.environ.copy() env_copy["LOCAL_RANK"] = f"{local_rank}" if self.lightning_module.logger is not None: # spawned processes must reference the same log dir, prevent auto-increment version env_copy["PL_EXP_VERSION"] = str( self.lightning_module.logger.version) # remove env var if global seed not set if os.environ.get( "PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: del env_copy["PL_GLOBAL_SEED"] # start process # if hydra is available and initialized, make sure to set the cwd correctly cwd: Optional[str] = None if _HYDRA_AVAILABLE: if HydraConfig.initialized(): cwd = get_original_cwd() os_cwd = f'"{os.getcwd()}"' command += [ f'hydra.run.dir={os_cwd}', f'hydra.job.name=train_ddp_process_{local_rank}' ] proc = subprocess.Popen(command, env=env_copy, cwd=cwd) self.interactive_ddp_procs.append(proc) # starting all processes at once can cause issues # with dataloaders delay between 1-10 seconds delay = np.random.uniform(1, 5, 1)[0] sleep(delay) def setup_distributed(self): reset_seed() # determine which process we are and world size self.set_world_ranks() # set warning rank rank_zero_only.rank = self.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table self.init_ddp_connection() # set the ranks and devices self.dist.rank = self.global_rank self.dist.device = self.root_device def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( "You tried to run `.fit` or `.test` multiple times in the same script." " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." ) def set_world_ranks(self) -> None: if self.cluster_environment is None: return self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( "find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False): rank_zero_warn( "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " "to properly work with DDP.") self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device: register_ddp_comm_hook( model=self._model, ddp_comm_state=self._ddp_comm_state, ddp_comm_hook=self._ddp_comm_hook, ddp_comm_wrapper=self._ddp_comm_wrapper, ) def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) self._register_ddp_hooks() def determine_ddp_device_ids(self): if self.root_device.type == "cpu": return None return [self.root_device.index] def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None: global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank( ) world_size = world_size if world_size is not None else self.cluster_environment.world_size( ) os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) if torch.distributed.is_available( ) and not torch.distributed.is_initialized(): log.info( f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" ) torch.distributed.init_process_group( self.torch_distributed_backend, rank=global_rank, world_size=world_size) # on rank=0 let everyone know training is starting rank_zero_info( f"{'-' * 100}\n" f"distributed_backend={self.torch_distributed_backend}\n" f"All DDP processes registered. Starting ddp with {self.world_size} processes\n" f"{'-' * 100}\n") def pre_dispatch(self): # move the model to the correct device self.model_to_device() if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) self.configure_ddp() # share ddp pids to all processes self._share_information_to_prevent_deadlock() def post_dispatch(self) -> None: self.cluster_environment.teardown() def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend( ) == "nccl": torch.distributed.barrier( device_ids=self.determine_ddp_device_ids()) else: torch.distributed.barrier() def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) def model_to_device(self): self.model.to(self.root_device) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: """ Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ if isinstance(tensor, torch.Tensor): tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) def validation_step(self, *args, **kwargs): return self.model(*args, **kwargs) def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register( "ddp_find_unused_parameters_false", cls, description="DDP Plugin with `find_unused_parameters` as False", find_unused_parameters=False) def _share_information_to_prevent_deadlock(self): self._share_pids() # remove `PL_DDP_SYNC_TMPDIR` from os.environ self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None) def _share_pids(self): """ Make all DDP processes aware of all processes pids. """ self.barrier() pids = self.all_gather( torch.tensor(os.getpid(), device=self.root_device)) pids = pids.cpu().numpy().tolist() self._pids = pids if isinstance(pids, list) else [pids] def reconciliate_processes(self, trace: str): if self.world_size < 2: return sync_dir = self._sync_dir # save a file locally. torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl")) # sleep for a short time time.sleep(3) # return if all processes wrote a file in the `sync_dir`. # todo (tchaton) Add support for non-shared file-system which will fail. if len(os.listdir(sync_dir)) == self.world_size: return for pid in self._pids: if pid != os.getpid(): os.kill(pid, signal.SIGKILL) shutil.rmtree(sync_dir) raise DeadlockDetectedException( f"DeadLock detected from rank: {self.global_rank} \n {trace}")
class DDPSpawnAccelerator(Accelerator): def __init__(self, trainer, nprocs: int, cluster_environment: Optional[ClusterEnvironment] = None, ddp_plugin: Optional[DDPPlugin] = None): """ Runs training using DDP using mp.spawn via manual launch (not cluster launch) Example:: # default trainer = Trainer(accelerator=DDPSpawnAccelerator()) """ super().__init__(trainer, cluster_environment, ddp_plugin) self.mp_queue = None self.nprocs = nprocs self.dist = LightningDistributed() self.nickname = 'ddp' def setup(self, model): os.environ['MASTER_PORT'] = os.environ.get( 'MASTER_PORT', str(find_free_network_port())) # pass in a state q smp = mp.get_context('spawn') self.mp_queue = smp.SimpleQueue() self.trainer.model = model def train(self): model = self.trainer.model # train in children process mp.spawn(self.ddp_train, nprocs=self.nprocs, args=( self.mp_queue, model, )) # restore main state with best weights best_path = self.mp_queue.get() results = self.mp_queue.get() last_path = self.mp_queue.get() # recover the weights of the processes trained in the children self.__recover_child_process_weights(model, best_path, last_path) return results def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): """ Entry point for ddp Args: process_idx: mp_queue: multiprocessing queue model: """ seed = os.environ.get("PL_GLOBAL_SEED") if seed is not None: seed_everything(int(seed)) # offset the process id if requested process_idx = process_idx + proc_offset # show progressbar only on progress_rank 0 if (self.trainer.node_rank != 0 or process_idx != 0 ) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # determine which process we are and world size self.set_world_ranks(process_idx) # set warning rank rank_zero_only.rank = self.trainer.global_rank # Initialize cuda device self.init_device(process_idx, is_master) # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer self.init_ddp_connection(self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks) if isinstance(self.ddp_plugin, RPCPlugin): if not self.ddp_plugin.is_main_rpc_process: self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() if self.ddp_plugin.return_after_exit_rpc_process: return else: self.ddp_plugin.on_main_rpc_connection(self.trainer) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.trainer.is_global_zero and not torch.distributed.is_initialized( ): log.info('-' * 100) log.info(f'distributed_backend={self.trainer.distributed_backend}') log.info( f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes' ) log.info('-' * 100) # call sync_bn before .cuda(), configure_apex and configure_ddp if self.trainer.sync_batchnorm: model = self.configure_sync_batchnorm(model) # move the model to the correct device self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) self.ddp_plugin.on_after_setup_optimizers(self.trainer) # 16-bit model = self.trainer.precision_connector.connect(model) # device ids change depending on the DDP setup device_ids = self.get_device_ids() # allow user to configure ddp model = self.configure_ddp(model, device_ids) self.trainer.setup_trainer(model) # train or test results = self.train_or_test() # get original model model = self.trainer.get_model() # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) # clean up memory torch.cuda.empty_cache() def set_world_ranks(self, process_idx): self.trainer.local_rank = process_idx self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes def init_device(self, process_idx, is_master): gpu_idx = self.trainer.data_parallel_device_ids[ self.trainer.local_rank] self.trainer.root_gpu = gpu_idx torch.cuda.set_device(self.trainer.root_gpu) def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): device_ids = [self.trainer.root_gpu] return device_ids def training_step(self, args): return self._step(args) def validation_step(self, args): return self._step(args) def test_step(self, args): return self._step(args) def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) else: output = self.trainer.model(*args) return output def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) torch_distrib.barrier() should_stop = stop == self.trainer.world_size return should_stop def broadcast(self, obj, src=0): return self.dist.broadcast(obj) def __recover_child_process_weights(self, model, best_path, last_path): # transfer back the best path to the trainer if self.trainer.checkpoint_callback: self.trainer.checkpoint_callback.best_model_path = best_path # todo, pass also best score # load last weights if last_path is not None and not self.trainer.testing: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) self.trainer.model = model def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): best_model_path = None if self.trainer.checkpoint_callback is not None: best_model_path = self.trainer.checkpoint_callback.best_model_path if self.trainer.global_rank == 0 and mp_queue is not None: rank_zero_warn('cleaning up ddp environment...') # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results) # save the last weights last_path = None if not self.trainer.testing and best_model_path is not None and len( best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) atomic_save(model.state_dict(), last_path) mp_queue.put(last_path) def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel: self.ddp_plugin.device_ids = device_ids model = self.ddp_plugin.configure_ddp(model, device_ids) return model def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: """ Add global batchnorm for a model spread across multiple GPUs and nodes. Override to synchronize batchnorm between specific process groups instead of the whole world or use a different sync_bn like `apex`'s version. Args: model: pointer to current :class:`LightningModule`. Return: LightningModule with batchnorm layers synchronized between process groups """ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model, process_group=None) return model def sync_tensor( self, tensor: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): """ Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for all_gather op Return: A tensor of shape (world_size, batch, ...) """ return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes, rank=self.trainer.global_rank) if self.ddp_plugin is not None: distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs( distributed_sampler_kwargs) return distributed_sampler_kwargs @property def require_distributed_sampler(self): return True
class DDPPlugin(ParallelPlugin): """Plugin for multi-process single-device training on one or multiple nodes. The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes. """ distributed_backend = "ddp" def __init__( self, parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, model_averaging_period: Optional[int] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, ) self.interactive_ddp_procs = [] if num_nodes is not None: rank_zero_deprecation( "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6." " Notice that it will be overriden by the trainer setting.") self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_deprecation( "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6." " Notice that it will be overriden by the trainer setting.") self._sync_batchnorm = sync_batchnorm or False self.dist = LightningDistributed() self.num_processes = len( self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._task_idx = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper self._model_averaging_period = model_averaging_period self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self.set_world_ranks() @property def is_distributed(self) -> bool: return True @property def root_device(self) -> torch.device: return self.parallel_devices[self.local_rank] @property def num_nodes(self) -> int: return self._num_nodes @num_nodes.setter def num_nodes(self, num_nodes: int) -> None: # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks self._num_nodes = num_nodes self.set_world_ranks() @property def sync_batchnorm(self) -> bool: return self._sync_batchnorm @sync_batchnorm.setter def sync_batchnorm(self, sync_batchnorm: bool) -> None: self._sync_batchnorm = sync_batchnorm @property def task_idx(self) -> Optional[int]: rank_zero_deprecation( f"`{self.__class__.__name__}.task_idx` is deprecated in v1.4 and will be removed in v1.6. Use " f"`{self.__class__.__name__}.local_rank` instead.") return self._task_idx @task_idx.setter def task_idx(self, task_idx: int) -> None: self._task_idx = task_idx @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs @property def _is_single_process_single_device(self) -> bool: return True def setup_environment(self) -> None: # start the other scripts if not self.cluster_environment.creates_children(): self._call_children_scripts() # set the task idx self.task_idx = self.cluster_environment.local_rank() self.setup_distributed() def _call_children_scripts(self): # bookkeeping of spawned processes self._check_can_spawn_children() # DDP Environment variables os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # allow the user to pass the node rank os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` # See https://docs.python.org/3/reference/import.html#main-spec if __main__.__spec__ is None: # pragma: no-cover # Script called as `python a/b/c.py` # when user is using hydra find the absolute path path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path # pull out the commands used to run the script and resolve the abs file path command = sys.argv try: full_path = path_lib(command[0]) except Exception: full_path = os.path.abspath(command[0]) command[0] = full_path # use the same python interpreter and actually running command = [sys.executable] + command else: # Script called as `python -m a.b.c` command = [sys.executable, "-m", __main__.__spec__.name ] + sys.argv[1:] # the visible devices tell us how many GPUs we want to use. # when the trainer script was called the device has already been scoped by the time # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone # but forward the GPUs selected via environment variables if self.parallel_devices is None: raise MisconfigurationException( "you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)" ) os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" self.interactive_ddp_procs = [] for local_rank in range(1, self.num_processes): env_copy = os.environ.copy() env_copy["LOCAL_RANK"] = f"{local_rank}" # remove env var if global seed not set if os.environ.get( "PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: del env_copy["PL_GLOBAL_SEED"] # start process # if hydra is available and initialized, make sure to set the cwd correctly cwd: Optional[str] = None if _HYDRA_AVAILABLE: if HydraConfig.initialized(): cwd = get_original_cwd() os_cwd = f'"{os.getcwd()}"' command += [ f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}" ] proc = subprocess.Popen(command, env=env_copy, cwd=cwd) self.interactive_ddp_procs.append(proc) # starting all processes at once can cause issues # with dataloaders delay between 1-10 seconds delay = np.random.uniform(1, 5, 1)[0] sleep(delay) def setup_distributed(self): reset_seed() # determine which process we are and world size self.set_world_ranks() # set warning rank rank_zero_only.rank = self.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table init_ddp_connection(self.cluster_environment, self.torch_distributed_backend) # set the ranks and devices self.dist.rank = self.global_rank self.dist.device = self.root_device def _check_can_spawn_children(self): if self.local_rank != 0: raise RuntimeError( "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," " 2) `ClusterEnvironment.creates_children()` incorrectly implemented." ) def set_world_ranks(self) -> None: if self.cluster_environment is None: return self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( "find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization if (_TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get("find_unused_parameters", False)): rank_zero_warn( "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " "to properly work with DDP.") self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: # In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # Since 1.9, DDP communication hooks can work on all backends. if _TORCH_GREATER_EQUAL_1_9 or (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device): register_ddp_comm_hook( model=self._model, ddp_comm_state=self._ddp_comm_state, ddp_comm_hook=self._ddp_comm_hook, ddp_comm_wrapper=self._ddp_comm_wrapper, ) if (_TORCH_GREATER_EQUAL_1_10 and isinstance( self._ddp_comm_state, post_localSGD.PostLocalSGDState) and self.lightning_module.trainer.state.fn == TrainerFn.FITTING): self._reinit_optimizers_with_post_localSGD( self._ddp_comm_state.start_localSGD_iter) def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): optimizers = self.lightning_module.trainer.optimizers if self._model_averaging_period is None: raise ValueError( "Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin." ) averager = averagers.PeriodicModelAverager( period=self._model_averaging_period, warmup_steps=warmup_steps) for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if (isinstance(optimizer, DistributedOptimizer) or isinstance(optimizer, ZeroRedundancyOptimizer) or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))): raise ValueError( f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer." ) if isinstance(optimizer, PostLocalSGDOptimizer): continue optim_class = type(optimizer) post_localSGD_optimizer = PostLocalSGDOptimizer( params=optimizer.param_groups, optimizer_class=optim_class, averager=averager, **optimizer.defaults, ) optimizers[x] = post_localSGD_optimizer del optimizer trainer = self.lightning_module.trainer trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers() def configure_ddp(self) -> None: self.pre_configure_ddp() self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) self._register_ddp_hooks() def determine_ddp_device_ids(self): if self.root_device.type == "cpu": return None return [self.root_device.index] def pre_dispatch(self): # share ddp pids to all processes self._share_information_to_prevent_deadlock() # move the model to the correct device self.model_to_device() if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = self.lightning_module.trainer.state.fn if trainer_fn == TrainerFn.FITTING: self.configure_ddp() def post_dispatch(self, trainer: "pl.Trainer") -> None: self.cluster_environment.teardown() def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend( ) == "nccl": torch.distributed.barrier( device_ids=self.determine_ddp_device_ids()) else: torch.distributed.barrier() def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward.""" if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) def model_to_device(self): self.model.to(self.root_device) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ if isinstance(tensor, torch.Tensor): tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor def training_step(self, *args, **kwargs) -> Optional[Any]: return self.model(*args, **kwargs) def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: if isinstance(self.model, DistributedDataParallel): # used when calling `trainer.fit` return self.model(*args, **kwargs) else: # used when calling `trainer.validate` return self.lightning_module.validation_step(*args, **kwargs) def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: return self.lightning_module.test_step(*args, **kwargs) def predict_step(self, *args, **kwargs) -> Any: return self.lightning_module.predict_step(*args, **kwargs) def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register( "ddp_find_unused_parameters_false", cls, description="DDP Plugin with `find_unused_parameters` as False", find_unused_parameters=False, ) def _share_information_to_prevent_deadlock(self): self._share_pids() # there should be a unique sync_dir per nodes. if self.local_rank == 0: # create a temporary directory used to synchronize processes on deadlock. self._sync_dir = tempfile.mkdtemp() sync_dirs = [] global_node_rank_zero = 0 for _ in range(self.num_nodes): sync_dirs.append( self.broadcast(self._sync_dir, global_node_rank_zero)) global_node_rank_zero += self.world_size // self.num_nodes self._sync_dir = sync_dirs[self.node_rank] def _share_pids(self): """Make all DDP processes aware of all processes pids.""" self.barrier() pids = self.all_gather( torch.tensor(os.getpid(), device=self.root_device)) pids = pids.cpu().numpy().tolist() self._pids = pids if isinstance(pids, list) else [pids] def reconciliate_processes(self, trace: str): if self.world_size < 2: return sync_dir = self._sync_dir if not sync_dir: rank_zero_warn( "Error handling mechanism for deadlock detection is uninitialized. Skipping check." ) return # The cluster may be configured to periodically purge the `/tmp` # directory, in which case `sync_dir` may not exist anymore at this # point. Idempotently create it to ensure its existence. Path(sync_dir).mkdir(parents=True, exist_ok=True) # save a file locally. torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl")) # sleep for a short time time.sleep(3) # return if all processes wrote a file in the `sync_dir`. # todo (tchaton) Add support for non-shared file-system which will fail. if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes): return for pid in self._pids: if pid != os.getpid(): os.kill(pid, signal.SIGKILL) shutil.rmtree(sync_dir) raise DeadlockDetectedException( f"DeadLock detected from rank: {self.global_rank} \n {trace}") def teardown(self) -> None: if isinstance(self.model, DistributedDataParallel): self.model = self.lightning_module if self.on_gpu: # GPU teardown self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache()
class DDPPlugin(ParallelPlugin): """ Plugin for multi-process single-device training on one or multiple nodes. The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes. """ distributed_backend = "ddp" def __init__( self, parallel_devices: Optional[List[torch.device]] = None, num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, sync_batchnorm: bool = False, **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) self.interactive_ddp_procs = [] self.num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm self.dist = LightningDistributed() self._ddp_kwargs = kwargs self._has_spawned_children = False self.task_idx = None self.node_rank = 0 self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices @property def root_device(self): return self.parallel_devices[self.local_rank] @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs def setup_environment(self): # start the other scripts if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": self._call_children_scripts() # set the task idx self.task_idx = self.cluster_environment.local_rank() self.setup_distributed() def _call_children_scripts(self): # bookkeeping of spawned processes assert self.global_rank == 0 self._check_can_spawn_children() self._has_spawned_children = True # DDP Environment variables os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # allow the user to pass the node rank os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) # when user is using hydra find the absolute path path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path # pull out the commands used to run the script and resolve the abs file path command = sys.argv try: full_path = path_lib(command[0]) except Exception: full_path = os.path.abspath(command[0]) command[0] = full_path # use the same python interpreter and actually running command = [sys.executable] + command # the visible devices tell us how many GPUs we want to use. # when the trainer script was called the device has already been scoped by the time # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone # but forward the GPUs selected via environment variables if self.parallel_devices is None: raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices]) os.environ["PL_IN_DDP_SUBPROCESS"] = "1" if self.lightning_module.logger is not None: os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version) num_gpus = len(self.parallel_devices) os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}" self.interactive_ddp_procs = [] for local_rank in range(1, self.num_processes): env_copy = os.environ.copy() env_copy["LOCAL_RANK"] = f"{local_rank}" # remove env var if global seed not set if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: del env_copy["PL_GLOBAL_SEED"] # start process # if hydra is available and initialized, make sure to set the cwd correctly cwd: Optional[str] = None if _HYDRA_AVAILABLE: if HydraConfig.initialized(): cwd = get_original_cwd() os_cwd = f'"{os.getcwd()}"' command += [f'hydra.run.dir={os_cwd}', f'hydra.job.name=train_ddp_process_{local_rank}'] proc = subprocess.Popen(command, env=env_copy, cwd=cwd) self.interactive_ddp_procs.append(proc) # starting all processes at once can cause issues # with dataloaders delay between 1-10 seconds delay = np.random.uniform(1, 5, 1)[0] sleep(delay) def setup_distributed(self): # TODO: check if needed seed = os.environ.get("PL_GLOBAL_SEED") if seed is not None: seed_everything(int(seed)) # determine which process we are and world size self.set_world_ranks() # set warning rank rank_zero_only.rank = self.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table self.init_ddp_connection(self.global_rank, self.world_size) # on world_size=0 let everyone know training is starting if self.is_global_zero and not torch.distributed.is_initialized(): log.info("-" * 100) log.info(f"distributed_backend={self.distributed_backend}") log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") log.info("-" * 100) # set the ranks and devices self.dist.rank = self.global_rank self.dist.device = self.root_device def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( "You tried to run `.fit` or `.test` multiple times in the same script." " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." ) def set_world_ranks(self): self.local_rank = self.task_idx self.node_rank = self.cluster_environment.node_rank() self.global_rank = self.node_rank * self.num_processes + self.local_rank self.world_size = self.num_nodes * self.num_processes def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False ): rank_zero_warn( "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " "to properly work with DDP." ) self._ddp_kwargs["find_unused_parameters"] = True def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) def determine_ddp_device_ids(self): if self.root_device.type == "cpu": return None return [self.root_device.index] def init_ddp_connection(self, global_rank: int, world_size: int) -> None: os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) # move the model to the correct device self.model_to_device() self.configure_ddp() self.barrier() def post_dispatch(self): if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] def barrier(self, *args, **kwargs): if torch_distrib.is_initialized(): torch_distrib.barrier() def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: prepare_for_backward(self.model, closure_loss) def model_to_device(self): if self.root_device.type == "cuda": torch.cuda.set_device(self.root_device) self.model.to(self.root_device) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ if isinstance(tensor, torch.Tensor): tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) return tensor def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) def validation_step(self, *args, **kwargs): return self.model(*args, **kwargs) def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True
class DDPPlugin(ParallelPlugin): """ Plugin for multi-process single-device training on one or multiple nodes. The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes. """ distributed_backend = "ddp" def __init__( self, parallel_devices, num_nodes=1, cluster_environment: ClusterEnvironment = None, sync_batchnorm=False, **kwargs: Dict[str, Any], ) -> None: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) self.interactive_ddp_procs = [] self.num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm self.dist = LightningDistributed() self._ddp_kwargs = kwargs self._has_spawned_children = False self.task_idx = None self.node_rank = 0 self.num_processes = len(parallel_devices) @property def root_device(self): return self.parallel_devices[self.local_rank] @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs def setup(self, model): self._model = model # start the other scripts # TODO: make sure this works, in torchelastic we should not launch child processes! if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": self._call_children_scripts() # set the task idx self.task_idx = self.cluster_environment.local_rank() def _call_children_scripts(self): # bookkeeping of spawned processes assert self.global_rank == 0 self._check_can_spawn_children() self._has_spawned_children = True # DDP Environment variables os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port())) # allow the user to pass the node rank node_rank = "0" node_rank = os.environ.get("NODE_RANK", node_rank) node_rank = os.environ.get("GROUP_RANK", node_rank) os.environ["NODE_RANK"] = node_rank os.environ["LOCAL_RANK"] = "0" # when user is using hydra find the absolute path path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path # pull out the commands used to run the script and resolve the abs file path command = sys.argv try: full_path = path_lib(command[0]) except Exception as e: full_path = os.path.abspath(command[0]) command[0] = full_path # use the same python interpreter and actually running command = [sys.executable] + command # the visible devices tell us how many GPUs we want to use. # when the trainer script was called the device has already been scoped by the time # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone # but forward the GPUs selected via environment variables if self.parallel_devices is None: raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices]) os.environ["PL_IN_DDP_SUBPROCESS"] = "1" if self.lightning_module.logger is not None: os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version) num_gpus = len(self.parallel_devices) os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}" self.interactive_ddp_procs = [] for local_rank in range(1, self.num_processes): env_copy = os.environ.copy() env_copy["LOCAL_RANK"] = f"{local_rank}" # remove env var if global seed not set if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: del env_copy["PL_GLOBAL_SEED"] # start process # if hydra is available and initialized, make sure to set the cwd correctly cwd: Optional[str] = None if _HYDRA_AVAILABLE: if HydraConfig.initialized(): cwd = get_original_cwd() proc = subprocess.Popen(command, env=env_copy, cwd=cwd) self.interactive_ddp_procs.append(proc) # starting all processes at once can cause issues # with dataloaders delay between 1-10 seconds delay = np.random.uniform(1, 5, 1)[0] sleep(delay) def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( "You tried to run `.fit` or `.test` multiple times in the same script." " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." ) def set_world_ranks(self): self.local_rank = self.task_idx self.node_rank = self.cluster_environment.node_rank() self.global_rank = self.node_rank * self.num_processes + self.local_rank self.world_size = self.num_nodes * self.num_processes def configure_ddp(self): self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) def determine_ddp_device_ids(self): if self.root_device.type == "cpu": return None return [self.root_device.index] def init_ddp_connection(self, global_rank: int, world_size: int) -> None: # TODO: From where to get cluster environment? os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) torch_backend = "nccl" if self.on_gpu else "gloo" if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) def pre_training(self): # TODO: check if needed seed = os.environ.get("PL_GLOBAL_SEED") if seed is not None: seed_everything(int(seed)) # determine which process we are and world size self.set_world_ranks() # set warning rank rank_zero_only.rank = self.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table self.init_ddp_connection(self.global_rank, self.world_size) # TODO: we moved it to the trainer.fit after calling pre_training # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) # on world_size=0 let everyone know training is starting if self.is_global_zero and not torch.distributed.is_initialized(): log.info("-" * 100) log.info(f"distributed_backend={self.distributed_backend}") log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") log.info("-" * 100) # set the ranks and devices self.dist.rank = self.global_rank self.dist.device = self.root_device if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) # move the model to the correct device self.model_to_device() self.configure_ddp() self.barrier() def post_training(self): if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] def barrier(self, *args, **kwargs): if torch_distrib.is_initialized(): torch_distrib.barrier() def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) def model_to_device(self): if self.root_device.type == "cuda": torch.cuda.set_device(self.root_device) self.model.to(self.root_device) def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if isinstance(output, torch.Tensor): output = sync_ddp_if_available(output, group, reduce_op) return output def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) def validation_step(self, *args, **kwargs): return self.model(*args, **kwargs) def test_step(self, *args, **kwargs): return self.model(*args, **kwargs)