def _compute(self, global_step, params, batch_loss): """Evaluate the MeanGSNR. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Mean GSNR of the current iteration. """ losses = get_individual_losses(global_step) individual_gradients_flat = autograd_individual_gradients(losses, params, concat=True) grad_squared = torch.cat([p.grad.flatten() for p in params])**2 N_axis = 0 second_moment_flat = (individual_gradients_flat**2).mean(N_axis) gsnr = grad_squared / (second_moment_flat - grad_squared + self._epsilon) return gsnr.mean().item()
def _compute(self, global_step, params, batch_loss): """Evaluate the norm test. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: foat: Result of the norm test. """ losses = get_individual_losses(global_step) individual_gradients_flat = autograd_individual_gradients(losses, params, concat=True) D_axis = 1 individual_l2_norms_squared = ( individual_gradients_flat**2).sum(D_axis) grad = torch.cat([p.grad.flatten() for p in params]) grad_norm = grad.norm() projections = torch.einsum("ni,i->n", individual_gradients_flat, grad) batch_size = get_batch_size(global_step) return ((1 / (batch_size * (batch_size - 1)) * (individual_l2_norms_squared / grad_norm**2 - (projections**2) / grad_norm**4).sum()).sqrt().item())
def _compute(self, global_step, params, batch_loss): """Evaluate the CABS criterion. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Evaluated CABS criterion. Raises: ValueError: If the optimizer differs from SGD with default arguments. """ optimizer = get_optimizer(global_step) if not ComputeStep.is_sgd_default_kwargs(optimizer): raise ValueError("This criterion only supports zero-momentum SGD.") losses = get_individual_losses(global_step) batch_axis = 0 trace_variance = autograd_diagonal_variance( losses, params, concat=True, unbiased=False).sum(batch_axis) lr = self.get_lr(optimizer) return (lr * trace_variance / batch_loss).item()
def _compute(self, global_step, params, batch_loss): """Evaluate the norm test. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Result of the norm test. """ losses = get_individual_losses(global_step) individual_gradients_flat = autograd_individual_gradients( losses, params, concat=True ) sum_of_squares = (individual_gradients_flat ** 2).sum() grad_norm = torch.cat([p.grad.flatten() for p in params]).norm() batch_size = get_batch_size(global_step) return ( ( 1 / (batch_size * (batch_size - 1)) * (sum_of_squares / grad_norm ** 2 - batch_size) ) .sqrt() .item() )
def _compute(self, global_step, params, batch_loss): """Evaluate the TIC approximating the Hessian by its diagonal. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Raises: NotImplementedError: If curvature is not ``diag_h``. Returns: float: TIC when approximation the Hessian with its diagonal. """ losses = get_individual_losses(global_step) individual_gradients_flat = autograd_individual_gradients(losses, params, concat=True) if self._curvature == "diag_h": diag_curvature_flat = autograd_diag_hessian(batch_loss, params, concat=True) else: raise NotImplementedError("Only Hessian diagonal is implemented") N_axis = 0 second_moment_flat = (individual_gradients_flat**2).mean(N_axis) return (second_moment_flat / (diag_curvature_flat + self._epsilon)).sum().item()
def _compute(self, global_step, params, batch_loss): """Evaluate the TIC approximation proposed in thomas2020interplay. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Raises: NotImplementedError: If curvature is not ``diag_h``. Returns: float: TIC using a trace approximation. """ losses = get_individual_losses(global_step) individual_gradients_flat = autograd_individual_gradients(losses, params, concat=True) if self._curvature == "diag_h": curvature_trace = autograd_diag_hessian(batch_loss, params, concat=True).sum() else: raise NotImplementedError("Only Hessian trace is implemented") N_axis = 0 mean_squared_l2_norm = ( individual_gradients_flat**2).mean(N_axis).sum() return (mean_squared_l2_norm / (curvature_trace + self._epsilon)).item()
def _fetch_values(self, params, batch_loss, pos, global_step): """Fetch values for quadratic fit. Return as dictionary. The entry "search_dir" is only initialized if ``pos`` is ``"start"``. Args: params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. pos (str): Whether we are at the start or end of an iteration. One of ``start`` or ``end``. global_step (int): The current iteration number. Raises: ValueError: If pos is not one of ``start`` or ``end``. Returns: dict: Holding the parameters, (variance of) loss and slope. """ info = {} if pos in ["start", "end"]: # 0ᵗʰ order info info["f"] = batch_loss.item() losses = get_individual_losses(global_step) info["var_f"] = losses.var().item() # temporary information required to compute quantities used in fit info["params"] = {id(p): p.data.clone().detach() for p in params} info["grad"] = {id(p): p.grad.data.clone().detach() for p in params} batch_grad = autograd_individual_gradients(losses, params) info["batch_grad"] = { id(p): bg.data.clone().detach() for p, bg in zip(params, batch_grad) } else: raise ValueError(f"Invalid position '{pos}'. Expect {self._positions}.") # compute all quantities used in fit if pos == "end": start_params, _ = self._get_info("params", end=False) end_params = info["params"] search_dir = [ end_params[key] - start_params[key] for key in start_params.keys() ] for info_dict in [self._start_info, info]: grad = [info_dict["grad"][key] for key in start_params.keys()] batch_grad = [ info_dict["batch_grad"][key] for key in start_params.keys() ] # 1ˢᵗ order info info_dict["df"] = _projected_gradient(grad, search_dir) info_dict["var_df"] = _exact_variance(batch_grad, search_dir) return info
def _fetch_values(self, params, batch_loss, pos, global_step): """Fetch values for quadratic fit. Return as dictionary. The entry "search_dir" is only initialized if ``pos`` is ``"start"``. """ info = {} if pos in ["start", "end"]: # 0ᵗʰ order info info["f"] = batch_loss.item() info["var_f"] = get_individual_losses(global_step).var().item() # temporary information required to compute quantities used in fit info["params"] = {id(p): p.data.clone().detach() for p in params} info["grad"] = { id(p): p.grad.data.clone().detach() for p in params } # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's computes ¹/ₙ ∇ℓᵢ, we have to rescale batch_size = get_batch_size(global_step) info["batch_grad"] = { id(p): batch_size * p.grad_batch.data.clone().detach() for p in params } else: raise ValueError( f"Invalid position '{pos}'. Expect {self._positions}.") # compute all quantities used in fit # TODO Restructure base class and move to other function if pos == "end": start_params, _ = self._get_info("params", end=False) end_params = info["params"] search_dir = [ end_params[key] - start_params[key] for key in start_params.keys() ] for info_dict in [self._start_info, info]: grad = [info_dict["grad"][key] for key in start_params.keys()] batch_grad = [ info_dict["batch_grad"][key] for key in start_params.keys() ] # 1ˢᵗ order info info_dict["df"] = _projected_gradient(grad, search_dir) info_dict["var_df"] = _exact_variance(batch_grad, search_dir) return info
def _fetch_values(self, params, batch_loss, pos, global_step): """Fetch values for quadratic fit. Return as dictionary. The entry "search_dir" is only initialized if ``pos`` is ``"start"``. """ info = {} info["params"] = {id(p): p.data.clone().detach() for p in params} # 0ᵗʰ order info info["f"] = batch_loss.item() info["var_f"] = get_individual_losses(global_step).var().item() # 1ˢᵗ order info info["df"], info["var_df"] = self._fetch_df_and_var_df(params, pos) return info
def _get_abs_max(self, global_step, params, batch_loss, range): """Compute the maximum absolute value of individual gradient elements. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. range (float, float): Current bin limits. Returns: float: Maximum absolute value of individual gradients. """ individual_losses = get_individual_losses(global_step) individual_gradients = autograd_individual_gradients(individual_losses, params, concat=True) return individual_gradients.abs().max().item()
def _compute(self, global_step, params, batch_loss): """Evaluate the individual gradient histogram. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: dict: Entry ``'hist'`` holds the histogram, entry ``'edges'`` holds the bin limits. """ individual_losses = get_individual_losses(global_step) individual_gradients = autograd_individual_gradients(individual_losses, params, concat=True) hist, edges = self._compute_histogram(individual_gradients) return {"hist": hist.float(), "edges": edges}
def _compute(self, global_step, params, batch_loss): """Evaluate the early stopping criterion. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Early stopping criterion. """ grad_squared = torch.cat([p.grad.flatten() for p in params])**2 losses = get_individual_losses(global_step) diag_variance = autograd_diagonal_variance(losses, params, concat=True) B = get_batch_size(global_step) return 1 - B * (grad_squared / (diag_variance + self._epsilon)).mean().item()
def _save_0th_order_info(self, global_step, params, batch_loss, point, until): """Store 0ᵗʰ-order information about the objective in cache. Modifies ``self._cache``, creating entries ``f_*`` and ``var_f_*``. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. point (str): Description of point, ``'start'`` or ``'end'``. until (int): Iteration number until which deletion from cache is blocked. """ block_fn = self._make_block_fn(global_step, until) f = batch_loss.item() self.save_to_cache(global_step, f"f_{point}", f, block_fn) var_f = get_individual_losses(global_step).var().item() self.save_to_cache(global_step, f"var_f_{point}", var_f, block_fn)
def _compute(self, global_step, params, batch_loss): """Aggregate histogram data over parameters and save to output.""" individual_losses = get_individual_losses(global_step) individual_gradients = autograd_individual_gradients( individual_losses, params) layerwise = [ self._compute_histogram(p, igrad) for p, igrad in zip(params, individual_gradients) ] hist = sum(out[0] for out in layerwise) edges = layerwise[0][1] result = {"hist": hist, "edges": edges} if self._keep_individual: result["param_groups"] = len(params) for idx, (hist, edges) in enumerate(layerwise): result[f"param_{idx}"] = {"hist": hist, "edges": edges} return result