def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
     if trainer.track_grad_norm == -1:
         return
     grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator)
     if grad_norm_dict:
         prev_fx = trainer.lightning_module._current_fx_name
         trainer.lightning_module._current_fx_name = "on_before_optimizer_step"
         trainer.lightning_module.log_grad_norm(grad_norm_dict)
         trainer.lightning_module._current_fx_name = prev_fx
    def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
        if trainer.track_grad_norm == -1:
            return

        kwargs = {}
        if len(trainer.loggers) == 1:
            kwargs["group_separator"] = trainer.loggers[0].group_separator

        grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
        if grad_norm_dict:
            prev_fx = trainer.lightning_module._current_fx_name
            trainer.lightning_module._current_fx_name = "on_before_optimizer_step"
            trainer.lightning_module.log_grad_norm(grad_norm_dict)
            trainer.lightning_module._current_fx_name = prev_fx
def test_grad_norm(norm_type, expected):
    """Test utility function for computing the p-norm of individual parameter groups and norm in total."""
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.param0 = nn.Parameter(torch.rand(3))
            self.param1 = nn.Parameter(torch.rand(2, 1))
            self.param0.grad = torch.tensor([-1.0, 2.0, -3.0])
            self.param1.grad = torch.tensor([[-4.0], [5.0]])
            # param without grad should not contribute to norm
            self.param2 = nn.Parameter(torch.rand(1))

    model = Model()
    norms = grad_norm(model, norm_type)
    expected = {k: round(v, 4) for k, v in expected.items()}
    assert norms == expected
    def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
        """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.

        Args:
            optimizer: the current optimizer
        """
        # track gradient norms
        grad_norm_dict = {}
        can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
        should_track = float(self.trainer.track_grad_norm) > 0
        if should_track and can_log:
            grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)

        # clip gradients
        self.trainer.accelerator.clip_gradients(
            optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
        )
        return grad_norm_dict
def test_grad_norm_invalid_norm_type(norm_type):
    with pytest.raises(ValueError,
                       match="`norm_type` must be a positive number or 'inf'"):
        grad_norm(Mock(), norm_type)