Beispiel #1
0
    def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
        """Clips gradient norm and updates dynamic loss scaler."""
        self._sync_fp16_grads_to_fp32()
        grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm,
                                          aggregate_norm_fn)

        # detect overflow and adjust loss scale
        if self.scaler is not None:
            overflow = DynamicLossScaler.has_overflow(grad_norm)
            prev_scale = self.scaler.loss_scale
            self.scaler.update_scale(overflow)
            if overflow:
                if self.scaler.loss_scale <= self.min_loss_scale:
                    # Use FloatingPointError as an uncommon error that parent
                    # functions can safely catch to stop training.
                    self.scaler.loss_scale = prev_scale
                    raise FloatingPointError((
                        'Minimum loss scale reached ({}). Your loss is probably exploding. '
                        'Try lowering the learning rate, using gradient clipping or '
                        'increasing the batch size.').format(
                            self.min_loss_scale))
                raise OverflowError('setting loss scale to: ' +
                                    str(self.scaler.loss_scale))

        return grad_norm
Beispiel #2
0
    def test_clip_grad_norm_(self):
        params = torch.nn.Parameter(torch.zeros(5)).requires_grad_(False)
        grad_norm = utils.clip_grad_norm_(params, 1.0)
        self.assertTrue(torch.is_tensor(grad_norm))
        self.assertEqual(grad_norm, 0.0)

        params = [torch.nn.Parameter(torch.zeros(5)) for i in range(3)]
        for p in params:
            p.grad = torch.full((5,), fill_value=2.0)
        grad_norm = utils.clip_grad_norm_(params, 1.0)
        exp_grad_norm = torch.full((15,), fill_value=2.0).norm()
        self.assertTrue(torch.is_tensor(grad_norm))
        self.assertEqual(grad_norm, exp_grad_norm)

        grad_norm = utils.clip_grad_norm_(params, 1.0)
        self.assertAlmostEqual(grad_norm, torch.tensor(1.0))
Beispiel #3
0
    def _all_reduce_and_rescale(self, grad_denom):
        # undo effect of dynamic loss scaling on gradients
        grad_denom *= self.scaler.loss_scale

        if self.args.distributed_world_size > 1:
            # flatten grads into a single buffer
            flat_grads = self._flat_grads = self._get_flat_grads(
                self._flat_grads)

            # scale gradients to avoid overflow in all-reduce
            flat_grads.div_(self.args.distributed_world_size)
            grad_denom /= self.args.distributed_world_size

            # all-reduce flat grads
            torch.distributed.all_reduce(flat_grads)

            # copy grads back to FP32
            self.fp32_params.grad.data.copy_(flat_grads)
        else:
            # single worker: copy grads directly to FP32
            self._get_flat_grads(out=self.fp32_params.grad.data)

        # rescale and clip grads
        self.fp32_params.grad.data.div_(grad_denom)
        grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data,
                                          self.args.clip_norm)

        # detect overflow and adjust loss scale
        overflow = DynamicLossScaler.has_overflow(grad_norm)
        self.scaler.update_scale(overflow)
        if overflow:
            raise OverflowError('setting loss scale to: ' +
                                str(self.scaler.loss_scale))

        return grad_norm
    def clip_grad_norm(self, max_norm, aggregate_norm_fn=None, mode='total'):
        """Clips gradient norm and updates dynamic loss scaler."""
        self._sync_fp16_grads_to_fp32()
        grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm, aggregate_norm_fn, mode)

        # detect overflow and adjust loss scale
        if self.scaler is not None:
            self.scaler.check_overflow(grad_norm)

        return grad_norm
Beispiel #5
0
    def _all_reduce_and_rescale(self, grad_denom):
        # flatten grads into a single buffer and all-reduce
        flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
        if self.args.distributed_world_size > 1:
            torch.distributed.all_reduce(flat_grads)

        # rescale and clip gradients
        flat_grads.div_(grad_denom)
        grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)

        # copy grads back into model parameters
        self._set_flat_grads(flat_grads)

        return grad_norm
Beispiel #6
0
    def _all_reduce_and_rescale(self, grad_denom):
        # flatten grads into a single buffer and all-reduce
        flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
        if self.args.distributed_world_size > 1:
            torch.distributed.all_reduce(flat_grads)

        # rescale and clip gradients
        # flat_grads.div_(grad_denom)   # to make correct rl gradient update, remove this step into loss calculation
        grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)

        # copy grads back into model parameters
        self._set_flat_grads(flat_grads)

        return grad_norm
Beispiel #7
0
    def _all_reduce_and_rescale(self, grad_denom):
        # flatten grads into a single buffer and all-reduce
        flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
        if self.args.distributed_world_size > 1:
            torch.distributed.all_reduce(flat_grads)

        # rescale and clip gradients
        flat_grads.div_(grad_denom)
        grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)

        # copy grads back into model parameters
        self._set_flat_grads(flat_grads)

        return grad_norm
Beispiel #8
0
    def clip_grad_norm(self, max_norm):
        """Clips gradient norm and updates dynamic loss scaler."""
        self._sync_fp16_grads_to_fp32()
        grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)

        # detect overflow and adjust loss scale
        overflow = DynamicLossScaler.has_overflow(grad_norm)
        self.scaler.update_scale(overflow)
        if overflow:
            if self.scaler.loss_scale <= self.args.min_loss_scale:
                raise Exception((
                    'Minimum loss scale reached ({}). Your loss is probably exploding. '
                    'Try lowering the learning rate, using gradient clipping or '
                    'increasing the batch size.').format(
                        self.args.min_loss_scale))
            raise OverflowError('setting loss scale to: ' +
                                str(self.scaler.loss_scale))
        return grad_norm
