def compute(self, global_step, params, batch_loss): """Evaluate the trace of the Hessian at the current point. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. """ if self.is_active(global_step): edges = self._get_current_bin_edges() hist = sum(p.grad_batch_transforms["hist_1d"] for p in params) if self._check: batch_size = get_batch_size(global_step) num_params = sum(p.numel() for p in params) num_counts = hist.sum() assert batch_size * num_params == num_counts self.output[global_step]["hist_1d"] = hist.cpu().numpy().tolist() self.output[global_step]["edges"] = edges.cpu().numpy().tolist() if self._verbose: print(f"[Step {global_step}] BatchGradHistogram1d" + f" edges 0,...,4: {edges[:5]}") print(f"[Step {global_step}] BatchGradHistogram1d" + f" counts 0,...,4: {hist[:5]}") self._update_limits(global_step, params, batch_loss)
def _compute_gsnr_from_batch_grad(params): """Gradient signal-to-noise ratio. Implement equation (25) in liu2020understanding, recursively defined via the prose between Equation (1) and (2). """ batch_grad = self._fetch_batch_grad(params, aggregate=True) if self._use_double: batch_grad = batch_grad.double() batch_size = get_batch_size(global_step) rescaled_batch_grad = batch_size * batch_grad grad_first_moment_squared = (rescaled_batch_grad).mean(0)**2 grad_second_moment = (rescaled_batch_grad**2).mean(0) grad_variance = grad_second_moment - grad_first_moment_squared if has_negative(grad_variance + self._epsilon): raise ValueError( "Gradient variances from batch_grad are negative.") if has_zeros(grad_variance + self._epsilon): raise ValueError("Gradient variances + ε has zeros.") return grad_first_moment_squared / (grad_variance + self._epsilon)
def _compute_individual(self, global_step, params, batch_loss): """Save histogram for each parameter to output.""" for idx, p in enumerate(params): x_edges, y_edges = self._get_current_bin_edges() hist = p.grad_batch_transforms["hist_2d"] if self._check: batch_size = get_batch_size(global_step) num_params = p.numel() num_counts = hist.sum() assert batch_size * num_params == num_counts self.output[global_step][f"param_{idx}_hist_2d"] = ( hist.cpu().numpy().tolist()) self.output[global_step][f"param_{idx}_x_edges"] = ( x_edges.cpu().numpy().tolist()) self.output[global_step][f"param_{idx}_y_edges"] = ( y_edges.cpu().numpy().tolist()) if self._verbose: print( f"[Step {global_step}] BatchGradHistogram2d param_{idx}" + f" x_edges 0,...,4: {x_edges[:5]}") print( f"[Step {global_step}] BatchGradHistogram2d param_{idx}" + f" y_edges 0,...,4: {y_edges[:5]}") print( f"[Step {global_step}] BatchGradHistogram2d param_{idx}" + f" counts [0,...,4][0,...,4]: {hist[:5,:5]}") self.output[global_step]["param_groups"] = len(params)
def _compute(self, global_step, params, batch_loss): """Evaluate the norm test. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Result of the norm test. """ losses = get_individual_losses(global_step) individual_gradients_flat = autograd_individual_gradients( losses, params, concat=True ) sum_of_squares = (individual_gradients_flat ** 2).sum() grad_norm = torch.cat([p.grad.flatten() for p in params]).norm() batch_size = get_batch_size(global_step) return ( ( 1 / (batch_size * (batch_size - 1)) * (sum_of_squares / grad_norm ** 2 - batch_size) ) .sqrt() .item() )
def _compute(self, global_step, params, batch_loss): """Evaluate the norm test. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: foat: Result of the norm test. """ losses = get_individual_losses(global_step) individual_gradients_flat = autograd_individual_gradients(losses, params, concat=True) D_axis = 1 individual_l2_norms_squared = ( individual_gradients_flat**2).sum(D_axis) grad = torch.cat([p.grad.flatten() for p in params]) grad_norm = grad.norm() projections = torch.einsum("ni,i->n", individual_gradients_flat, grad) batch_size = get_batch_size(global_step) return ((1 / (batch_size * (batch_size - 1)) * (individual_l2_norms_squared / grad_norm**2 - (projections**2) / grad_norm**4).sum()).sqrt().item())
def hook(grad_batch): """Project the end point gradients onto the start point's update direction. Modifies ``self._cache``, creating entry ``update_dot_grad_batch_end``. Args: grad_batch (torch.Tensor): Result of BackPACK's ``BatchGrad`` extension. Returns: dict: Empty dictionary. """ param_id = id(grad_batch._param_weakref()) search_dir = self.load_from_cache(start_step, "update_start")[param_id] # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's BatchGrad computes ¹/ₙ ∇ℓᵢ, we have to rescale batch_size = get_batch_size(end_step) dot_products = batch_size * self.batched_dot_product( grad_batch, search_dir) # update or create cache for ``dᵀgᵢ`` key = "update_dot_grad_batch_end" try: update_dot_dict = self.load_from_cache(end_step, key) except KeyError: update_dot_dict = {} finally: update_dot_dict[param_id] = dot_products self.save_to_cache(end_step, key, update_dot_dict, block_fn) return {}
def _fetch_values(self, params, batch_loss, pos, global_step): """Fetch values for quadratic fit. Return as dictionary. The entry "search_dir" is only initialized if ``pos`` is ``"start"``. """ info = {} if pos in ["start", "end"]: # 0ᵗʰ order info info["f"] = batch_loss.item() info["var_f"] = get_individual_losses(global_step).var().item() # temporary information required to compute quantities used in fit info["params"] = {id(p): p.data.clone().detach() for p in params} info["grad"] = { id(p): p.grad.data.clone().detach() for p in params } # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's computes ¹/ₙ ∇ℓᵢ, we have to rescale batch_size = get_batch_size(global_step) info["batch_grad"] = { id(p): batch_size * p.grad_batch.data.clone().detach() for p in params } else: raise ValueError( f"Invalid position '{pos}'. Expect {self._positions}.") # compute all quantities used in fit # TODO Restructure base class and move to other function if pos == "end": start_params, _ = self._get_info("params", end=False) end_params = info["params"] search_dir = [ end_params[key] - start_params[key] for key in start_params.keys() ] for info_dict in [self._start_info, info]: grad = [info_dict["grad"][key] for key in start_params.keys()] batch_grad = [ info_dict["batch_grad"][key] for key in start_params.keys() ] # 1ˢᵗ order info info_dict["df"] = _projected_gradient(grad, search_dir) info_dict["var_df"] = _exact_variance(batch_grad, search_dir) return info
def _compute_gsnr(self, global_step, params, batch_loss): """Compute gradient signal-to-noise ratio. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of parameters considered in the computation. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Mean GSNR of the current iteration. """ grad_squared = self._fetch_grad(params, aggregate=True)**2 sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True) batch_size = get_batch_size(global_step) return grad_squared / (batch_size * sum_grad_squared - grad_squared + self._epsilon)
def _compute_tic_with_batch_grad(params): """TICTrace.""" batch_grad = self._fetch_batch_grad(params, aggregate=True) curvature = self._fetch_diag_curvature(params, self._curvature, aggregate=True) if self._use_double: batch_grad = batch_grad.double() curvature = curvature.double() curv_trace_stable = curvature.sum() + self._epsilon if has_zeros(curv_trace_stable): raise ValueError("Curvature trace + ε has zeros.") if has_negative(curv_trace_stable): raise ValueError("Curvature trace + ε has negative entries.") batch_size = get_batch_size(global_step) return batch_size * (batch_grad**2).sum() / curv_trace_stable
def _compute_gsnr(self, global_step, params, batch_loss): """Compute gradient signal-to-noise ratio. Args: params ([torch.Tensor]): List of parameters considered in the computation. batch_loss (torch.Tensor): Mini-batch loss from current step. """ if self._use_double: grad_squared = self._fetch_grad(params, aggregate=True).double()**2 sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True).double() else: grad_squared = self._fetch_grad(params, aggregate=True)**2 sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True) batch_size = get_batch_size(global_step) return grad_squared / (batch_size * sum_grad_squared - grad_squared + self._epsilon)
def _compute(self, global_step, params, batch_loss): """Compute the TICTrace using a trace approximation. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: TIC computed using a trace approximation. """ sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True ) curvature = self._fetch_diag_curvature(params, self._curvature, aggregate=True) batch_size = get_batch_size(global_step) return ( batch_size * sum_grad_squared.sum() / (curvature.sum() + self._epsilon) ).item()
def _compute(self, global_step, params, batch_loss): """Evaluate the early stopping criterion. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Early stopping criterion. """ grad_squared = torch.cat([p.grad.flatten() for p in params])**2 losses = get_individual_losses(global_step) diag_variance = autograd_diagonal_variance(losses, params, concat=True) B = get_batch_size(global_step) return 1 - B * (grad_squared / (diag_variance + self._epsilon)).mean().item()
def _compute(self, global_step, params, batch_loss): """Compute the CABS rule. Return suggested batch size. Evaluates Equ. 22 of - Balles, L., Romero, J., & Hennig, P., Coupling adaptive batch sizes with learning rates (2017). """ B = get_batch_size(global_step) lr = self._lr grad_squared = self._fetch_grad(params, aggregate=True) ** 2 # # compensate BackPACK's 1/B scaling sgs_compensated = ( B ** 2 * self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True ) ) return lr * (sgs_compensated - B * grad_squared).sum() / (B * batch_loss.item())
def _compute(self, global_step, params, batch_loss): """Compute the EB early stopping criterion. Evaluates the left hand side of Equ. 7 in - Mahsereci, M., Balles, L., Lassner, C., & Hennig, P., Early stopping without a validation set (2017). If this value exceeds 0, training should be stopped. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Result of the Early stopping criterion. Training should stop if it is larger than 0. Raises: ValueError: If the used optimizer differs from SGD with default parameters. """ if not ComputeStep.is_sgd_default_kwargs(get_optimizer(global_step)): raise ValueError("This criterion only supports zero-momentum SGD.") B = get_batch_size(global_step) grad_squared = self._fetch_grad(params, aggregate=True)**2 # compensate BackPACK's 1/B scaling sgs_compensated = ( B**2 * self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True)) diag_variance = (sgs_compensated - B * grad_squared) / (B - 1) snr = grad_squared / (diag_variance + self._epsilon) return 1 - B * snr.mean().item()
def _compute_projection_variance_from_batch_grad(params): """Compute variance of individual gradient projections on the gradient. The sample variance of projections is given by Equation (line after 2.6) in bollapragada2017adaptive (https://arxiv.org/pdf/1710.11258.pdf) """ batch_grad = self._fetch_batch_grad(params, aggregate=True) batch_size = get_batch_size(global_step) grad = self._fetch_grad(params, aggregate=True) grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True) if self._use_double: batch_grad = batch_grad.double() grad = grad.double() grad_l2_squared = grad_l2_squared.double() projections = torch.einsum("ni,i->n", batch_size * batch_grad, grad) return (1 / (batch_size - 1)) * ( (projections**2).sum() - batch_size * grad_l2_squared**2)
def _compute(self, global_step, params, batch_loss): """Compute the TICTrace using a trace approximation. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of parameters considered in the computation. batch_loss (torch.Tensor): Mini-batch loss from current step. """ sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True) curvature = self._fetch_diag_curvature(params, self._curvature, aggregate=True) if self._use_double: sum_grad_squared = sum_grad_squared.double() curvature = curvature.double() batch_size = get_batch_size(global_step) return batch_size * sum_grad_squared.sum() / (curvature.sum() + self._epsilon)
def _compute_aggregated(self, global_step, params, batch_loss): """Aggregate histogram data over parameters and save to output.""" x_edges, y_edges = self._get_current_bin_edges() hist = sum(p.grad_batch_transforms["hist_2d"] for p in params) if self._check: batch_size = get_batch_size(global_step) num_params = sum(p.numel() for p in params) num_counts = hist.sum() assert batch_size * num_params == num_counts self.output[global_step]["hist_2d"] = hist.cpu().numpy().tolist() self.output[global_step]["x_edges"] = x_edges.cpu().numpy().tolist() self.output[global_step]["y_edges"] = y_edges.cpu().numpy().tolist() if self._verbose: print(f"[Step {global_step}] BatchGradHistogram2d" + f" x_edges 0,...,4: {x_edges[:5]}") print(f"[Step {global_step}] BatchGradHistogram2d" + f" y_edges 0,...,4: {y_edges[:5]}") print(f"[Step {global_step}] BatchGradHistogram2d" + f" counts [0,...,4][0,...,4]: {hist[:5,:5]}")
def _compute_tic_with_batch_grad(params): """TICDiag.""" batch_grad = self._fetch_batch_grad(params, aggregate=True) curvature = self._fetch_diag_curvature(params, self._curvature, aggregate=True) if self._use_double: batch_grad = batch_grad.double() curvature = curvature.double() curv_stable = curvature + self._epsilon if has_zeros(curv_stable): raise ValueError("Diagonal curvature + ε has zeros.") if has_negative(curv_stable): raise ValueError( "Diagonal curvature + ε has negative entries.") batch_size = get_batch_size(global_step) return torch.einsum("j,nj->", 1 / curv_stable, batch_size * batch_grad**2)
def _compute(self, global_step, params, batch_loss): """Compute the CABS rule. Return suggested batch size. Evaluates Equ. 22 of - Balles, L., Romero, J., & Hennig, P., Coupling adaptive batch sizes with learning rates (2017). Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Batch size suggested by CABS. Raises: ValueError: If the optimizer differs from SGD with default arguments. """ optimizer = get_optimizer(global_step) if not ComputeStep.is_sgd_default_kwargs(optimizer): raise ValueError("This criterion only supports zero-momentum SGD.") B = get_batch_size(global_step) lr = self.get_lr(optimizer) grad_squared = self._fetch_grad(params, aggregate=True) ** 2 # # compensate BackPACK's 1/B scaling sgs_compensated = ( B ** 2 * self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True ) ) return ( lr * (sgs_compensated - B * grad_squared).sum() / (B * batch_loss) ).item()
def _save_1st_order_info(self, global_step, params, batch_loss, point, until): """Store information for projecting 1ˢᵗ-order info about the objective in cache. This is the go-to approach if the update step at the start point +projections cannot be computed automatically through BackPACK. Modifies ``self._cache``, creating the fields ``params_*``, ``grad_*``, and ``grad_batch_*``. Parameters are required to compute the update step, the gradients are required to project them onto the step at a later stage. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. point (str): Description of point, ``'start'`` or ``'end'``. until (int): Iteration number until which deletion from cache is blocked. """ block_fn = self._make_block_fn(global_step, until) params_dict = {id(p): p.data.clone().detach() for p in params} self.save_to_cache(global_step, f"params_{point}", params_dict, block_fn) grad_dict = {id(p): p.grad.data.clone().detach() for p in params} self.save_to_cache(global_step, f"grad_{point}", grad_dict, block_fn) # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's BatchGrad computes ¹/ₙ ∇ℓᵢ, we have to rescale batch_size = get_batch_size(global_step) grad_batch_dict = { id(p): batch_size * p.grad_batch.data.clone().detach() for p in params } self.save_to_cache(global_step, f"grad_batch_{point}", grad_batch_dict, block_fn)
def _compute(self, global_step, params, batch_loss): """Compute the criterion. Evaluates the left hand side of Equ. 7 in - Mahsereci, M., Balles, L., Lassner, C., & Hennig, P., Early stopping without a validation set (2017). If this value exceeds 0, training should be stopped. """ B = get_batch_size(global_step) grad_squared = self._fetch_grad(params, aggregate=True)**2 # compensate BackPACK's 1/B scaling sgs_compensated = ( B**2 * self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True)) diag_variance = (sgs_compensated - B * grad_squared) / (B - 1) snr = grad_squared / (diag_variance + self._epsilon) return 1 - B * snr.mean()