Example #1
0
 def __init__(
     self,
     parallel_devices: Optional[List[torch.device]] = None,
     num_nodes: Optional[int] = None,
     cluster_environment: ClusterEnvironment = None,
     sync_batchnorm: Optional[bool] = None,
     ddp_comm_state: Optional[object] = None,
     ddp_comm_hook: Optional[callable] = None,
     ddp_comm_wrapper: Optional[callable] = None,
     **kwargs: Union[Any, Dict[str, Any]],
 ) -> None:
     super().__init__(parallel_devices=parallel_devices,
                      cluster_environment=cluster_environment)
     self.interactive_ddp_procs = []
     if num_nodes is not None:
         rank_zero_deprecation(
             "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
             " Notice that it will be overriden by the trainer setting.")
     self._num_nodes = num_nodes or 1
     if sync_batchnorm is not None:
         rank_zero_deprecation(
             "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
             " Notice that it will be overriden by the trainer setting.")
     self._sync_batchnorm = sync_batchnorm or False
     self.dist = LightningDistributed()
     self.num_processes = len(
         self.parallel_devices) if self.parallel_devices is not None else 0
     self._ddp_kwargs = kwargs
     self._has_spawned_children = False
     self.task_idx = None
     self._ddp_comm_state = ddp_comm_state
     self._ddp_comm_hook = ddp_comm_hook
     self._ddp_comm_wrapper = ddp_comm_wrapper
     self.set_world_ranks()
Example #2
0
 def __init__(
     self,
     parallel_devices: Optional[List[torch.device]] = None,
     num_nodes: int = 1,
     cluster_environment: ClusterEnvironment = None,
     sync_batchnorm: bool = False,
     ddp_comm_state: Optional[object] = None,
     ddp_comm_hook: Optional[callable] = None,
     ddp_comm_wrapper: Optional[callable] = None,
     **kwargs: Union[Any, Dict[str, Any]],
 ) -> None:
     super().__init__(parallel_devices=parallel_devices,
                      cluster_environment=cluster_environment)
     self.interactive_ddp_procs = []
     self.num_nodes = num_nodes
     self.sync_batchnorm = sync_batchnorm
     self.dist = LightningDistributed()
     self._ddp_kwargs = kwargs
     self._has_spawned_children = False
     self.task_idx = None
     self.num_processes = len(
         parallel_devices
     ) if parallel_devices is not None else parallel_devices
     self._ddp_comm_state = ddp_comm_state
     self._ddp_comm_hook = ddp_comm_hook
     self._ddp_comm_wrapper = ddp_comm_wrapper
     self.set_world_ranks()
Example #3
0
 def __init__(
     self,
     parallel_devices,
     num_nodes=1,
     cluster_environment: ClusterEnvironment = None,
     sync_batchnorm=False,
     **kwargs: Dict[str, Any],
 ) -> None:
     super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
     self.interactive_ddp_procs = []
     self.num_nodes = num_nodes
     self.sync_batchnorm = sync_batchnorm
     self.dist = LightningDistributed()
     self._ddp_kwargs = kwargs
     self._has_spawned_children = False
     self.task_idx = None
     self.node_rank = 0
     self.num_processes = len(parallel_devices)
Example #4
0
    def __init__(self,
                 trainer,
                 nprocs: int,
                 cluster_environment: Optional[ClusterEnvironment] = None,
                 ddp_plugin: Optional[DDPPlugin] = None):
        """
        Runs training using DDP using mp.spawn via manual launch (not cluster launch)

        Example::

            # default
            trainer = Trainer(accelerator=DDPSpawnAccelerator())

        """
        super().__init__(trainer, cluster_environment, ddp_plugin)
        self.mp_queue = None
        self.nprocs = nprocs
        self.dist = LightningDistributed()
        self.nickname = 'ddp'
Example #5
0
    def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
        """
        Runs training using DP via manual start (not HPC cluster)

        Example::

            # default
            trainer = Trainer(accelerator=DataParallelAccelerator())

        """
        super().__init__(trainer, cluster_environment)
        self.model_autocast_original_forward = None
        self.dist = LightningDistributed()
        self.nickname = 'dp'
