示例#1
0
 def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
     for x, optimizer in enumerate(optimizers):
         if isinstance(optimizer, LightningOptimizer):
             optimizer = optimizer._optimizer
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
             if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
                 is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)
                 # For multi-node training, compressing the model shards in fp16 before broadcasting
                 # improves performance. When using PyTorch AMP, it will not degrade
                 # the model performance.
                 zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
             optimizers[x] = zero_optimizer
             del optimizer
     return optimizers
示例#2
0
 def _reinit_optimizers_with_oss(self):
     optimizers = self.lightning_module.trainer.optimizers
     for x, optimizer in enumerate(optimizers):
         if isinstance(optimizer, LightningOptimizer):
             optimizer = optimizer._optimizer
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
             if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
                 precision = self.lightning_module.trainer.precision
                 is_fp16 = precision in ("mixed", 16)
                 # For multi-node training, compressing the model shards in fp16 before broadcasting
                 # improves performance. When using PyTorch AMP, it will not degrade
                 # the model performance.
                 zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
             optimizers[x] = zero_optimizer
             del optimizer
     trainer = self.lightning_module.trainer
     trainer.optimizers = optimizers
     trainer.convert_to_lightning_optimizers()