def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): optimizers = self.lightning_module.trainer.optimizers if self._model_averaging_period is None: raise ValueError( "Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin." ) averager = averagers.PeriodicModelAverager( period=self._model_averaging_period, warmup_steps=warmup_steps) for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if (isinstance(optimizer, DistributedOptimizer) or isinstance(optimizer, ZeroRedundancyOptimizer) or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))): raise ValueError( f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer." ) if isinstance(optimizer, PostLocalSGDOptimizer): continue optim_class = type(optimizer) post_localSGD_optimizer = PostLocalSGDOptimizer( params=optimizer.param_groups, optimizer_class=optim_class, averager=averager, **optimizer.defaults, ) optimizers[x] = post_localSGD_optimizer del optimizer trainer = self.lightning_module.trainer trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers()
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): log.detail( f"{self.__class__.__name__}: reinitializing optimizers with post localSGD" ) optimizers = self.lightning_module.trainer.optimizers if self._model_averaging_period is None: raise ValueError( "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." ) if _TORCH_GREATER_EQUAL_1_10: if not _IS_WINDOWS: from torch.distributed.optim import DistributedOptimizer import torch.distributed.algorithms.model_averaging.averagers as averagers from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer averager = averagers.PeriodicModelAverager( period=self._model_averaging_period, warmup_steps=warmup_steps) for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer is_distributed_optimizer = isinstance( optimizer, DistributedOptimizer) if not _IS_WINDOWS else False if (is_distributed_optimizer or isinstance(optimizer, ZeroRedundancyOptimizer) or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))): raise ValueError( f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer." ) if isinstance(optimizer, PostLocalSGDOptimizer): continue optim_class = type(optimizer) post_localSGD_optimizer = PostLocalSGDOptimizer( params=optimizer.param_groups, optimizer_class=optim_class, averager=averager, **optimizer.defaults, ) optimizers[x] = post_localSGD_optimizer del optimizer self.optimizers = optimizers