예제 #1
0
    def compute(self, global_step, params, batch_loss):
        """Evaluate the trace of the Hessian at the current point.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
        """
        if self.is_active(global_step):
            edges = self._get_current_bin_edges()
            hist = sum(p.grad_batch_transforms["hist_1d"] for p in params)

            if self._check:
                batch_size = get_batch_size(global_step)
                num_params = sum(p.numel() for p in params)
                num_counts = hist.sum()
                assert batch_size * num_params == num_counts

            self.output[global_step]["hist_1d"] = hist.cpu().numpy().tolist()
            self.output[global_step]["edges"] = edges.cpu().numpy().tolist()

            if self._verbose:
                print(f"[Step {global_step}] BatchGradHistogram1d" +
                      f" edges 0,...,4: {edges[:5]}")
                print(f"[Step {global_step}] BatchGradHistogram1d" +
                      f" counts 0,...,4: {hist[:5]}")

        self._update_limits(global_step, params, batch_loss)
예제 #2
0
        def _compute_gsnr_from_batch_grad(params):
            """Gradient signal-to-noise ratio.

            Implement equation (25) in liu2020understanding, recursively defined via
            the prose between Equation (1) and (2).
            """
            batch_grad = self._fetch_batch_grad(params, aggregate=True)

            if self._use_double:
                batch_grad = batch_grad.double()

            batch_size = get_batch_size(global_step)

            rescaled_batch_grad = batch_size * batch_grad

            grad_first_moment_squared = (rescaled_batch_grad).mean(0)**2
            grad_second_moment = (rescaled_batch_grad**2).mean(0)
            grad_variance = grad_second_moment - grad_first_moment_squared

            if has_negative(grad_variance + self._epsilon):
                raise ValueError(
                    "Gradient variances from batch_grad are negative.")

            if has_zeros(grad_variance + self._epsilon):
                raise ValueError("Gradient variances + ε has zeros.")

            return grad_first_moment_squared / (grad_variance + self._epsilon)
예제 #3
0
    def _compute_individual(self, global_step, params, batch_loss):
        """Save histogram for each parameter to output."""
        for idx, p in enumerate(params):
            x_edges, y_edges = self._get_current_bin_edges()

            hist = p.grad_batch_transforms["hist_2d"]

            if self._check:
                batch_size = get_batch_size(global_step)
                num_params = p.numel()
                num_counts = hist.sum()
                assert batch_size * num_params == num_counts

            self.output[global_step][f"param_{idx}_hist_2d"] = (
                hist.cpu().numpy().tolist())
            self.output[global_step][f"param_{idx}_x_edges"] = (
                x_edges.cpu().numpy().tolist())
            self.output[global_step][f"param_{idx}_y_edges"] = (
                y_edges.cpu().numpy().tolist())

            if self._verbose:
                print(
                    f"[Step {global_step}] BatchGradHistogram2d param_{idx}" +
                    f" x_edges 0,...,4: {x_edges[:5]}")
                print(
                    f"[Step {global_step}] BatchGradHistogram2d param_{idx}" +
                    f" y_edges 0,...,4: {y_edges[:5]}")
                print(
                    f"[Step {global_step}] BatchGradHistogram2d param_{idx}" +
                    f" counts [0,...,4][0,...,4]: {hist[:5,:5]}")
        self.output[global_step]["param_groups"] = len(params)
예제 #4
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the norm test.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Result of the norm test.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(
            losses, params, concat=True
        )
        sum_of_squares = (individual_gradients_flat ** 2).sum()

        grad_norm = torch.cat([p.grad.flatten() for p in params]).norm()

        batch_size = get_batch_size(global_step)

        return (
            (
                1
                / (batch_size * (batch_size - 1))
                * (sum_of_squares / grad_norm ** 2 - batch_size)
            )
            .sqrt()
            .item()
        )
예제 #5
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the norm test.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            foat: Result of the norm test.
        """
        losses = get_individual_losses(global_step)
        individual_gradients_flat = autograd_individual_gradients(losses,
                                                                  params,
                                                                  concat=True)
        D_axis = 1
        individual_l2_norms_squared = (
            individual_gradients_flat**2).sum(D_axis)

        grad = torch.cat([p.grad.flatten() for p in params])
        grad_norm = grad.norm()

        projections = torch.einsum("ni,i->n", individual_gradients_flat, grad)

        batch_size = get_batch_size(global_step)

        return ((1 / (batch_size * (batch_size - 1)) *
                 (individual_l2_norms_squared / grad_norm**2 -
                  (projections**2) / grad_norm**4).sum()).sqrt().item())
예제 #6
0
파일: alpha.py 프로젝트: f-dangel/cockpit
        def hook(grad_batch):
            """Project the end point gradients onto the start point's update direction.

            Modifies ``self._cache``, creating entry ``update_dot_grad_batch_end``.

            Args:
                grad_batch (torch.Tensor): Result of BackPACK's ``BatchGrad``
                    extension.

            Returns:
                dict: Empty dictionary.
            """
            param_id = id(grad_batch._param_weakref())

            search_dir = self.load_from_cache(start_step,
                                              "update_start")[param_id]
            # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's BatchGrad computes ¹/ₙ ∇ℓᵢ, we have to rescale
            batch_size = get_batch_size(end_step)
            dot_products = batch_size * self.batched_dot_product(
                grad_batch, search_dir)

            # update or create cache for ``dᵀgᵢ``
            key = "update_dot_grad_batch_end"
            try:
                update_dot_dict = self.load_from_cache(end_step, key)
            except KeyError:
                update_dot_dict = {}
            finally:
                update_dot_dict[param_id] = dot_products
                self.save_to_cache(end_step, key, update_dot_dict, block_fn)

            return {}
예제 #7
0
    def _fetch_values(self, params, batch_loss, pos, global_step):
        """Fetch values for quadratic fit. Return as dictionary.

        The entry "search_dir" is only initialized if ``pos`` is ``"start"``.
        """
        info = {}

        if pos in ["start", "end"]:
            # 0ᵗʰ order info
            info["f"] = batch_loss.item()
            info["var_f"] = get_individual_losses(global_step).var().item()

            # temporary information required to compute quantities used in fit
            info["params"] = {id(p): p.data.clone().detach() for p in params}
            info["grad"] = {
                id(p): p.grad.data.clone().detach()
                for p in params
            }
            # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's computes ¹/ₙ ∇ℓᵢ, we have to rescale
            batch_size = get_batch_size(global_step)
            info["batch_grad"] = {
                id(p): batch_size * p.grad_batch.data.clone().detach()
                for p in params
            }

        else:
            raise ValueError(
                f"Invalid position '{pos}'. Expect {self._positions}.")

        # compute all quantities used in fit
        # TODO Restructure base class and move to other function
        if pos == "end":
            start_params, _ = self._get_info("params", end=False)
            end_params = info["params"]

            search_dir = [
                end_params[key] - start_params[key]
                for key in start_params.keys()
            ]

            for info_dict in [self._start_info, info]:
                grad = [info_dict["grad"][key] for key in start_params.keys()]
                batch_grad = [
                    info_dict["batch_grad"][key]
                    for key in start_params.keys()
                ]

                # 1ˢᵗ order info
                info_dict["df"] = _projected_gradient(grad, search_dir)
                info_dict["var_df"] = _exact_variance(batch_grad, search_dir)

        return info
예제 #8
0
    def _compute_gsnr(self, global_step, params, batch_loss):
        """Compute gradient signal-to-noise ratio.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of parameters considered in the computation.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Mean GSNR of the current iteration.
        """
        grad_squared = self._fetch_grad(params, aggregate=True)**2
        sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms(
            params, aggregate=True)

        batch_size = get_batch_size(global_step)

        return grad_squared / (batch_size * sum_grad_squared - grad_squared +
                               self._epsilon)
예제 #9
0
파일: tic.py 프로젝트: MeNicefellow/cockpit
        def _compute_tic_with_batch_grad(params):
            """TICTrace."""
            batch_grad = self._fetch_batch_grad(params, aggregate=True)
            curvature = self._fetch_diag_curvature(params,
                                                   self._curvature,
                                                   aggregate=True)

            if self._use_double:
                batch_grad = batch_grad.double()
                curvature = curvature.double()

            curv_trace_stable = curvature.sum() + self._epsilon
            if has_zeros(curv_trace_stable):
                raise ValueError("Curvature trace + ε has zeros.")
            if has_negative(curv_trace_stable):
                raise ValueError("Curvature trace + ε has negative entries.")

            batch_size = get_batch_size(global_step)

            return batch_size * (batch_grad**2).sum() / curv_trace_stable
예제 #10
0
    def _compute_gsnr(self, global_step, params, batch_loss):
        """Compute gradient signal-to-noise ratio.

        Args:
            params ([torch.Tensor]): List of parameters considered in the computation.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
        """
        if self._use_double:
            grad_squared = self._fetch_grad(params, aggregate=True).double()**2
            sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms(
                params, aggregate=True).double()
        else:
            grad_squared = self._fetch_grad(params, aggregate=True)**2
            sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms(
                params, aggregate=True)

        batch_size = get_batch_size(global_step)

        return grad_squared / (batch_size * sum_grad_squared - grad_squared +
                               self._epsilon)
예제 #11
0
    def _compute(self, global_step, params, batch_loss):
        """Compute the TICTrace using a trace approximation.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: TIC computed using a trace approximation.
        """
        sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms(
            params, aggregate=True
        )
        curvature = self._fetch_diag_curvature(params, self._curvature, aggregate=True)
        batch_size = get_batch_size(global_step)

        return (
            batch_size * sum_grad_squared.sum() / (curvature.sum() + self._epsilon)
        ).item()
예제 #12
0
    def _compute(self, global_step, params, batch_loss):
        """Evaluate the early stopping criterion.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Early stopping criterion.
        """
        grad_squared = torch.cat([p.grad.flatten() for p in params])**2

        losses = get_individual_losses(global_step)
        diag_variance = autograd_diagonal_variance(losses, params, concat=True)

        B = get_batch_size(global_step)

        return 1 - B * (grad_squared /
                        (diag_variance + self._epsilon)).mean().item()
예제 #13
0
    def _compute(self, global_step, params, batch_loss):
        """Compute the CABS rule. Return suggested batch size.

        Evaluates Equ. 22 of

        - Balles, L., Romero, J., & Hennig, P.,
          Coupling adaptive batch sizes with learning rates (2017).
        """
        B = get_batch_size(global_step)
        lr = self._lr

        grad_squared = self._fetch_grad(params, aggregate=True) ** 2
        # # compensate BackPACK's 1/B scaling
        sgs_compensated = (
            B ** 2
            * self._fetch_sum_grad_squared_via_batch_grad_transforms(
                params, aggregate=True
            )
        )

        return lr * (sgs_compensated - B * grad_squared).sum() / (B * batch_loss.item())
예제 #14
0
    def _compute(self, global_step, params, batch_loss):
        """Compute the EB early stopping criterion.

        Evaluates the left hand side of Equ. 7 in

        - Mahsereci, M., Balles, L., Lassner, C., & Hennig, P.,
          Early stopping without a validation set (2017).

        If this value exceeds 0, training should be stopped.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Result of the Early stopping criterion. Training should stop
                if it is larger than 0.

        Raises:
            ValueError: If the used optimizer differs from SGD with default parameters.
        """
        if not ComputeStep.is_sgd_default_kwargs(get_optimizer(global_step)):
            raise ValueError("This criterion only supports zero-momentum SGD.")

        B = get_batch_size(global_step)

        grad_squared = self._fetch_grad(params, aggregate=True)**2

        # compensate BackPACK's 1/B scaling
        sgs_compensated = (
            B**2 * self._fetch_sum_grad_squared_via_batch_grad_transforms(
                params, aggregate=True))

        diag_variance = (sgs_compensated - B * grad_squared) / (B - 1)

        snr = grad_squared / (diag_variance + self._epsilon)

        return 1 - B * snr.mean().item()
예제 #15
0
        def _compute_projection_variance_from_batch_grad(params):
            """Compute variance of individual gradient projections on the gradient.

            The sample variance of projections is given by Equation (line after 2.6)
            in bollapragada2017adaptive (https://arxiv.org/pdf/1710.11258.pdf)
            """
            batch_grad = self._fetch_batch_grad(params, aggregate=True)
            batch_size = get_batch_size(global_step)
            grad = self._fetch_grad(params, aggregate=True)
            grad_l2_squared = self._fetch_grad_l2_squared(params,
                                                          aggregate=True)

            if self._use_double:
                batch_grad = batch_grad.double()
                grad = grad.double()
                grad_l2_squared = grad_l2_squared.double()

            projections = torch.einsum("ni,i->n", batch_size * batch_grad,
                                       grad)

            return (1 / (batch_size - 1)) * (
                (projections**2).sum() - batch_size * grad_l2_squared**2)
예제 #16
0
파일: tic.py 프로젝트: MeNicefellow/cockpit
    def _compute(self, global_step, params, batch_loss):
        """Compute the TICTrace using a trace approximation.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of parameters considered in the computation.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
        """
        sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms(
            params, aggregate=True)
        curvature = self._fetch_diag_curvature(params,
                                               self._curvature,
                                               aggregate=True)

        if self._use_double:
            sum_grad_squared = sum_grad_squared.double()
            curvature = curvature.double()

        batch_size = get_batch_size(global_step)

        return batch_size * sum_grad_squared.sum() / (curvature.sum() +
                                                      self._epsilon)
예제 #17
0
    def _compute_aggregated(self, global_step, params, batch_loss):
        """Aggregate histogram data over parameters and save to output."""
        x_edges, y_edges = self._get_current_bin_edges()
        hist = sum(p.grad_batch_transforms["hist_2d"] for p in params)

        if self._check:
            batch_size = get_batch_size(global_step)
            num_params = sum(p.numel() for p in params)
            num_counts = hist.sum()
            assert batch_size * num_params == num_counts

        self.output[global_step]["hist_2d"] = hist.cpu().numpy().tolist()
        self.output[global_step]["x_edges"] = x_edges.cpu().numpy().tolist()
        self.output[global_step]["y_edges"] = y_edges.cpu().numpy().tolist()

        if self._verbose:
            print(f"[Step {global_step}] BatchGradHistogram2d" +
                  f" x_edges 0,...,4: {x_edges[:5]}")
            print(f"[Step {global_step}] BatchGradHistogram2d" +
                  f" y_edges 0,...,4: {y_edges[:5]}")
            print(f"[Step {global_step}] BatchGradHistogram2d" +
                  f" counts [0,...,4][0,...,4]: {hist[:5,:5]}")
예제 #18
0
파일: tic.py 프로젝트: MeNicefellow/cockpit
        def _compute_tic_with_batch_grad(params):
            """TICDiag."""
            batch_grad = self._fetch_batch_grad(params, aggregate=True)
            curvature = self._fetch_diag_curvature(params,
                                                   self._curvature,
                                                   aggregate=True)

            if self._use_double:
                batch_grad = batch_grad.double()
                curvature = curvature.double()

            curv_stable = curvature + self._epsilon
            if has_zeros(curv_stable):
                raise ValueError("Diagonal curvature + ε has zeros.")
            if has_negative(curv_stable):
                raise ValueError(
                    "Diagonal curvature + ε has negative entries.")

            batch_size = get_batch_size(global_step)

            return torch.einsum("j,nj->", 1 / curv_stable,
                                batch_size * batch_grad**2)
예제 #19
0
    def _compute(self, global_step, params, batch_loss):
        """Compute the CABS rule. Return suggested batch size.

        Evaluates Equ. 22 of

        - Balles, L., Romero, J., & Hennig, P.,
          Coupling adaptive batch sizes with learning rates (2017).

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.

        Returns:
            float: Batch size suggested by CABS.

        Raises:
            ValueError: If the optimizer differs from SGD with default arguments.
        """
        optimizer = get_optimizer(global_step)
        if not ComputeStep.is_sgd_default_kwargs(optimizer):
            raise ValueError("This criterion only supports zero-momentum SGD.")

        B = get_batch_size(global_step)
        lr = self.get_lr(optimizer)

        grad_squared = self._fetch_grad(params, aggregate=True) ** 2
        # # compensate BackPACK's 1/B scaling
        sgs_compensated = (
            B ** 2
            * self._fetch_sum_grad_squared_via_batch_grad_transforms(
                params, aggregate=True
            )
        )

        return (
            lr * (sgs_compensated - B * grad_squared).sum() / (B * batch_loss)
        ).item()
예제 #20
0
파일: alpha.py 프로젝트: f-dangel/cockpit
    def _save_1st_order_info(self, global_step, params, batch_loss, point,
                             until):
        """Store information for projecting 1ˢᵗ-order info about the objective in cache.

        This is the go-to approach if the update step at the start point +projections
        cannot be computed automatically through BackPACK.

        Modifies ``self._cache``, creating the fields ``params_*``, ``grad_*``, and
        ``grad_batch_*``. Parameters are required to compute the update step, the
        gradients are required to project them onto the step at a later stage.

        Args:
            global_step (int): The current iteration number.
            params ([torch.Tensor]): List of torch.Tensors holding the network's
                parameters.
            batch_loss (torch.Tensor): Mini-batch loss from current step.
            point (str): Description of point, ``'start'`` or ``'end'``.
            until (int): Iteration number until which deletion from cache is blocked.
        """
        block_fn = self._make_block_fn(global_step, until)

        params_dict = {id(p): p.data.clone().detach() for p in params}
        self.save_to_cache(global_step, f"params_{point}", params_dict,
                           block_fn)

        grad_dict = {id(p): p.grad.data.clone().detach() for p in params}
        self.save_to_cache(global_step, f"grad_{point}", grad_dict, block_fn)

        # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's BatchGrad computes ¹/ₙ ∇ℓᵢ, we have to rescale
        batch_size = get_batch_size(global_step)
        grad_batch_dict = {
            id(p): batch_size * p.grad_batch.data.clone().detach()
            for p in params
        }
        self.save_to_cache(global_step, f"grad_batch_{point}", grad_batch_dict,
                           block_fn)
예제 #21
0
    def _compute(self, global_step, params, batch_loss):
        """Compute the criterion.

        Evaluates the left hand side of Equ. 7 in

        - Mahsereci, M., Balles, L., Lassner, C., & Hennig, P.,
          Early stopping without a validation set (2017).

        If this value exceeds 0, training should be stopped.
        """
        B = get_batch_size(global_step)

        grad_squared = self._fetch_grad(params, aggregate=True)**2

        # compensate BackPACK's 1/B scaling
        sgs_compensated = (
            B**2 * self._fetch_sum_grad_squared_via_batch_grad_transforms(
                params, aggregate=True))

        diag_variance = (sgs_compensated - B * grad_squared) / (B - 1)

        snr = grad_squared / (diag_variance + self._epsilon)

        return 1 - B * snr.mean()