示例#1
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the norm test.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            foat: Result of the norm test.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)
        D_axis = 1
        individual_l2_norms_squared = (
            individual_gradients_flat**2).sum(D_axis)

        grad = torch.cat([p.grad.flatten() for p in params])
        grad_norm = grad.norm()

        projections = torch.einsum("ni,i->n", individual_gradients_flat, grad)

        batch_size = get_batch_size(global_step)

        return ((1 / (batch_size * (batch_size - 1)) *
                 (individual_l2_norms_squared / grad_norm**2 -
                  (projections**2) / grad_norm**4).sum()).sqrt().item())
示例#2
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the MeanGSNR.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Mean GSNR of the current iteration.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)

        grad_squared = torch.cat([p.grad.flatten() for p in params])**2

        N_axis = 0
        second_moment_flat = (individual_gradients_flat**2).mean(N_axis)

        gsnr = grad_squared / (second_moment_flat - grad_squared +
                               self._epsilon)

        return gsnr.mean().item()
示例#3
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the norm test.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Result of the norm test.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(
            losses, params, concat=True
        )
        sum_of_squares = (individual_gradients_flat ** 2).sum()

        grad_norm = torch.cat([p.grad.flatten() for p in params]).norm()

        batch_size = get_batch_size(global_step)

        return (
            (
                1
                / (batch_size * (batch_size - 1))
                * (sum_of_squares / grad_norm ** 2 - batch_size)
            )
            .sqrt()
            .item()
        )
示例#4
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the TIC approximating the Hessian by its diagonal.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Raises:
            NotImplementedError: If curvature is not ``diag_h``.

        Returns:
            float: TIC when approximation the Hessian with its diagonal.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)

        if self._curvature == "diag_h":
            diag_curvature_flat = autograd_diag_hessian(batch_loss,
                                                        params,
                                                        concat=True)
        else:
            raise NotImplementedError("Only Hessian diagonal is implemented")

        N_axis = 0
        second_moment_flat = (individual_gradients_flat**2).mean(N_axis)

        return (second_moment_flat /
                (diag_curvature_flat + self._epsilon)).sum().item()
示例#5
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the TIC approximation proposed in thomas2020interplay.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Raises:
            NotImplementedError: If curvature is not ``diag_h``.

        Returns:
            float: TIC using a trace approximation.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)

        if self._curvature == "diag_h":
            curvature_trace = autograd_diag_hessian(batch_loss,
                                                    params,
                                                    concat=True).sum()
        else:
            raise NotImplementedError("Only Hessian trace is implemented")

        N_axis = 0
        mean_squared_l2_norm = (
            individual_gradients_flat**2).mean(N_axis).sum()

        return (mean_squared_l2_norm /
                (curvature_trace + self._epsilon)).item()
示例#6
0
    def _fetch_values(self, params, batch_loss, pos, global_step):
        """Fetch values for quadratic fit. Return as dictionary.

        The entry "search_dir" is only initialized if ``pos`` is ``"start"``.

        Args:
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
            pos (str): Whether we are at the start or end of an iteration.
                One of ``start`` or ``end``.
            global_step (int): The current iteration number.

        Raises:
            ValueError: If pos is not one of ``start`` or ``end``.

        Returns:
            dict: Holding the parameters, (variance of) loss and slope.
        """
        info = {}

        if pos in ["start", "end"]:
            # 0ᵗʰ order info
            info["f"] = batch_loss.item()
            losses = get_individual_losses(global_step)
            info["var_f"] = losses.var().item()

            # temporary information required to compute quantities used in fit
            info["params"] = {id(p): p.data.clone().detach() for p in params}
            info["grad"] = {id(p): p.grad.data.clone().detach() for p in params}
            batch_grad = autograd_individual_gradients(losses, params)
            info["batch_grad"] = {
                id(p): bg.data.clone().detach() for p, bg in zip(params, batch_grad)
            }

        else:
            raise ValueError(f"Invalid position '{pos}'. Expect {self._positions}.")

        # compute all quantities used in fit
        if pos == "end":
            start_params, _ = self._get_info("params", end=False)
            end_params = info["params"]

            search_dir = [
                end_params[key] - start_params[key] for key in start_params.keys()
            ]

            for info_dict in [self._start_info, info]:
                grad = [info_dict["grad"][key] for key in start_params.keys()]
                batch_grad = [
                    info_dict["batch_grad"][key] for key in start_params.keys()
                ]

                # 1ˢᵗ order info
                info_dict["df"] = _projected_gradient(grad, search_dir)
                info_dict["var_df"] = _exact_variance(batch_grad, search_dir)

        return info
示例#7
0
    def _get_abs_max(self, global_step, params, batch_loss, range):
        """Compute the maximum absolute value of individual gradient elements.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
            range (float, float): Current bin limits.

        Returns:
            float: Maximum absolute value of individual gradients.
        """
        individual_losses = get_individual_losses(global_step)
        individual_gradients = autograd_individual_gradients(individual_losses,
                                                             params,
                                                             concat=True)
        return individual_gradients.abs().max().item()
示例#8
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the individual gradient histogram.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            dict: Entry ``'hist'`` holds the histogram, entry ``'edges'`` holds
                the bin limits.
        """
        individual_losses = get_individual_losses(global_step)
        individual_gradients = autograd_individual_gradients(individual_losses,
                                                             params,
                                                             concat=True)
        hist, edges = self._compute_histogram(individual_gradients)

        return {"hist": hist.float(), "edges": edges}
示例#9
0
    def _compute(self, global_step, params, batch_loss):
        """Aggregate histogram data over parameters and save to output."""
        individual_losses = get_individual_losses(global_step)
        individual_gradients = autograd_individual_gradients(
            individual_losses, params)
        layerwise = [
            self._compute_histogram(p, igrad)
            for p, igrad in zip(params, individual_gradients)
        ]

        hist = sum(out[0] for out in layerwise)
        edges = layerwise[0][1]

        result = {"hist": hist, "edges": edges}

        if self._keep_individual:
            result["param_groups"] = len(params)

            for idx, (hist, edges) in enumerate(layerwise):
                result[f"param_{idx}"] = {"hist": hist, "edges": edges}

        return result