Esempio n. 1
0
 def configure_ddp(
         self,
         model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
     model = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
     # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
     model.require_backward_grad_sync = False
     return model
Esempio n. 2
0
 def configure_ddp(self, model: LightningModule,
                   device_ids: List[int]) -> DistributedDataParallel:
     ddp_plugin = RPCPlugin(
         process_group=mpu.get_data_parallel_group()).configure_ddp(
             model, device_ids)
     # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
     ddp_plugin.PREPARE_FOR_BACKWARDS = False
     return ddp_plugin
Esempio n. 3
0
 def _sync_balance_to_all_parallel_groups(self, main_rank=0):
     """
     Ensures that we sync the balance to all main processes, so that the balance is the same per replica.
     Args:
         main_rank: The rank with the balance we'd like to replicate.
     """
     self.balance = torch.tensor(self.balance, dtype=torch.int, device='cuda')
     # Ensure we sync to all processes within the main data parallel group
     # We use the data parallel group as all main processes are found within the same group
     torch_distrib.broadcast(self.balance, src=main_rank, group=mpu.get_data_parallel_group())
     self.balance = self.balance.cpu()
Esempio n. 4
0
    def configure_ddp(self):
        if self.main_rpc_process:
            self.pre_configure_ddp()

            self._model = DistributedDataParallel(
                LightningDistributedModule(self.model),
                device_ids=self.determine_ddp_device_ids(),
                process_group=mpu.get_data_parallel_group(),
                **self._ddp_kwargs,
            )
            # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
            self._model.require_backward_grad_sync = False
Esempio n. 5
0
 def data_parallel_group(self):
     return mpu.get_data_parallel_group()