Example #6
0
class DDPPlugin(ParallelPlugin):
    """
    Plugin for multi-process single-device training on one or multiple nodes.

    The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`,
    where N is the number of devices (e.g. GPU) per node.
    It is very similar to how :mod:`torch.distributed.launch` launches processes.
    """

    distributed_backend = "ddp"

    def __init__(
        self,
        parallel_devices: Optional[List[torch.device]] = None,
        num_nodes: Optional[int] = None,
        cluster_environment: ClusterEnvironment = None,
        sync_batchnorm: Optional[bool] = None,
        ddp_comm_state: Optional[object] = None,
        ddp_comm_hook: Optional[callable] = None,
        ddp_comm_wrapper: Optional[callable] = None,
        **kwargs: Union[Any, Dict[str, Any]],
    ) -> None:
        super().__init__(parallel_devices=parallel_devices,
                         cluster_environment=cluster_environment)
        self.interactive_ddp_procs = []
        if num_nodes is not None:
            rank_zero_deprecation(
                "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
                " Notice that it will be overriden by the trainer setting.")
        self._num_nodes = num_nodes or 1
        if sync_batchnorm is not None:
            rank_zero_deprecation(
                "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
                " Notice that it will be overriden by the trainer setting.")
        self._sync_batchnorm = sync_batchnorm or False
        self.dist = LightningDistributed()
        self.num_processes = len(
            self.parallel_devices) if self.parallel_devices is not None else 0
        self._ddp_kwargs = kwargs
        self._has_spawned_children = False
        self._task_idx = None
        self._ddp_comm_state = ddp_comm_state
        self._ddp_comm_hook = ddp_comm_hook
        self._ddp_comm_wrapper = ddp_comm_wrapper
        self._pids: Optional[List[int]] = None
        self._sync_dir: Optional[str] = None
        self.set_world_ranks()

    @property
    def is_distributed(self) -> bool:
        return True

    @property
    def root_device(self) -> torch.device:
        return self.parallel_devices[self.local_rank]

    @property
    def num_nodes(self) -> int:
        return self._num_nodes

    @num_nodes.setter
    def num_nodes(self, num_nodes: int) -> None:
        # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
        self._num_nodes = num_nodes
        self.set_world_ranks()

    @property
    def sync_batchnorm(self) -> bool:
        return self._sync_batchnorm

    @sync_batchnorm.setter
    def sync_batchnorm(self, sync_batchnorm: bool) -> None:
        self._sync_batchnorm = sync_batchnorm

    @property
    def task_idx(self) -> Optional[int]:
        rank_zero_deprecation(
            f'`{self.__class__.__name__}.task_idx` is deprecated in v1.4 and will be removed in v1.6. Use '
            f'`{self.__class__.__name__}.local_rank` instead.')
        return self._task_idx

    @task_idx.setter
    def task_idx(self, task_idx: int) -> None:
        self._task_idx = task_idx

    @property
    def distributed_sampler_kwargs(self):
        distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes *
                                                        self.num_processes),
                                          rank=self.global_rank)
        return distributed_sampler_kwargs

    @property
    def _is_single_process_single_device(self) -> bool:
        return True

    def setup_environment(self) -> None:
        # start the other scripts
        if not self.cluster_environment.creates_children() and os.environ.get(
                "PL_IN_DDP_SUBPROCESS", "0") != "1":
            self._call_children_scripts()

        # set the task idx
        self.task_idx = self.cluster_environment.local_rank()

        self.setup_distributed()

    def _call_children_scripts(self):
        # bookkeeping of spawned processes
        assert self.local_rank == 0
        self._check_can_spawn_children()
        self._has_spawned_children = True

        # DDP Environment variables
        os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
        os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())

        # allow the user to pass the node rank
        os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
        os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())

        # create a temporary directory used to synchronize processes on deadlock.
        os.environ["PL_DDP_SYNC_TMPDIR"] = self._sync_dir = tempfile.mkdtemp()

        # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
        # See https://docs.python.org/3/reference/import.html#main-spec
        if __main__.__spec__ is None:  # pragma: no-cover
            # Script called as `python a/b/c.py`
            # when user is using hydra find the absolute path
            path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path

            # pull out the commands used to run the script and resolve the abs file path
            command = sys.argv
            try:
                full_path = path_lib(command[0])
            except Exception:
                full_path = os.path.abspath(command[0])

            command[0] = full_path
            # use the same python interpreter and actually running
            command = [sys.executable] + command
        else:  # Script called as `python -m a.b.c`
            command = [sys.executable, "-m", __main__.__spec__.name
                       ] + sys.argv[1:]

        # the visible devices tell us how many GPUs we want to use.
        # when the trainer script was called the device has already been scoped by the time
        # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
        # but forward the GPUs selected via environment variables
        if self.parallel_devices is None:
            raise MisconfigurationException(
                "you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)"
            )

        os.environ["PL_IN_DDP_SUBPROCESS"] = "1"

        os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"

        self.interactive_ddp_procs = []

        for local_rank in range(1, self.num_processes):
            env_copy = os.environ.copy()
            env_copy["LOCAL_RANK"] = f"{local_rank}"

            if self.lightning_module.logger is not None:
                # spawned processes must reference the same log dir, prevent auto-increment version
                env_copy["PL_EXP_VERSION"] = str(
                    self.lightning_module.logger.version)

            # remove env var if global seed not set
            if os.environ.get(
                    "PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
                del env_copy["PL_GLOBAL_SEED"]

            # start process
            # if hydra is available and initialized, make sure to set the cwd correctly
            cwd: Optional[str] = None
            if _HYDRA_AVAILABLE:
                if HydraConfig.initialized():
                    cwd = get_original_cwd()
                    os_cwd = f'"{os.getcwd()}"'
                    command += [
                        f'hydra.run.dir={os_cwd}',
                        f'hydra.job.name=train_ddp_process_{local_rank}'
                    ]
            proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
            self.interactive_ddp_procs.append(proc)

            # starting all processes at once can cause issues
            # with dataloaders delay between 1-10 seconds
            delay = np.random.uniform(1, 5, 1)[0]
            sleep(delay)

    def setup_distributed(self):
        reset_seed()

        # determine which process we are and world size
        self.set_world_ranks()

        # set warning rank
        rank_zero_only.rank = self.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        self.init_ddp_connection()

        # set the ranks and devices
        self.dist.rank = self.global_rank
        self.dist.device = self.root_device

    def _check_can_spawn_children(self):
        if self._has_spawned_children:
            raise RuntimeError(
                "You tried to run `.fit` or `.test` multiple times in the same script."
                " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
            )

    def set_world_ranks(self) -> None:
        if self.cluster_environment is None:
            return
        self.cluster_environment.set_global_rank(self.node_rank *
                                                 self.num_processes +
                                                 self.local_rank)
        self.cluster_environment.set_world_size(self.num_nodes *
                                                self.num_processes)
        rank_zero_only.rank = self.cluster_environment.global_rank()

    def pre_configure_ddp(self):
        # if unset, default `find_unused_parameters` `True`
        # Many models require setting this parameter to True, as there are corner cases
        # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
        # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
        self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
            "find_unused_parameters", True)
        # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization
        if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get(
                "find_unused_parameters", False):
            rank_zero_warn(
                "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
                "to properly work with DDP.")
            self._ddp_kwargs["find_unused_parameters"] = True

    def _register_ddp_hooks(self) -> None:
        # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
        # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
        if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device:
            register_ddp_comm_hook(
                model=self._model,
                ddp_comm_state=self._ddp_comm_state,
                ddp_comm_hook=self._ddp_comm_hook,
                ddp_comm_wrapper=self._ddp_comm_wrapper,
            )

    def configure_ddp(self):
        self.pre_configure_ddp()
        self._model = DistributedDataParallel(
            LightningDistributedModule(self.model),
            device_ids=self.determine_ddp_device_ids(),
            **self._ddp_kwargs,
        )
        self._register_ddp_hooks()

    def determine_ddp_device_ids(self):
        if self.root_device.type == "cpu":
            return None
        return [self.root_device.index]

    def init_ddp_connection(self,
                            global_rank: Optional[int] = None,
                            world_size: Optional[int] = None) -> None:
        global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank(
        )
        world_size = world_size if world_size is not None else self.cluster_environment.world_size(
        )
        os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
        os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
        if torch.distributed.is_available(
        ) and not torch.distributed.is_initialized():
            log.info(
                f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
            )
            torch.distributed.init_process_group(
                self.torch_distributed_backend,
                rank=global_rank,
                world_size=world_size)

            # on rank=0 let everyone know training is starting
            rank_zero_info(
                f"{'-' * 100}\n"
                f"distributed_backend={self.torch_distributed_backend}\n"
                f"All DDP processes registered. Starting ddp with {self.world_size} processes\n"
                f"{'-' * 100}\n")

    def pre_dispatch(self):
        # move the model to the correct device
        self.model_to_device()

        if self.sync_batchnorm:
            self.model = self.configure_sync_batchnorm(self.model)

        self.configure_ddp()

        # share ddp pids to all processes
        self._share_information_to_prevent_deadlock()

    def post_dispatch(self) -> None:
        self.cluster_environment.teardown()

    def barrier(self, *args, **kwargs) -> None:
        if not distributed_available():
            return
        if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend(
        ) == "nccl":
            torch.distributed.barrier(
                device_ids=self.determine_ddp_device_ids())
        else:
            torch.distributed.barrier()

    def broadcast(self, obj: object, src: int = 0) -> object:
        return self.dist.broadcast(obj)

    def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool,
                     optimizer: Optimizer, opt_idx: int):
        """Run before precision plugin executes backward"""
        if not self.lightning_module.automatic_optimization:
            prepare_for_backward(self.model, closure_loss)

    def model_to_device(self):
        self.model.to(self.root_device)

    def reduce(self,
               tensor,
               group: Optional[Any] = None,
               reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
        """
        Reduces a tensor from several distributed processes to one aggregated tensor.

        Args:
            tensor: the tensor to sync and reduce
            group: the process group to gather results from. Defaults to all processes (world)
            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
                Can also be a string 'sum' to calculate the sum during reduction.

        Return:
            reduced value, except when the input was not a tensor the output remains is unchanged
        """
        if isinstance(tensor, torch.Tensor):
            tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
        return tensor

    def training_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def validation_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def test_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def predict_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def post_training_step(self):
        if not self.lightning_module.automatic_optimization:
            self.model.require_backward_grad_sync = True

    @classmethod
    def register_plugins(cls, plugin_registry: Dict) -> None:
        plugin_registry.register(
            "ddp_find_unused_parameters_false",
            cls,
            description="DDP Plugin with `find_unused_parameters` as False",
            find_unused_parameters=False)

    def _share_information_to_prevent_deadlock(self):
        self._share_pids()

        # remove `PL_DDP_SYNC_TMPDIR` from os.environ
        self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None)

    def _share_pids(self):
        """
        Make all DDP processes aware of all processes pids.
        """
        self.barrier()
        pids = self.all_gather(
            torch.tensor(os.getpid(), device=self.root_device))
        pids = pids.cpu().numpy().tolist()
        self._pids = pids if isinstance(pids, list) else [pids]

    def reconciliate_processes(self, trace: str):
        if self.world_size < 2:
            return

        sync_dir = self._sync_dir

        # save a file locally.
        torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl"))

        # sleep for a short time
        time.sleep(3)

        # return if all processes wrote a file in the `sync_dir`.
        # todo (tchaton) Add support for non-shared file-system which will fail.
        if len(os.listdir(sync_dir)) == self.world_size:
            return

        for pid in self._pids:
            if pid != os.getpid():
                os.kill(pid, signal.SIGKILL)
            shutil.rmtree(sync_dir)
            raise DeadlockDetectedException(
                f"DeadLock detected from rank: {self.global_rank} \n {trace}")
