def autograd_diag_hessian(loss, params, concat=False): """Compute the Hessian diagonal via ``torch.autograd``. Args: loss (torch.Tensor): Loss whose Hessian is investigated. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. concat (bool): If ``True``, flatten and concatenate the columns over all parameters. Returns: Tensor: Hessian diagonal. """ D = sum(p.numel() for p in params) device = loss.device hessian_diag = torch.zeros(D, device=device) for d, column_d in enumerate(autograd_hessian_columns(loss, params, concat=True)): hessian_diag[d] = column_d[d] if concat is False: hessian_diag = vector_to_parameter_list(hessian_diag, params) return hessian_diag
def autograd_hessian_columns(loss, params, concat=False): """Return an iterator of the Hessian columns computed via ``torch.autograd``. Args: loss (torch.Tensor): Loss whose Hessian is investigated. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. concat (bool): If ``True``, flatten and concatenate the columns over all parameters. Yields: Tensor: Tensor of Hessian columns. """ D = sum(p.numel() for p in params) device = loss.device for d in range(D): e_d = torch.zeros(D, device=device) e_d[d] = 1.0 e_d_list = vector_to_parameter_list(e_d, params) hessian_e_d = hessian_vector_product(loss, params, e_d_list) if concat: hessian_e_d = torch.cat([tensor.flatten() for tensor in hessian_e_d]) yield hessian_e_d
def extract_ith_element_of_diag_ggn(i, p, loss, output): v = torch.zeros(p.numel()).to(self.problem.device) v[i] = 1.0 vs = vector_to_parameter_list(v, [p]) GGN_vs = ggn_vector_product_from_plist(loss, output, [p], vs) GGN_v = torch.cat([g.detach().view(-1) for g in GGN_vs]) return GGN_v[i]
def _autograd_ggn_exact_columns( X: Tensor, y: Tensor, model: Module, loss_function: Module, idx: List[int] = None) -> Iterator[Tuple[int, Tensor]]: """Yield exact generalized Gauss-Newton's columns computed with ``torch.autograd``. Args: X: Input to the model. y: Labels. model: The neural network. loss_function: Loss function module. idx: Indices of columns that are computed. Default value ``None`` computes all columns. Yields: Tuple of column index and respective GGN column (flattened and concatenated). """ trainable_parameters = [p for p in model.parameters() if p.requires_grad] D = sum(p.numel() for p in trainable_parameters) outputs = model(X) loss = loss_function(outputs, y) idx = idx if idx is not None else list(range(D)) for d in idx: e_d = zeros(D, device=loss.device, dtype=loss.dtype) e_d[d] = 1.0 e_d_list = vector_to_parameter_list(e_d, trainable_parameters) ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list) yield d, parameters_to_vector(ggn_d_list)
def extract_ith_element_of_diag_h(i, p, df_dx): v = torch.zeros(p.numel()).to(self.problem.device) v[i] = 1.0 vs = vector_to_parameter_list(v, [p]) Hvs = hvp(df_dx, [p], vs) Hv = torch.cat([g.detach().view(-1) for g in Hvs]) return Hv[i]
def extract_ith_element_of_diag_h(i, p, df_dx): v = zeros_like(p).flatten() v[i] = 1.0 vs = vector_to_parameter_list(v, [p]) Hvs = R_op(df_dx, [p], vs) Hv = cat([g.flatten() for g in Hvs]) return Hv[i]
def _ggn_columns(self, loss: Tensor, output: Tensor) -> Iterator[Tensor]: params = list(self.problem.trainable_parameters()) num_params = sum(p.numel() for p in params) model = self.problem.model for i in range(num_params): # GGN-vector product with i.th unit vector yields the i.th row e_i = zeros(num_params).to(self.problem.device) e_i[i] = 1.0 # convert to model parameter shapes e_i_list = vector_to_parameter_list(e_i, params) ggn_i_list = ggn_vector_product(loss, output, model, e_i_list) yield parameters_to_vector(ggn_i_list)
def _preprocess(self, v_numpy): """Convert to `torch`, block into parameters.""" v_torch = torch.from_numpy(v_numpy).to(self.device) return vector_to_parameter_list(v_torch, self.params)
loss.backward() for (name, param), vec, GGNvec in zip(model.named_parameters(), v, GGNv): print(name) print(".grad.shape: ", param.grad.shape) # vector print("vector shape: ", vec.shape) # Hessian-vector product print("GGN-vector product shape: ", GGNvec.shape) print("# 2) Flattened vector | B =", B) output = model(X) loss = lossfunc(output, y) num_params = sum(p.numel() for p in model.parameters()) v_flat = torch.randn(num_params) v = vector_to_parameter_list(v_flat, model.parameters()) GGNv = ggn_vector_product(loss, output, model, v) GGNv_flat = parameters_to_vector(GGNv) # has to be called afterwards, or with create_graph=True loss.backward() print("Model parameters: ", num_params) # vector print("flat vector shape: ", v_flat.shape) # individual gradient L2 norm print("flat GGN-vector product shape: ", GGNv_flat.shape)
def _get_diag_ggn(self, loss: Tensor, output: Tensor) -> List[Tensor]: diag_ggn_flat = cat([ col[[i]] for i, col in enumerate(self._ggn_columns(loss, output)) ]) return vector_to_parameter_list( diag_ggn_flat, list(self.problem.trainable_parameters()))
print("# GGN matrix with automatic differentiation | B =", B) model = Sequential( Flatten(), Linear(784, 10), ) lossfunc = CrossEntropyLoss() output = model(X) loss = lossfunc(output, y) num_params = sum(p.numel() for p in model.parameters()) ggn = torch.zeros(num_params, num_params) for i in range(num_params): # GGN-vector product with i.th unit vector yields the i.th row e_i = torch.zeros(num_params) e_i[i] = 1.0 # convert to model parameter shapes e_i_list = vector_to_parameter_list(e_i, model.parameters()) ggn_i_list = ggn_vector_product(loss, output, model, e_i_list) ggn_i = parameters_to_vector(ggn_i_list) ggn[i, :] = ggn_i print("Model parameters: ", num_params) print("GGN shape: ", ggn.shape) print("GGN: ", ggn)