コード例 #1
0
ファイル: ddp.py プロジェクト: nunenuh/pytorch-lightning
    def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
        optimizers = self.lightning_module.trainer.optimizers
        if self._model_averaging_period is None:
            raise ValueError(
                "Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin."
            )
        averager = averagers.PeriodicModelAverager(
            period=self._model_averaging_period, warmup_steps=warmup_steps)
        for x, optimizer in enumerate(optimizers):
            if isinstance(optimizer, LightningOptimizer):
                optimizer = optimizer._optimizer

            if (isinstance(optimizer, DistributedOptimizer)
                    or isinstance(optimizer, ZeroRedundancyOptimizer)
                    or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))):
                raise ValueError(
                    f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
                )

            if isinstance(optimizer, PostLocalSGDOptimizer):
                continue

            optim_class = type(optimizer)
            post_localSGD_optimizer = PostLocalSGDOptimizer(
                params=optimizer.param_groups,
                optimizer_class=optim_class,
                averager=averager,
                **optimizer.defaults,
            )
            optimizers[x] = post_localSGD_optimizer
            del optimizer
        trainer = self.lightning_module.trainer
        trainer.optimizers = optimizers
        trainer.convert_to_lightning_optimizers()
コード例 #2
0
    def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
        log.detail(
            f"{self.__class__.__name__}: reinitializing optimizers with post localSGD"
        )
        optimizers = self.lightning_module.trainer.optimizers
        if self._model_averaging_period is None:
            raise ValueError(
                "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
            )
        if _TORCH_GREATER_EQUAL_1_10:
            if not _IS_WINDOWS:
                from torch.distributed.optim import DistributedOptimizer
            import torch.distributed.algorithms.model_averaging.averagers as averagers
            from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer

        averager = averagers.PeriodicModelAverager(
            period=self._model_averaging_period, warmup_steps=warmup_steps)
        for x, optimizer in enumerate(optimizers):
            if isinstance(optimizer, LightningOptimizer):
                optimizer = optimizer._optimizer

            is_distributed_optimizer = isinstance(
                optimizer, DistributedOptimizer) if not _IS_WINDOWS else False
            if (is_distributed_optimizer
                    or isinstance(optimizer, ZeroRedundancyOptimizer)
                    or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))):
                raise ValueError(
                    f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
                )

            if isinstance(optimizer, PostLocalSGDOptimizer):
                continue

            optim_class = type(optimizer)
            post_localSGD_optimizer = PostLocalSGDOptimizer(
                params=optimizer.param_groups,
                optimizer_class=optim_class,
                averager=averager,
                **optimizer.defaults,
            )
            optimizers[x] = post_localSGD_optimizer
            del optimizer
        self.optimizers = optimizers