def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer return optimizers
def _reinit_optimizers_with_oss(self): optimizers = self.lightning_module.trainer.optimizers for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: precision = self.lightning_module.trainer.precision is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer trainer = self.lightning_module.trainer trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers()