Ejemplo n.º 1
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the MeanGSNR.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Mean GSNR of the current iteration.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)

        grad_squared = torch.cat([p.grad.flatten() for p in params])**2

        N_axis = 0
        second_moment_flat = (individual_gradients_flat**2).mean(N_axis)

        gsnr = grad_squared / (second_moment_flat - grad_squared +
                               self._epsilon)

        return gsnr.mean().item()
Ejemplo n.º 2
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the norm test.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            foat: Result of the norm test.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)
        D_axis = 1
        individual_l2_norms_squared = (
            individual_gradients_flat**2).sum(D_axis)

        grad = torch.cat([p.grad.flatten() for p in params])
        grad_norm = grad.norm()

        projections = torch.einsum("ni,i->n", individual_gradients_flat, grad)

        batch_size = get_batch_size(global_step)

        return ((1 / (batch_size * (batch_size - 1)) *
                 (individual_l2_norms_squared / grad_norm**2 -
                  (projections**2) / grad_norm**4).sum()).sqrt().item())
Ejemplo n.º 3
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the CABS criterion.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Evaluated CABS criterion.

        Raises:
            ValueError: If the optimizer differs from SGD with default arguments.
        """
        optimizer = get_optimizer(global_step)
        if not ComputeStep.is_sgd_default_kwargs(optimizer):
            raise ValueError("This criterion only supports zero-momentum SGD.")

        losses = get_individual_losses(global_step)
        batch_axis = 0
        trace_variance = autograd_diagonal_variance(
            losses, params, concat=True, unbiased=False).sum(batch_axis)
        lr = self.get_lr(optimizer)

        return (lr * trace_variance / batch_loss).item()
Ejemplo n.º 4
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the norm test.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Result of the norm test.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(
            losses, params, concat=True
        )
        sum_of_squares = (individual_gradients_flat ** 2).sum()

        grad_norm = torch.cat([p.grad.flatten() for p in params]).norm()

        batch_size = get_batch_size(global_step)

        return (
            (
                1
                / (batch_size * (batch_size - 1))
                * (sum_of_squares / grad_norm ** 2 - batch_size)
            )
            .sqrt()
            .item()
        )
Ejemplo n.º 5
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the TIC approximating the Hessian by its diagonal.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Raises:
            NotImplementedError: If curvature is not ``diag_h``.

        Returns:
            float: TIC when approximation the Hessian with its diagonal.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)

        if self._curvature == "diag_h":
            diag_curvature_flat = autograd_diag_hessian(batch_loss,
                                                        params,
                                                        concat=True)
        else:
            raise NotImplementedError("Only Hessian diagonal is implemented")

        N_axis = 0
        second_moment_flat = (individual_gradients_flat**2).mean(N_axis)

        return (second_moment_flat /
                (diag_curvature_flat + self._epsilon)).sum().item()
Ejemplo n.º 6
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the TIC approximation proposed in thomas2020interplay.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Raises:
            NotImplementedError: If curvature is not ``diag_h``.

        Returns:
            float: TIC using a trace approximation.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)

        if self._curvature == "diag_h":
            curvature_trace = autograd_diag_hessian(batch_loss,
                                                    params,
                                                    concat=True).sum()
        else:
            raise NotImplementedError("Only Hessian trace is implemented")

        N_axis = 0
        mean_squared_l2_norm = (
            individual_gradients_flat**2).mean(N_axis).sum()

        return (mean_squared_l2_norm /
                (curvature_trace + self._epsilon)).item()
Ejemplo n.º 7
0
    def _fetch_values(self, params, batch_loss, pos, global_step):
        """Fetch values for quadratic fit. Return as dictionary.

        The entry "search_dir" is only initialized if ``pos`` is ``"start"``.

        Args:
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
            pos (str): Whether we are at the start or end of an iteration.
                One of ``start`` or ``end``.
            global_step (int): The current iteration number.

        Raises:
            ValueError: If pos is not one of ``start`` or ``end``.

        Returns:
            dict: Holding the parameters, (variance of) loss and slope.
        """
        info = {}

        if pos in ["start", "end"]:
            # 0ᵗʰ order info
            info["f"] = batch_loss.item()
            losses = get_individual_losses(global_step)
            info["var_f"] = losses.var().item()

            # temporary information required to compute quantities used in fit
            info["params"] = {id(p): p.data.clone().detach() for p in params}
            info["grad"] = {id(p): p.grad.data.clone().detach() for p in params}
            batch_grad = autograd_individual_gradients(losses, params)
            info["batch_grad"] = {
                id(p): bg.data.clone().detach() for p, bg in zip(params, batch_grad)
            }

        else:
            raise ValueError(f"Invalid position '{pos}'. Expect {self._positions}.")

        # compute all quantities used in fit
        if pos == "end":
            start_params, _ = self._get_info("params", end=False)
            end_params = info["params"]

            search_dir = [
                end_params[key] - start_params[key] for key in start_params.keys()
            ]

            for info_dict in [self._start_info, info]:
                grad = [info_dict["grad"][key] for key in start_params.keys()]
                batch_grad = [
                    info_dict["batch_grad"][key] for key in start_params.keys()
                ]

                # 1ˢᵗ order info
                info_dict["df"] = _projected_gradient(grad, search_dir)
                info_dict["var_df"] = _exact_variance(batch_grad, search_dir)

        return info
Ejemplo n.º 8
0
    def _fetch_values(self, params, batch_loss, pos, global_step):
        """Fetch values for quadratic fit. Return as dictionary.

        The entry "search_dir" is only initialized if ``pos`` is ``"start"``.
        """
        info = {}

        if pos in ["start", "end"]:
            # 0ᵗʰ order info
            info["f"] = batch_loss.item()
            info["var_f"] = get_individual_losses(global_step).var().item()

            # temporary information required to compute quantities used in fit
            info["params"] = {id(p): p.data.clone().detach() for p in params}
            info["grad"] = {
                id(p): p.grad.data.clone().detach()
                for p in params
            }
            # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's computes ¹/ₙ ∇ℓᵢ, we have to rescale
            batch_size = get_batch_size(global_step)
            info["batch_grad"] = {
                id(p): batch_size * p.grad_batch.data.clone().detach()
                for p in params
            }

        else:
            raise ValueError(
                f"Invalid position '{pos}'. Expect {self._positions}.")

        # compute all quantities used in fit
        # TODO Restructure base class and move to other function
        if pos == "end":
            start_params, _ = self._get_info("params", end=False)
            end_params = info["params"]

            search_dir = [
                end_params[key] - start_params[key]
                for key in start_params.keys()
            ]

            for info_dict in [self._start_info, info]:
                grad = [info_dict["grad"][key] for key in start_params.keys()]
                batch_grad = [
                    info_dict["batch_grad"][key]
                    for key in start_params.keys()
                ]

                # 1ˢᵗ order info
                info_dict["df"] = _projected_gradient(grad, search_dir)
                info_dict["var_df"] = _exact_variance(batch_grad, search_dir)

        return info
Ejemplo n.º 9
0
    def _fetch_values(self, params, batch_loss, pos, global_step):
        """Fetch values for quadratic fit. Return as dictionary.

        The entry "search_dir" is only initialized if ``pos`` is ``"start"``.
        """
        info = {}

        info["params"] = {id(p): p.data.clone().detach() for p in params}

        # 0ᵗʰ order info
        info["f"] = batch_loss.item()
        info["var_f"] = get_individual_losses(global_step).var().item()

        # 1ˢᵗ order info
        info["df"], info["var_df"] = self._fetch_df_and_var_df(params, pos)

        return info
Ejemplo n.º 10
0
    def _get_abs_max(self, global_step, params, batch_loss, range):
        """Compute the maximum absolute value of individual gradient elements.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
            range (float, float): Current bin limits.

        Returns:
            float: Maximum absolute value of individual gradients.
        """
        individual_losses = get_individual_losses(global_step)
        individual_gradients = autograd_individual_gradients(individual_losses,
                                                             params,
                                                             concat=True)
        return individual_gradients.abs().max().item()
Ejemplo n.º 11
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the individual gradient histogram.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            dict: Entry ``'hist'`` holds the histogram, entry ``'edges'`` holds
                the bin limits.
        """
        individual_losses = get_individual_losses(global_step)
        individual_gradients = autograd_individual_gradients(individual_losses,
                                                             params,
                                                             concat=True)
        hist, edges = self._compute_histogram(individual_gradients)

        return {"hist": hist.float(), "edges": edges}
Ejemplo n.º 12
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the early stopping criterion.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Early stopping criterion.
        """
        grad_squared = torch.cat([p.grad.flatten() for p in params])**2

        losses = get_individual_losses(global_step)
        diag_variance = autograd_diagonal_variance(losses, params, concat=True)

        B = get_batch_size(global_step)

        return 1 - B * (grad_squared /
                        (diag_variance + self._epsilon)).mean().item()
Ejemplo n.º 13
0
    def _save_0th_order_info(self, global_step, params, batch_loss, point,
                             until):
        """Store 0ᵗʰ-order information about the objective in cache.

        Modifies ``self._cache``, creating entries ``f_*`` and ``var_f_*``.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
            point (str): Description of point, ``'start'`` or ``'end'``.
            until (int): Iteration number until which deletion from cache is blocked.
        """
        block_fn = self._make_block_fn(global_step, until)

        f = batch_loss.item()
        self.save_to_cache(global_step, f"f_{point}", f, block_fn)

        var_f = get_individual_losses(global_step).var().item()
        self.save_to_cache(global_step, f"var_f_{point}", var_f, block_fn)
Ejemplo n.º 14
0
    def _compute(self, global_step, params, batch_loss):
        """Aggregate histogram data over parameters and save to output."""
        individual_losses = get_individual_losses(global_step)
        individual_gradients = autograd_individual_gradients(
            individual_losses, params)
        layerwise = [
            self._compute_histogram(p, igrad)
            for p, igrad in zip(params, individual_gradients)
        ]

        hist = sum(out[0] for out in layerwise)
        edges = layerwise[0][1]

        result = {"hist": hist, "edges": edges}

        if self._keep_individual:
            result["param_groups"] = len(params)

            for idx, (hist, edges) in enumerate(layerwise):
                result[f"param_{idx}"] = {"hist": hist, "edges": edges}

        return result