Beispiel #9
0
    def clip_grad_norm(self, max_norm):
        """Clips gradient norm and updates dynamic loss scaler."""
        self._sync_fp16_grads_to_fp32()
        grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)

        # detect overflow and adjust loss scale
        overflow = DynamicLossScaler.has_overflow(grad_norm)
        self.scaler.update_scale(overflow)
        if overflow:
            if self.scaler.loss_scale <= self.min_loss_scale:
                # Use FloatingPointError as an uncommon error that parent
                # functions can safely catch to stop training.
                raise FloatingPointError((
                    "Minimum loss scale reached ({}). Your loss is probably exploding. "
                    "Try lowering the learning rate, using gradient clipping or "
                    "increasing the batch size.").format(self.min_loss_scale))
            raise OverflowError("setting loss scale to: " +
                                str(self.scaler.loss_scale))
        return grad_norm
Beispiel #10
0
    def _all_reduce_and_rescale(self, grad_denom):
        # undo effect of dynamic loss scaling on gradients
        grad_denom *= self.scaler.loss_scale

        if self.args.distributed_world_size > 1:
            # flatten grads into a single buffer
            flat_grads = self._flat_grads = self._get_flat_grads(
                self._flat_grads)

            # scale gradients to avoid overflow in all-reduce
            flat_grads.div_(self.args.distributed_world_size)
            grad_denom /= self.args.distributed_world_size

            # all-reduce flat grads
            torch.distributed.all_reduce(flat_grads)

            # copy grads back to FP32
            self.fp32_params.grad.data.copy_(flat_grads)
        else:
            # single worker: copy grads directly to FP32
            self._get_flat_grads(out=self.fp32_params.grad.data)

        # rescale and clip grads
        self.fp32_params.grad.data.div_(grad_denom)
        grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data,
                                          self.args.clip_norm)

        # detect overflow and adjust loss scale
        overflow = DynamicLossScaler.has_overflow(grad_norm)
        self.scaler.update_scale(overflow)
        if overflow:
            if self.scaler.loss_scale <= self.args.min_loss_scale:
                raise Exception((
                    'Minimum loss scale reached ({}). Your loss is probably exploding. '
                    'Try lowering the learning rate, using gradient clipping or '
                    'increasing the batch size.').format(
                        self.args.min_loss_scale))
            raise OverflowError('setting loss scale to: ' +
                                str(self.scaler.loss_scale))

        return grad_norm
Beispiel #11
0
    def _all_reduce_and_rescale(self, grad_denom):
        # undo effect of dynamic loss scaling on gradients
        grad_denom *= self.scaler.loss_scale

        if self.args.distributed_world_size > 1:
            # flatten grads into a single buffer
            flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)

            # scale gradients to avoid overflow in all-reduce
            flat_grads.div_(self.args.distributed_world_size)
            grad_denom /= self.args.distributed_world_size

            # all-reduce flat grads
            torch.distributed.all_reduce(flat_grads)

            # copy grads back to FP32
            self.fp32_params.grad.data.copy_(flat_grads)
        else:
            # single worker: copy grads directly to FP32
            self._get_flat_grads(out=self.fp32_params.grad.data)

        # rescale and clip grads
        self.fp32_params.grad.data.div_(grad_denom)
        grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm)

        # detect overflow and adjust loss scale
        overflow = DynamicLossScaler.has_overflow(grad_norm)
        self.scaler.update_scale(overflow)
        if overflow:
            if self.scaler.loss_scale <= self.args.min_loss_scale:
                raise Exception((
                    'Minimum loss scale reached ({}). Your loss is probably exploding. '
                    'Try lowering the learning rate, using gradient clipping or '
                    'increasing the batch size.'
                ).format(self.args.min_loss_scale))
            raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))

        return grad_norm
Beispiel #12
0
    def clip_grad_norm(self, max_norm, loss):
        """Clips gradient norm and updates dynamic loss scaler."""
        self._sync_fp16_grads_to_fp32()
        grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)

        # detect overflow and adjust loss scale
        overflow = DynamicLossScaler.has_overflow(grad_norm)
        self.scaler.update_scale(overflow)
        if overflow:
            if self.scaler.loss_scale <= self.args.min_loss_scale:
                print ("**********************************fp32 params data", self.fp32_params.data.norm(), flush=True)
                print ("**********************************fp32 abs max",  self.fp32_params.data.abs().max(), flush=True)
                print ("**********************************fp32 params grad data", self.fp32_params.grad.data.norm(),flush=True)
                print ("**********************************grad norm ", grad_norm)
                print ("*******************!!!!!!!!!!!!!", loss)
                # Use FloatingPointError as an uncommon error that parent
                # functions can safely catch to stop training.
                raise FloatingPointError((
                    'Minimum loss scale reached ({}). Your loss is probably exploding. '
                    'Try lowering the learning rate, using gradient clipping or '
                    'increasing the batch size.'
                ).format(self.args.min_loss_scale))
            raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
        return grad_norm
 def clip_grad_norm(self, max_norm):
     """Clips gradient norm."""
     return utils.clip_grad_norm_(self.params, max_norm)
 def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
     """Clips gradient norm."""
     return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)