Esempio n. 1
0
 def __getstate__(self):
     # unwrap optimizer
     self.optimizers = [
         opt._optimizer if is_lightning_optimizer(opt) else opt
         for opt in self.optimizers
     ]
     return self.__dict__
Esempio n. 2
0
    def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
        if is_lightning_optimizer(optimizer):
            optimizer = optimizer._optimizer

        if isinstance(optimizer, OSS):
            optimizer.consolidate_state_dict()
        return self._optim_state_dict(optimizer)
 def _reinit_with_fairscale_oss(self, trainer):
     optimizers = trainer.optimizers
     for x, optimizer in enumerate(optimizers):
         if is_lightning_optimizer(optimizer):
             optimizer = optimizer.optimizer
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups,
                                  optim=optim_class,
                                  **optimizer.defaults)
             optimizers[x] = zero_optimizer
             del optimizer
Esempio n. 4
0
 def _reinit_optimizers_with_oss(self):
     optimizers = self.lightning_module.trainer.optimizers
     for x, optimizer in enumerate(optimizers):
         if is_lightning_optimizer(optimizer):
             optimizer = optimizer._optimizer
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
             optimizers[x] = zero_optimizer
             del optimizer
     trainer = self.lightning_module.trainer
     trainer.optimizers = trainer.convert_to_lightning_optimizers(optimizers)
Esempio n. 5
0
 def _reinit_optimizers_with_oss(self):
     optimizers = self.lightning_module.trainer.optimizers
     for x, optimizer in enumerate(optimizers):
         if is_lightning_optimizer(optimizer):
             optimizer = optimizer._optimizer
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups,
                                  optim=optim_class,
                                  **optimizer.defaults)
             if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
                 is_fp16 = self.lightning_module.trainer.precision == 16
                 # For multi-node training, compressing the model shards in fp16 before broadcasting
                 # improves performance. When using PyTorch AMP, it will not degrade
                 # the model performance.
                 zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
             optimizers[x] = zero_optimizer
             del optimizer
     trainer = self.lightning_module.trainer
     trainer.optimizers = optimizers
     trainer.convert_to_lightning_optimizers()
Esempio n. 6
0
 def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
     if is_lightning_optimizer(optimizer):
         optimizer = optimizer._optimizer
     optimizer.consolidate_state_dict()
     return self._optim_state_dict(optimizer)