def clip_gradients( self, optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType, model: Optional[Module], ) -> None: """Override PTL gradient clipping. Model parallel models require gradient clipping from megatron-lm. """ if clip_val is None: return clip_val = float(clip_val) if clip_val <= 0: return app_state = AppState() if app_state.model_parallel_size is not None: parameters = model.parameters() clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) else: return super().clip_gradients( optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=model)
def configure_gradient_clipping(self, *args, **kwargs): """PTL hook to configure gradients. We use gradient clipping implementation from megatron-lm. """ clip_val = self.trainer.gradient_clip_val if clip_val is None: return clip_val = float(clip_val) if clip_val <= 0: return parameters = self.model.parameters() clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)
def configure_gradient_clipping(self, *args, **kwargs): """PTL hook to configure gradients. We use gradient clipping implementation from megatron-lm. """ clip_val = self.trainer.gradient_clip_val if clip_val is None: return clip_val = float(clip_val) if clip_val <= 0: return if self.grad_clip_pl_default: # use the default behavior return super().configure_gradient_clipping(*args, **kwargs) elif self.megatron_amp_o2: # grep fp32 master parameters for gradient clipping parameters = self._optimizer.get_parameters() else: parameters = self._get_parameters() grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) self.log('grad_norm', grad_norm, rank_zero_only=True)
def on_before_optimizer_step(self, optimizer, optimizer_idx): """PTL hook that is called after unscaling gradients when using native amp. We use gradient clipping implementation from megatron-lm. """ clip_val = self.trainer.gradient_clip_val if clip_val is None: return clip_val = float(clip_val) if clip_val <= 0: return if isinstance(self.trainer.accelerator_connector.precision_plugin, NLPNativeMixedPrecisionPlugin): parameters = self.model.parameters() clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) else: return