コード例 #1
0
    def new_process(self, process_idx: int, trainer: "pl.Trainer",
                    mp_queue: SimpleQueue) -> None:
        self.mp_queue = mp_queue

        reset_seed()

        self.set_world_ranks(process_idx)

        # set warning rank
        rank_zero_only.rank = self.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        init_ddp_connection(self.cluster_environment,
                            self.torch_distributed_backend, self.global_rank,
                            self.world_size)

        # TODO: we moved it to the trainer.fit after calling pre_dispatch
        #   ... need to double check that it is the correct place
        # self.trainer.call_setup_hook(self.model)

        # set the ranks and devices
        self.dist.rank = self.global_rank
        self.dist.device = self.root_device

        # move the model to the correct device
        self.model_to_device()

        if self.sync_batchnorm:
            self.model = self.configure_sync_batchnorm(self.model)

        # skip wrapping the model if we are not fitting as no gradients need to be exchanged
        trainer_fn = self.lightning_module.trainer.state.fn
        if trainer_fn == TrainerFn.FITTING:
            self.configure_ddp()

        self.barrier()

        results = trainer.run_stage()

        # persist info in ddp_spawn
        self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)

        # ensure that spawned processes go through teardown before joining
        trainer._call_teardown_hook()
コード例 #2
0
ファイル: ddp.py プロジェクト: nunenuh/pytorch-lightning
    def setup_distributed(self):
        reset_seed()

        # determine which process we are and world size
        self.set_world_ranks()

        # set warning rank
        rank_zero_only.rank = self.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        init_ddp_connection(self.cluster_environment,
                            self.torch_distributed_backend)

        # set the ranks and devices
        self.dist.rank = self.global_rank
        self.dist.device = self.root_device