Example #7
0
class DDPSpawnAccelerator(Accelerator):
    def __init__(self,
                 trainer,
                 nprocs: int,
                 cluster_environment: Optional[ClusterEnvironment] = None,
                 ddp_plugin: Optional[DDPPlugin] = None):
        """
        Runs training using DDP using mp.spawn via manual launch (not cluster launch)

        Example::

            # default
            trainer = Trainer(accelerator=DDPSpawnAccelerator())

        """
        super().__init__(trainer, cluster_environment, ddp_plugin)
        self.mp_queue = None
        self.nprocs = nprocs
        self.dist = LightningDistributed()
        self.nickname = 'ddp'

    def setup(self, model):
        os.environ['MASTER_PORT'] = os.environ.get(
            'MASTER_PORT', str(find_free_network_port()))

        # pass in a state q
        smp = mp.get_context('spawn')
        self.mp_queue = smp.SimpleQueue()

        self.trainer.model = model

    def train(self):
        model = self.trainer.model

        # train in children process
        mp.spawn(self.ddp_train,
                 nprocs=self.nprocs,
                 args=(
                     self.mp_queue,
                     model,
                 ))

        # restore main state with best weights
        best_path = self.mp_queue.get()
        results = self.mp_queue.get()
        last_path = self.mp_queue.get()

        # recover the weights of the processes trained in the children
        self.__recover_child_process_weights(model, best_path, last_path)
        return results

    def ddp_train(self,
                  process_idx,
                  mp_queue,
                  model,
                  is_master=False,
                  proc_offset=0):
        """
        Entry point for ddp

        Args:
            process_idx:
            mp_queue: multiprocessing queue
            model:
        """
        seed = os.environ.get("PL_GLOBAL_SEED")
        if seed is not None:
            seed_everything(int(seed))

        # offset the process id if requested
        process_idx = process_idx + proc_offset

        # show progressbar only on progress_rank 0
        if (self.trainer.node_rank != 0 or process_idx != 0
            ) and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

        # determine which process we are and world size
        self.set_world_ranks(process_idx)

        # set warning rank
        rank_zero_only.rank = self.trainer.global_rank

        # Initialize cuda device
        self.init_device(process_idx, is_master)

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        model.trainer = self.trainer
        self.init_ddp_connection(self.trainer.global_rank,
                                 self.trainer.world_size,
                                 self.trainer.is_slurm_managing_tasks)

        if isinstance(self.ddp_plugin, RPCPlugin):
            if not self.ddp_plugin.is_main_rpc_process:
                self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
                self.ddp_plugin.exit_rpc_process()
                if self.ddp_plugin.return_after_exit_rpc_process:
                    return
            else:
                self.ddp_plugin.on_main_rpc_connection(self.trainer)

        # call setup after the ddp process has connected
        self.trainer.call_setup_hook(model)

        # on world_size=0 let everyone know training is starting
        if self.trainer.is_global_zero and not torch.distributed.is_initialized(
        ):
            log.info('-' * 100)
            log.info(f'distributed_backend={self.trainer.distributed_backend}')
            log.info(
                f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes'
            )
            log.info('-' * 100)

        # call sync_bn before .cuda(), configure_apex and configure_ddp
        if self.trainer.sync_batchnorm:
            model = self.configure_sync_batchnorm(model)

        # move the model to the correct device
        self.model_to_device(model)

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        self.setup_optimizers(model)

        self.ddp_plugin.on_after_setup_optimizers(self.trainer)

        # 16-bit
        model = self.trainer.precision_connector.connect(model)

        # device ids change depending on the DDP setup
        device_ids = self.get_device_ids()

        # allow user to configure ddp
        model = self.configure_ddp(model, device_ids)

        self.trainer.setup_trainer(model)

        # train or test
        results = self.train_or_test()

        # get original model
        model = self.trainer.get_model()

        # persist info in ddp_spawn
        self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

        # clean up memory
        torch.cuda.empty_cache()

    def set_world_ranks(self, process_idx):
        self.trainer.local_rank = process_idx
        self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
        self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

    def init_device(self, process_idx, is_master):
        gpu_idx = self.trainer.data_parallel_device_ids[
            self.trainer.local_rank]
        self.trainer.root_gpu = gpu_idx
        torch.cuda.set_device(self.trainer.root_gpu)

    def model_to_device(self, model):
        model.cuda(self.trainer.root_gpu)

    def get_device_ids(self):
        device_ids = [self.trainer.root_gpu]
        return device_ids

    def training_step(self, args):
        return self._step(args)

    def validation_step(self, args):
        return self._step(args)

    def test_step(self, args):
        return self._step(args)

    def _step(self, args):
        args = self.ddp_plugin.on_before_forward(self.trainer.get_model(),
                                                 *args)
        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.trainer.model(*args)
        else:
            output = self.trainer.model(*args)
        return output

    def barrier(self, name: Optional[str] = None):
        if torch_distrib.is_initialized():
            torch_distrib.barrier()

    def early_stopping_should_stop(self, pl_module):
        stop = torch.tensor(int(self.trainer.should_stop),
                            device=pl_module.device)
        torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
        torch_distrib.barrier()
        should_stop = stop == self.trainer.world_size
        return should_stop

    def broadcast(self, obj, src=0):
        return self.dist.broadcast(obj)

    def __recover_child_process_weights(self, model, best_path, last_path):
        # transfer back the best path to the trainer
        if self.trainer.checkpoint_callback:
            self.trainer.checkpoint_callback.best_model_path = best_path
        # todo, pass also best score

        # load last weights
        if last_path is not None and not self.trainer.testing:
            ckpt = pl_load(last_path,
                           map_location=lambda storage, loc: storage)
            model.load_state_dict(ckpt)

        self.trainer.model = model

    def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue,
                                                results):
        best_model_path = None
        if self.trainer.checkpoint_callback is not None:
            best_model_path = self.trainer.checkpoint_callback.best_model_path

        if self.trainer.global_rank == 0 and mp_queue is not None:
            rank_zero_warn('cleaning up ddp environment...')
            # todo, pass complete checkpoint as state dictionary
            mp_queue.put(best_model_path)
            mp_queue.put(results)

            # save the last weights
            last_path = None
            if not self.trainer.testing and best_model_path is not None and len(
                    best_model_path) > 0:
                last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
                atomic_save(model.state_dict(), last_path)
            mp_queue.put(last_path)

    def configure_ddp(self, model: LightningModule,
                      device_ids: List[int]) -> DistributedDataParallel:
        self.ddp_plugin.device_ids = device_ids
        model = self.ddp_plugin.configure_ddp(model, device_ids)
        return model

    def configure_sync_batchnorm(self,
                                 model: LightningModule) -> LightningModule:
        """
        Add global batchnorm for a model spread across multiple GPUs and nodes.

        Override to synchronize batchnorm between specific process groups instead
        of the whole world or use a different sync_bn like `apex`'s version.

        Args:
            model: pointer to current :class:`LightningModule`.

        Return:
            LightningModule with batchnorm layers synchronized between process groups
        """
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
            model, process_group=None)

        return model

    def sync_tensor(
            self,
            tensor: Union[torch.Tensor],
            group: Optional[Any] = None,
            reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
        return sync_ddp_if_available(tensor, group, reduce_op)

    def all_gather(self,
                   tensor: Union[torch.Tensor],
                   group: Optional[Any] = None,
                   sync_grads: bool = False):
        """
        Function to gather a tensor from several distributed processes

        Args:
            tensor: tensor of shape (batch, ...)
            group: the process group to gather results from. Defaults to all processes (world)
            sync_grads: flag that allows users to synchronize gradients for all_gather op

        Return:
            A tensor of shape (world_size, batch, ...)
        """
        return all_gather_ddp_if_available(tensor,
                                           group=group,
                                           sync_grads=sync_grads)

    def get_reference_model(self, model) -> LightningModule:
        return self.ddp_plugin.get_model_from_plugin(model)

    @property
    def distributed_sampler_kwargs(self):
        distributed_sampler_kwargs = dict(num_replicas=self.trainer.num_nodes *
                                          self.trainer.num_processes,
                                          rank=self.trainer.global_rank)
        if self.ddp_plugin is not None:
            distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(
                distributed_sampler_kwargs)
        return distributed_sampler_kwargs

    @property
    def require_distributed_sampler(self):
        return True
Example #8
0
 def __init__(self, trainer):
     super().__init__(trainer)
     self.model_autocast_original_forward = None
     self.dist = LightningDistributed()
Example #9
0
class DDPPlugin(ParallelPlugin):
    """
    Plugin for multi-process single-device training on one or multiple nodes.

    The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`,
    where N is the number of devices (e.g. GPU) per node.
    It is very similar to how :mod:`torch.distributed.launch` launches processes.
    """

    distributed_backend = "ddp"

    def __init__(
        self,
        parallel_devices: Optional[List[torch.device]] = None,
        num_nodes: int = 1,
        cluster_environment: ClusterEnvironment = None,
        sync_batchnorm: bool = False,
        **kwargs: Union[Any, Dict[str, Any]],
    ) -> None:
        super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
        self.interactive_ddp_procs = []
        self.num_nodes = num_nodes
        self.sync_batchnorm = sync_batchnorm
        self.dist = LightningDistributed()
        self._ddp_kwargs = kwargs
        self._has_spawned_children = False
        self.task_idx = None
        self.node_rank = 0
        self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices

    @property
    def root_device(self):
        return self.parallel_devices[self.local_rank]

    @property
    def distributed_sampler_kwargs(self):
        distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
        return distributed_sampler_kwargs

    def setup_environment(self):
        # start the other scripts
        if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
            self._call_children_scripts()

        # set the task idx
        self.task_idx = self.cluster_environment.local_rank()

        self.setup_distributed()

    def _call_children_scripts(self):

        # bookkeeping of spawned processes
        assert self.global_rank == 0
        self._check_can_spawn_children()
        self._has_spawned_children = True

        # DDP Environment variables
        os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
        os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())

        # allow the user to pass the node rank
        os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
        os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())

        # when user is using hydra find the absolute path
        path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path

        # pull out the commands used to run the script and resolve the abs file path
        command = sys.argv
        try:
            full_path = path_lib(command[0])
        except Exception:
            full_path = os.path.abspath(command[0])

        command[0] = full_path
        # use the same python interpreter and actually running
        command = [sys.executable] + command

        # the visible devices tell us how many GPUs we want to use.
        # when the trainer script was called the device has already been scoped by the time
        # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
        # but forward the GPUs selected via environment variables
        if self.parallel_devices is None:
            raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)")

        os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices])
        os.environ["PL_IN_DDP_SUBPROCESS"] = "1"

        if self.lightning_module.logger is not None:
            os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version)

        num_gpus = len(self.parallel_devices)
        os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}"

        self.interactive_ddp_procs = []

        for local_rank in range(1, self.num_processes):
            env_copy = os.environ.copy()
            env_copy["LOCAL_RANK"] = f"{local_rank}"

            # remove env var if global seed not set
            if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
                del env_copy["PL_GLOBAL_SEED"]

            # start process
            # if hydra is available and initialized, make sure to set the cwd correctly
            cwd: Optional[str] = None
            if _HYDRA_AVAILABLE:
                if HydraConfig.initialized():
                    cwd = get_original_cwd()
                    os_cwd = f'"{os.getcwd()}"'
                    command += [f'hydra.run.dir={os_cwd}', f'hydra.job.name=train_ddp_process_{local_rank}']
            proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
            self.interactive_ddp_procs.append(proc)

            # starting all processes at once can cause issues
            # with dataloaders delay between 1-10 seconds
            delay = np.random.uniform(1, 5, 1)[0]
            sleep(delay)

    def setup_distributed(self):
        # TODO: check if needed
        seed = os.environ.get("PL_GLOBAL_SEED")
        if seed is not None:
            seed_everything(int(seed))

        # determine which process we are and world size
        self.set_world_ranks()

        # set warning rank
        rank_zero_only.rank = self.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        self.init_ddp_connection(self.global_rank, self.world_size)

        # on world_size=0 let everyone know training is starting
        if self.is_global_zero and not torch.distributed.is_initialized():
            log.info("-" * 100)
            log.info(f"distributed_backend={self.distributed_backend}")
            log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
            log.info("-" * 100)

        # set the ranks and devices
        self.dist.rank = self.global_rank
        self.dist.device = self.root_device

    def _check_can_spawn_children(self):
        if self._has_spawned_children:
            raise RuntimeError(
                "You tried to run `.fit` or `.test` multiple times in the same script."
                " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
            )

    def set_world_ranks(self):
        self.local_rank = self.task_idx
        self.node_rank = self.cluster_environment.node_rank()
        self.global_rank = self.node_rank * self.num_processes + self.local_rank
        self.world_size = self.num_nodes * self.num_processes

    def pre_configure_ddp(self):
        # if unset, default `find_unused_parameters` `True`
        # Many models require setting this parameter to True, as there are corner cases
        # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
        # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
        self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True)
        # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization
        if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get(
            "find_unused_parameters", False
        ):
            rank_zero_warn(
                "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
                "to properly work with DDP."
            )
            self._ddp_kwargs["find_unused_parameters"] = True

    def configure_ddp(self):
        self.pre_configure_ddp()
        self._model = DistributedDataParallel(
            LightningDistributedModule(self.model),
            device_ids=self.determine_ddp_device_ids(),
            **self._ddp_kwargs,
        )

    def determine_ddp_device_ids(self):
        if self.root_device.type == "cpu":
            return None
        return [self.root_device.index]

    def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
        os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
        os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
        os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

        if not torch.distributed.is_initialized():
            log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
            torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

    def pre_dispatch(self):
        if self.sync_batchnorm:
            self.model = self.configure_sync_batchnorm(self.model)

        # move the model to the correct device
        self.model_to_device()

        self.configure_ddp()

        self.barrier()

    def post_dispatch(self):
        if "WORLD_SIZE" in os.environ:
            del os.environ["WORLD_SIZE"]

    def barrier(self, *args, **kwargs):
        if torch_distrib.is_initialized():
            torch_distrib.barrier()

    def broadcast(self, obj: object, src: int = 0) -> object:
        return self.dist.broadcast(obj)

    def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
        """Run before precision plugin executes backward"""
        if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
            prepare_for_backward(self.model, closure_loss)

    def model_to_device(self):
        if self.root_device.type == "cuda":
            torch.cuda.set_device(self.root_device)
        self.model.to(self.root_device)

    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
        """
        Reduces a tensor from several distributed processes to one aggregated tensor.

        Args:
            tensor: the tensor to sync and reduce
            group: the process group to gather results from. Defaults to all processes (world)
            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
                Can also be a string 'sum' to calculate the sum during reduction.

        Return:
            reduced value, except when the input was not a tensor the output remains is unchanged
        """
        if isinstance(tensor, torch.Tensor):
            tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
        return tensor

    def training_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def validation_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def test_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def predict_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def post_training_step(self):
        if not self.lightning_module.automatic_optimization:
            self.model.require_backward_grad_sync = True
