def __getstate__(self): # unwrap optimizer self.optimizers = [ opt._optimizer if is_lightning_optimizer(opt) else opt for opt in self.optimizers ] return self.__dict__
def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: if is_lightning_optimizer(optimizer): optimizer = optimizer._optimizer if isinstance(optimizer, OSS): optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer)
def _reinit_with_fairscale_oss(self, trainer): optimizers = trainer.optimizers for x, optimizer in enumerate(optimizers): if is_lightning_optimizer(optimizer): optimizer = optimizer.optimizer if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) optimizers[x] = zero_optimizer del optimizer
def _reinit_optimizers_with_oss(self): optimizers = self.lightning_module.trainer.optimizers for x, optimizer in enumerate(optimizers): if is_lightning_optimizer(optimizer): optimizer = optimizer._optimizer if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) optimizers[x] = zero_optimizer del optimizer trainer = self.lightning_module.trainer trainer.optimizers = trainer.convert_to_lightning_optimizers(optimizers)
def _reinit_optimizers_with_oss(self): optimizers = self.lightning_module.trainer.optimizers for x, optimizer in enumerate(optimizers): if is_lightning_optimizer(optimizer): optimizer = optimizer._optimizer if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: is_fp16 = self.lightning_module.trainer.precision == 16 # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer trainer = self.lightning_module.trainer trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers()
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if is_lightning_optimizer(optimizer): optimizer = optimizer._optimizer optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer)