def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: self.mp_queue = mp_queue reset_seed() self.set_world_ranks(process_idx) # set warning rank rank_zero_only.rank = self.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size) # TODO: we moved it to the trainer.fit after calling pre_dispatch # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) # set the ranks and devices self.dist.rank = self.global_rank self.dist.device = self.root_device # move the model to the correct device self.model_to_device() if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = self.lightning_module.trainer.state.fn if trainer_fn == TrainerFn.FITTING: self.configure_ddp() self.barrier() results = trainer.run_stage() # persist info in ddp_spawn self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook()
def setup_distributed(self): reset_seed() # determine which process we are and world size self.set_world_ranks() # set warning rank rank_zero_only.rank = self.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table init_ddp_connection(self.cluster_environment, self.torch_distributed_backend) # set the ranks and devices self.dist.rank = self.global_rank self.dist.device = self.root_device