Example #10
0
class DDPPlugin(ParallelPlugin):
    """Plugin for multi-process single-device training on one or multiple nodes.

    The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of
    devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes.
    """

    distributed_backend = "ddp"

    def __init__(
        self,
        parallel_devices: Optional[List[torch.device]] = None,
        num_nodes: Optional[int] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        sync_batchnorm: Optional[bool] = None,
        ddp_comm_state: Optional[object] = None,
        ddp_comm_hook: Optional[callable] = None,
        ddp_comm_wrapper: Optional[callable] = None,
        model_averaging_period: Optional[int] = None,
        **kwargs: Union[Any, Dict[str, Any]],
    ) -> None:
        super().__init__(
            parallel_devices=parallel_devices,
            cluster_environment=cluster_environment,
            checkpoint_io=checkpoint_io,
        )
        self.interactive_ddp_procs = []
        if num_nodes is not None:
            rank_zero_deprecation(
                "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
                " Notice that it will be overriden by the trainer setting.")
        self._num_nodes = num_nodes or 1
        if sync_batchnorm is not None:
            rank_zero_deprecation(
                "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
                " Notice that it will be overriden by the trainer setting.")
        self._sync_batchnorm = sync_batchnorm or False
        self.dist = LightningDistributed()
        self.num_processes = len(
            self.parallel_devices) if self.parallel_devices is not None else 0
        self._ddp_kwargs = kwargs
        self._task_idx = None
        self._ddp_comm_state = ddp_comm_state
        self._ddp_comm_hook = ddp_comm_hook
        self._ddp_comm_wrapper = ddp_comm_wrapper
        self._model_averaging_period = model_averaging_period
        self._pids: Optional[List[int]] = None
        self._sync_dir: Optional[str] = None
        self.set_world_ranks()

    @property
    def is_distributed(self) -> bool:
        return True

    @property
    def root_device(self) -> torch.device:
        return self.parallel_devices[self.local_rank]

    @property
    def num_nodes(self) -> int:
        return self._num_nodes

    @num_nodes.setter
    def num_nodes(self, num_nodes: int) -> None:
        # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
        self._num_nodes = num_nodes
        self.set_world_ranks()

    @property
    def sync_batchnorm(self) -> bool:
        return self._sync_batchnorm

    @sync_batchnorm.setter
    def sync_batchnorm(self, sync_batchnorm: bool) -> None:
        self._sync_batchnorm = sync_batchnorm

    @property
    def task_idx(self) -> Optional[int]:
        rank_zero_deprecation(
            f"`{self.__class__.__name__}.task_idx` is deprecated in v1.4 and will be removed in v1.6. Use "
            f"`{self.__class__.__name__}.local_rank` instead.")
        return self._task_idx

    @task_idx.setter
    def task_idx(self, task_idx: int) -> None:
        self._task_idx = task_idx

    @property
    def distributed_sampler_kwargs(self):
        distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes *
                                                        self.num_processes),
                                          rank=self.global_rank)
        return distributed_sampler_kwargs

    @property
    def _is_single_process_single_device(self) -> bool:
        return True

    def setup_environment(self) -> None:
        # start the other scripts
        if not self.cluster_environment.creates_children():
            self._call_children_scripts()

        # set the task idx
        self.task_idx = self.cluster_environment.local_rank()

        self.setup_distributed()

    def _call_children_scripts(self):
        # bookkeeping of spawned processes
        self._check_can_spawn_children()

        # DDP Environment variables
        os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
        os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())

        # allow the user to pass the node rank
        os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
        os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())

        # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
        # See https://docs.python.org/3/reference/import.html#main-spec
        if __main__.__spec__ is None:  # pragma: no-cover
            # Script called as `python a/b/c.py`
            # when user is using hydra find the absolute path
            path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path

            # pull out the commands used to run the script and resolve the abs file path
            command = sys.argv
            try:
                full_path = path_lib(command[0])
            except Exception:
                full_path = os.path.abspath(command[0])

            command[0] = full_path
            # use the same python interpreter and actually running
            command = [sys.executable] + command
        else:  # Script called as `python -m a.b.c`
            command = [sys.executable, "-m", __main__.__spec__.name
                       ] + sys.argv[1:]

        # the visible devices tell us how many GPUs we want to use.
        # when the trainer script was called the device has already been scoped by the time
        # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
        # but forward the GPUs selected via environment variables
        if self.parallel_devices is None:
            raise MisconfigurationException(
                "you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)"
            )

        os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"

        self.interactive_ddp_procs = []

        for local_rank in range(1, self.num_processes):
            env_copy = os.environ.copy()
            env_copy["LOCAL_RANK"] = f"{local_rank}"

            # remove env var if global seed not set
            if os.environ.get(
                    "PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
                del env_copy["PL_GLOBAL_SEED"]

            # start process
            # if hydra is available and initialized, make sure to set the cwd correctly
            cwd: Optional[str] = None
            if _HYDRA_AVAILABLE:
                if HydraConfig.initialized():
                    cwd = get_original_cwd()
                    os_cwd = f'"{os.getcwd()}"'
                    command += [
                        f"hydra.run.dir={os_cwd}",
                        f"hydra.job.name=train_ddp_process_{local_rank}"
                    ]
            proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
            self.interactive_ddp_procs.append(proc)

            # starting all processes at once can cause issues
            # with dataloaders delay between 1-10 seconds
            delay = np.random.uniform(1, 5, 1)[0]
            sleep(delay)

    def setup_distributed(self):
        reset_seed()

        # determine which process we are and world size
        self.set_world_ranks()

        # set warning rank
        rank_zero_only.rank = self.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        init_ddp_connection(self.cluster_environment,
                            self.torch_distributed_backend)

        # set the ranks and devices
        self.dist.rank = self.global_rank
        self.dist.device = self.root_device

    def _check_can_spawn_children(self):
        if self.local_rank != 0:
            raise RuntimeError(
                "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen."
                " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user,"
                " 2) `ClusterEnvironment.creates_children()` incorrectly implemented."
            )

    def set_world_ranks(self) -> None:
        if self.cluster_environment is None:
            return
        self.cluster_environment.set_global_rank(self.node_rank *
                                                 self.num_processes +
                                                 self.local_rank)
        self.cluster_environment.set_world_size(self.num_nodes *
                                                self.num_processes)
        rank_zero_only.rank = self.cluster_environment.global_rank()

    def pre_configure_ddp(self):
        # if unset, default `find_unused_parameters` `True`
        # Many models require setting this parameter to True, as there are corner cases
        # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
        # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
        self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
            "find_unused_parameters", True)
        # todo: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization
        if (_TORCH_GREATER_EQUAL_1_7
                and not self.lightning_module.automatic_optimization
                and not self._ddp_kwargs.get("find_unused_parameters", False)):
            rank_zero_warn(
                "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
                "to properly work with DDP.")
            self._ddp_kwargs["find_unused_parameters"] = True

    def _register_ddp_hooks(self) -> None:
        # In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
        # Since 1.9, DDP communication hooks can work on all backends.
        if _TORCH_GREATER_EQUAL_1_9 or (_TORCH_GREATER_EQUAL_1_8
                                        and self.on_gpu and
                                        self._is_single_process_single_device):
            register_ddp_comm_hook(
                model=self._model,
                ddp_comm_state=self._ddp_comm_state,
                ddp_comm_hook=self._ddp_comm_hook,
                ddp_comm_wrapper=self._ddp_comm_wrapper,
            )

            if (_TORCH_GREATER_EQUAL_1_10 and isinstance(
                    self._ddp_comm_state, post_localSGD.PostLocalSGDState)
                    and self.lightning_module.trainer.state.fn
                    == TrainerFn.FITTING):
                self._reinit_optimizers_with_post_localSGD(
                    self._ddp_comm_state.start_localSGD_iter)

    def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
        optimizers = self.lightning_module.trainer.optimizers
        if self._model_averaging_period is None:
            raise ValueError(
                "Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin."
            )
        averager = averagers.PeriodicModelAverager(
            period=self._model_averaging_period, warmup_steps=warmup_steps)
        for x, optimizer in enumerate(optimizers):
            if isinstance(optimizer, LightningOptimizer):
                optimizer = optimizer._optimizer

            if (isinstance(optimizer, DistributedOptimizer)
                    or isinstance(optimizer, ZeroRedundancyOptimizer)
                    or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))):
                raise ValueError(
                    f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
                )

            if isinstance(optimizer, PostLocalSGDOptimizer):
                continue

            optim_class = type(optimizer)
            post_localSGD_optimizer = PostLocalSGDOptimizer(
                params=optimizer.param_groups,
                optimizer_class=optim_class,
                averager=averager,
                **optimizer.defaults,
            )
            optimizers[x] = post_localSGD_optimizer
            del optimizer
        trainer = self.lightning_module.trainer
        trainer.optimizers = optimizers
        trainer.convert_to_lightning_optimizers()

    def configure_ddp(self) -> None:
        self.pre_configure_ddp()
        self._model = DistributedDataParallel(
            LightningDistributedModule(self.model),
            device_ids=self.determine_ddp_device_ids(),
            **self._ddp_kwargs)
        self._register_ddp_hooks()

    def determine_ddp_device_ids(self):
        if self.root_device.type == "cpu":
            return None
        return [self.root_device.index]

    def pre_dispatch(self):
        # share ddp pids to all processes
        self._share_information_to_prevent_deadlock()

        # move the model to the correct device
        self.model_to_device()

        if self.sync_batchnorm:
            self.model = self.configure_sync_batchnorm(self.model)

        # skip wrapping the model if we are not fitting as no gradients need to be exchanged
        trainer_fn = self.lightning_module.trainer.state.fn
        if trainer_fn == TrainerFn.FITTING:
            self.configure_ddp()

    def post_dispatch(self, trainer: "pl.Trainer") -> None:
        self.cluster_environment.teardown()

    def barrier(self, *args, **kwargs) -> None:
        if not distributed_available():
            return
        if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend(
        ) == "nccl":
            torch.distributed.barrier(
                device_ids=self.determine_ddp_device_ids())
        else:
            torch.distributed.barrier()

    def broadcast(self, obj: object, src: int = 0) -> object:
        return self.dist.broadcast(obj)

    def pre_backward(self, closure_loss: torch.Tensor) -> None:
        """Run before precision plugin executes backward."""
        if not self.lightning_module.automatic_optimization:
            prepare_for_backward(self.model, closure_loss)

    def model_to_device(self):
        self.model.to(self.root_device)

    def reduce(self,
               tensor,
               group: Optional[Any] = None,
               reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
        """Reduces a tensor from several distributed processes to one aggregated tensor.

        Args:
            tensor: the tensor to sync and reduce
            group: the process group to gather results from. Defaults to all processes (world)
            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
                Can also be a string 'sum' to calculate the sum during reduction.

        Return:
            reduced value, except when the input was not a tensor the output remains is unchanged
        """
        if isinstance(tensor, torch.Tensor):
            tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
        return tensor

    def training_step(self, *args, **kwargs) -> Optional[Any]:
        return self.model(*args, **kwargs)

    def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
        if isinstance(self.model, DistributedDataParallel):
            # used when calling `trainer.fit`
            return self.model(*args, **kwargs)
        else:
            # used when calling `trainer.validate`
            return self.lightning_module.validation_step(*args, **kwargs)

    def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
        return self.lightning_module.test_step(*args, **kwargs)

    def predict_step(self, *args, **kwargs) -> Any:
        return self.lightning_module.predict_step(*args, **kwargs)

    def post_training_step(self):
        if not self.lightning_module.automatic_optimization:
            self.model.require_backward_grad_sync = True

    @classmethod
    def register_plugins(cls, plugin_registry: Dict) -> None:
        plugin_registry.register(
            "ddp_find_unused_parameters_false",
            cls,
            description="DDP Plugin with `find_unused_parameters` as False",
            find_unused_parameters=False,
        )

    def _share_information_to_prevent_deadlock(self):
        self._share_pids()

        # there should be a unique sync_dir per nodes.
        if self.local_rank == 0:
            # create a temporary directory used to synchronize processes on deadlock.
            self._sync_dir = tempfile.mkdtemp()

        sync_dirs = []
        global_node_rank_zero = 0
        for _ in range(self.num_nodes):
            sync_dirs.append(
                self.broadcast(self._sync_dir, global_node_rank_zero))
            global_node_rank_zero += self.world_size // self.num_nodes

        self._sync_dir = sync_dirs[self.node_rank]

    def _share_pids(self):
        """Make all DDP processes aware of all processes pids."""
        self.barrier()
        pids = self.all_gather(
            torch.tensor(os.getpid(), device=self.root_device))
        pids = pids.cpu().numpy().tolist()
        self._pids = pids if isinstance(pids, list) else [pids]

    def reconciliate_processes(self, trace: str):
        if self.world_size < 2:
            return

        sync_dir = self._sync_dir

        if not sync_dir:
            rank_zero_warn(
                "Error handling mechanism for deadlock detection is uninitialized. Skipping check."
            )
            return

        # The cluster may be configured to periodically purge the `/tmp`
        # directory, in which case `sync_dir` may not exist anymore at this
        # point. Idempotently create it to ensure its existence.
        Path(sync_dir).mkdir(parents=True, exist_ok=True)

        # save a file locally.
        torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl"))

        # sleep for a short time
        time.sleep(3)

        # return if all processes wrote a file in the `sync_dir`.
        # todo (tchaton) Add support for non-shared file-system which will fail.
        if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes):
            return

        for pid in self._pids:
            if pid != os.getpid():
                os.kill(pid, signal.SIGKILL)
        shutil.rmtree(sync_dir)
        raise DeadlockDetectedException(
            f"DeadLock detected from rank: {self.global_rank} \n {trace}")

    def teardown(self) -> None:
        if isinstance(self.model, DistributedDataParallel):
            self.model = self.lightning_module

        if self.on_gpu:
            # GPU teardown
            self.lightning_module.cpu()
            # clean up memory
            torch.cuda.empty_cache()
