def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm and updates dynamic loss scaler.""" self._sync_fp16_grads_to_fp32() grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm, aggregate_norm_fn) # detect overflow and adjust loss scale if self.scaler is not None: overflow = DynamicLossScaler.has_overflow(grad_norm) prev_scale = self.scaler.loss_scale self.scaler.update_scale(overflow) if overflow: if self.scaler.loss_scale <= self.min_loss_scale: # Use FloatingPointError as an uncommon error that parent # functions can safely catch to stop training. self.scaler.loss_scale = prev_scale raise FloatingPointError(( 'Minimum loss scale reached ({}). Your loss is probably exploding. ' 'Try lowering the learning rate, using gradient clipping or ' 'increasing the batch size.').format( self.min_loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) return grad_norm
def test_clip_grad_norm_(self): params = torch.nn.Parameter(torch.zeros(5)).requires_grad_(False) grad_norm = utils.clip_grad_norm_(params, 1.0) self.assertTrue(torch.is_tensor(grad_norm)) self.assertEqual(grad_norm, 0.0) params = [torch.nn.Parameter(torch.zeros(5)) for i in range(3)] for p in params: p.grad = torch.full((5,), fill_value=2.0) grad_norm = utils.clip_grad_norm_(params, 1.0) exp_grad_norm = torch.full((15,), fill_value=2.0).norm() self.assertTrue(torch.is_tensor(grad_norm)) self.assertEqual(grad_norm, exp_grad_norm) grad_norm = utils.clip_grad_norm_(params, 1.0) self.assertAlmostEqual(grad_norm, torch.tensor(1.0))
def _all_reduce_and_rescale(self, grad_denom): # undo effect of dynamic loss scaling on gradients grad_denom *= self.scaler.loss_scale if self.args.distributed_world_size > 1: # flatten grads into a single buffer flat_grads = self._flat_grads = self._get_flat_grads( self._flat_grads) # scale gradients to avoid overflow in all-reduce flat_grads.div_(self.args.distributed_world_size) grad_denom /= self.args.distributed_world_size # all-reduce flat grads torch.distributed.all_reduce(flat_grads) # copy grads back to FP32 self.fp32_params.grad.data.copy_(flat_grads) else: # single worker: copy grads directly to FP32 self._get_flat_grads(out=self.fp32_params.grad.data) # rescale and clip grads self.fp32_params.grad.data.div_(grad_denom) grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm) # detect overflow and adjust loss scale overflow = DynamicLossScaler.has_overflow(grad_norm) self.scaler.update_scale(overflow) if overflow: raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) return grad_norm
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None, mode='total'): """Clips gradient norm and updates dynamic loss scaler.""" self._sync_fp16_grads_to_fp32() grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm, aggregate_norm_fn, mode) # detect overflow and adjust loss scale if self.scaler is not None: self.scaler.check_overflow(grad_norm) return grad_norm
def _all_reduce_and_rescale(self, grad_denom): # flatten grads into a single buffer and all-reduce flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) if self.args.distributed_world_size > 1: torch.distributed.all_reduce(flat_grads) # rescale and clip gradients flat_grads.div_(grad_denom) grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm) # copy grads back into model parameters self._set_flat_grads(flat_grads) return grad_norm
def _all_reduce_and_rescale(self, grad_denom): # flatten grads into a single buffer and all-reduce flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) if self.args.distributed_world_size > 1: torch.distributed.all_reduce(flat_grads) # rescale and clip gradients # flat_grads.div_(grad_denom) # to make correct rl gradient update, remove this step into loss calculation grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm) # copy grads back into model parameters self._set_flat_grads(flat_grads) return grad_norm
def clip_grad_norm(self, max_norm): """Clips gradient norm and updates dynamic loss scaler.""" self._sync_fp16_grads_to_fp32() grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm) # detect overflow and adjust loss scale overflow = DynamicLossScaler.has_overflow(grad_norm) self.scaler.update_scale(overflow) if overflow: if self.scaler.loss_scale <= self.args.min_loss_scale: raise Exception(( 'Minimum loss scale reached ({}). Your loss is probably exploding. ' 'Try lowering the learning rate, using gradient clipping or ' 'increasing the batch size.').format( self.args.min_loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) return grad_norm
def clip_grad_norm(self, max_norm): """Clips gradient norm and updates dynamic loss scaler.""" self._sync_fp16_grads_to_fp32() grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm) # detect overflow and adjust loss scale overflow = DynamicLossScaler.has_overflow(grad_norm) self.scaler.update_scale(overflow) if overflow: if self.scaler.loss_scale <= self.min_loss_scale: # Use FloatingPointError as an uncommon error that parent # functions can safely catch to stop training. raise FloatingPointError(( "Minimum loss scale reached ({}). Your loss is probably exploding. " "Try lowering the learning rate, using gradient clipping or " "increasing the batch size.").format(self.min_loss_scale)) raise OverflowError("setting loss scale to: " + str(self.scaler.loss_scale)) return grad_norm
def _all_reduce_and_rescale(self, grad_denom): # undo effect of dynamic loss scaling on gradients grad_denom *= self.scaler.loss_scale if self.args.distributed_world_size > 1: # flatten grads into a single buffer flat_grads = self._flat_grads = self._get_flat_grads( self._flat_grads) # scale gradients to avoid overflow in all-reduce flat_grads.div_(self.args.distributed_world_size) grad_denom /= self.args.distributed_world_size # all-reduce flat grads torch.distributed.all_reduce(flat_grads) # copy grads back to FP32 self.fp32_params.grad.data.copy_(flat_grads) else: # single worker: copy grads directly to FP32 self._get_flat_grads(out=self.fp32_params.grad.data) # rescale and clip grads self.fp32_params.grad.data.div_(grad_denom) grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm) # detect overflow and adjust loss scale overflow = DynamicLossScaler.has_overflow(grad_norm) self.scaler.update_scale(overflow) if overflow: if self.scaler.loss_scale <= self.args.min_loss_scale: raise Exception(( 'Minimum loss scale reached ({}). Your loss is probably exploding. ' 'Try lowering the learning rate, using gradient clipping or ' 'increasing the batch size.').format( self.args.min_loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) return grad_norm
def _all_reduce_and_rescale(self, grad_denom): # undo effect of dynamic loss scaling on gradients grad_denom *= self.scaler.loss_scale if self.args.distributed_world_size > 1: # flatten grads into a single buffer flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) # scale gradients to avoid overflow in all-reduce flat_grads.div_(self.args.distributed_world_size) grad_denom /= self.args.distributed_world_size # all-reduce flat grads torch.distributed.all_reduce(flat_grads) # copy grads back to FP32 self.fp32_params.grad.data.copy_(flat_grads) else: # single worker: copy grads directly to FP32 self._get_flat_grads(out=self.fp32_params.grad.data) # rescale and clip grads self.fp32_params.grad.data.div_(grad_denom) grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm) # detect overflow and adjust loss scale overflow = DynamicLossScaler.has_overflow(grad_norm) self.scaler.update_scale(overflow) if overflow: if self.scaler.loss_scale <= self.args.min_loss_scale: raise Exception(( 'Minimum loss scale reached ({}). Your loss is probably exploding. ' 'Try lowering the learning rate, using gradient clipping or ' 'increasing the batch size.' ).format(self.args.min_loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) return grad_norm
def clip_grad_norm(self, max_norm, loss): """Clips gradient norm and updates dynamic loss scaler.""" self._sync_fp16_grads_to_fp32() grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm) # detect overflow and adjust loss scale overflow = DynamicLossScaler.has_overflow(grad_norm) self.scaler.update_scale(overflow) if overflow: if self.scaler.loss_scale <= self.args.min_loss_scale: print ("**********************************fp32 params data", self.fp32_params.data.norm(), flush=True) print ("**********************************fp32 abs max", self.fp32_params.data.abs().max(), flush=True) print ("**********************************fp32 params grad data", self.fp32_params.grad.data.norm(),flush=True) print ("**********************************grad norm ", grad_norm) print ("*******************!!!!!!!!!!!!!", loss) # Use FloatingPointError as an uncommon error that parent # functions can safely catch to stop training. raise FloatingPointError(( 'Minimum loss scale reached ({}). Your loss is probably exploding. ' 'Try lowering the learning rate, using gradient clipping or ' 'increasing the batch size.' ).format(self.args.min_loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) return grad_norm
def clip_grad_norm(self, max_norm): """Clips gradient norm.""" return utils.clip_grad_norm_(self.params, max_norm)
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm.""" return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)