def configure_ddp( self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel: model = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids) # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel model.require_backward_grad_sync = False return model
def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel: ddp_plugin = RPCPlugin( process_group=mpu.get_data_parallel_group()).configure_ddp( model, device_ids) # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel ddp_plugin.PREPARE_FOR_BACKWARDS = False return ddp_plugin
def _sync_balance_to_all_parallel_groups(self, main_rank=0): """ Ensures that we sync the balance to all main processes, so that the balance is the same per replica. Args: main_rank: The rank with the balance we'd like to replicate. """ self.balance = torch.tensor(self.balance, dtype=torch.int, device='cuda') # Ensure we sync to all processes within the main data parallel group # We use the data parallel group as all main processes are found within the same group torch_distrib.broadcast(self.balance, src=main_rank, group=mpu.get_data_parallel_group()) self.balance = self.balance.cpu()
def configure_ddp(self): if self.main_rpc_process: self.pre_configure_ddp() self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), process_group=mpu.get_data_parallel_group(), **self._ddp_kwargs, ) # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel self._model.require_backward_grad_sync = False
def data_parallel_group(self): return mpu.get_data_parallel_group()