Example #11
0
class DDPPlugin(ParallelPlugin):
    """
    Plugin for multi-process single-device training on one or multiple nodes.

    The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`,
    where N is the number of devices (e.g. GPU) per node.
    It is very similar to how :mod:`torch.distributed.launch` launches processes.
    """

    distributed_backend = "ddp"

    def __init__(
        self,
        parallel_devices,
        num_nodes=1,
        cluster_environment: ClusterEnvironment = None,
        sync_batchnorm=False,
        **kwargs: Dict[str, Any],
    ) -> None:
        super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
        self.interactive_ddp_procs = []
        self.num_nodes = num_nodes
        self.sync_batchnorm = sync_batchnorm
        self.dist = LightningDistributed()
        self._ddp_kwargs = kwargs
        self._has_spawned_children = False
        self.task_idx = None
        self.node_rank = 0
        self.num_processes = len(parallel_devices)

    @property
    def root_device(self):
        return self.parallel_devices[self.local_rank]

    @property
    def distributed_sampler_kwargs(self):
        distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
        return distributed_sampler_kwargs

    def setup(self, model):
        self._model = model

        # start the other scripts
        # TODO: make sure this works, in torchelastic we should not launch child processes!
        if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
            self._call_children_scripts()

        # set the task idx
        self.task_idx = self.cluster_environment.local_rank()

    def _call_children_scripts(self):

        # bookkeeping of spawned processes
        assert self.global_rank == 0
        self._check_can_spawn_children()
        self._has_spawned_children = True

        # DDP Environment variables
        os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
        os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port()))

        # allow the user to pass the node rank
        node_rank = "0"
        node_rank = os.environ.get("NODE_RANK", node_rank)
        node_rank = os.environ.get("GROUP_RANK", node_rank)
        os.environ["NODE_RANK"] = node_rank
        os.environ["LOCAL_RANK"] = "0"

        # when user is using hydra find the absolute path
        path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path

        # pull out the commands used to run the script and resolve the abs file path
        command = sys.argv
        try:
            full_path = path_lib(command[0])
        except Exception as e:
            full_path = os.path.abspath(command[0])

        command[0] = full_path
        # use the same python interpreter and actually running
        command = [sys.executable] + command

        # the visible devices tell us how many GPUs we want to use.
        # when the trainer script was called the device has already been scoped by the time
        # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
        # but forward the GPUs selected via environment variables
        if self.parallel_devices is None:
            raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)")

        os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices])
        os.environ["PL_IN_DDP_SUBPROCESS"] = "1"

        if self.lightning_module.logger is not None:
            os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version)

        num_gpus = len(self.parallel_devices)
        os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}"

        self.interactive_ddp_procs = []

        for local_rank in range(1, self.num_processes):
            env_copy = os.environ.copy()
            env_copy["LOCAL_RANK"] = f"{local_rank}"

            # remove env var if global seed not set
            if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
                del env_copy["PL_GLOBAL_SEED"]

            # start process
            # if hydra is available and initialized, make sure to set the cwd correctly
            cwd: Optional[str] = None
            if _HYDRA_AVAILABLE:
                if HydraConfig.initialized():
                    cwd = get_original_cwd()
            proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
            self.interactive_ddp_procs.append(proc)

            # starting all processes at once can cause issues
            # with dataloaders delay between 1-10 seconds
            delay = np.random.uniform(1, 5, 1)[0]
            sleep(delay)

    def _check_can_spawn_children(self):
        if self._has_spawned_children:
            raise RuntimeError(
                "You tried to run `.fit` or `.test` multiple times in the same script."
                " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
            )

    def set_world_ranks(self):
        self.local_rank = self.task_idx
        self.node_rank = self.cluster_environment.node_rank()
        self.global_rank = self.node_rank * self.num_processes + self.local_rank
        self.world_size = self.num_nodes * self.num_processes

    def configure_ddp(self):
        self._model = DistributedDataParallel(
            LightningDistributedModule(self.model),
            device_ids=self.determine_ddp_device_ids(),
            **self._ddp_kwargs,
        )

    def determine_ddp_device_ids(self):
        if self.root_device.type == "cpu":
            return None
        return [self.root_device.index]

    def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
        # TODO: From where to get cluster environment?
        os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
        os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
        os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())
        torch_backend = "nccl" if self.on_gpu else "gloo"

        if not torch.distributed.is_initialized():
            log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
            torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

    def pre_training(self):
        # TODO: check if needed
        seed = os.environ.get("PL_GLOBAL_SEED")
        if seed is not None:
            seed_everything(int(seed))

        # determine which process we are and world size
        self.set_world_ranks()

        # set warning rank
        rank_zero_only.rank = self.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        self.init_ddp_connection(self.global_rank, self.world_size)

        # TODO: we moved it to the trainer.fit after calling pre_training
        #   ... need to double check that it is the correct place
        # self.trainer.call_setup_hook(self.model)

        # on world_size=0 let everyone know training is starting
        if self.is_global_zero and not torch.distributed.is_initialized():
            log.info("-" * 100)
            log.info(f"distributed_backend={self.distributed_backend}")
            log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
            log.info("-" * 100)

        # set the ranks and devices
        self.dist.rank = self.global_rank
        self.dist.device = self.root_device

        if self.sync_batchnorm:
            self.model = self.configure_sync_batchnorm(self.model)

        # move the model to the correct device
        self.model_to_device()

        self.configure_ddp()

        self.barrier()

    def post_training(self):
        if "WORLD_SIZE" in os.environ:
            del os.environ["WORLD_SIZE"]

    def barrier(self, *args, **kwargs):
        if torch_distrib.is_initialized():
            torch_distrib.barrier()

    def broadcast(self, obj: object, src: int = 0) -> object:
        return self.dist.broadcast(obj)

    def model_to_device(self):
        if self.root_device.type == "cuda":
            torch.cuda.set_device(self.root_device)
        self.model.to(self.root_device)

    def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
        if isinstance(output, torch.Tensor):
            output = sync_ddp_if_available(output, group, reduce_op)
        return output

    def training_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def validation_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def test_step(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Example #12
0
 def __init__(self, trainer, cluster_environment=None):
     super().__init__(trainer, cluster_environment)
     self.model_autocast_original_forward = None
     self.dist = LightningDistributed()
     self.nickname = 'dp'