def hvp_applied_columnwise(self, f, p, mat): h_cols = [] for i in range(mat.size(0)): hvp_col_i = hessian_vector_product(f, [p], mat[i, :])[0] h_cols.append(hvp_col_i.unsqueeze(0)) return torch.cat(h_cols, dim=0)
def autograd_hessian_columns(loss, params, concat=False): """Return an iterator of the Hessian columns computed via ``torch.autograd``. Args: loss (torch.Tensor): Loss whose Hessian is investigated. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. concat (bool): If ``True``, flatten and concatenate the columns over all parameters. Yields: Tensor: Tensor of Hessian columns. """ D = sum(p.numel() for p in params) device = loss.device for d in range(D): e_d = torch.zeros(D, device=device) e_d[d] = 1.0 e_d_list = vector_to_parameter_list(e_d, params) hessian_e_d = hessian_vector_product(loss, params, e_d_list) if concat: hessian_e_d = torch.cat([tensor.flatten() for tensor in hessian_e_d]) yield hessian_e_d
def autograd_hessian(loss, x): """Return the Hessian matrix of `loss` w.r.t. `x`. Arguments: loss (torch.Tensor): A scalar-valued tensor. x (torch.Tensor): Tensor used in the computation graph of `loss`. Shapes: loss: `[1,]` x: `[A, B, C, ...]` Returns: torch.Tensor: Hessian tensor of `loss` w.r.t. `x`. The Hessian has shape `[A, B, C, ..., A, B, C, ...]`. """ assert loss.numel() == 1 vectorized_shape = (x.numel(), x.numel()) final_shape = (*x.shape, *x.shape) hessian_vec_x = torch.zeros(vectorized_shape) num_cols = hessian_vec_x.shape[1] for column_idx in range(num_cols): unit = torch.zeros(num_cols) unit[column_idx] = 1.0 unit = unit.view_as(x) column = hessian_vector_product(loss, [x], [unit])[0].reshape(-1) hessian_vec_x[:, column_idx] = column return hessian_vec_x.reshape(final_shape)
def hvp_applied_columnwise(self, f, p, mat): h_cols = [] for i in range(mat.size(1)): hvp_col_i = hessian_vector_product(f, [p], mat[:, i].view_as(p))[0] h_cols.append(hvp_col_i.view(-1, 1)) return torch.cat(h_cols, dim=1)
def _hessian(self, loss: Tensor, x: Tensor) -> Tensor: """Return the Hessian matrix of a scalar `loss` w.r.t. a tensor `x`. Args: loss: A scalar-valued tensor. x: Tensor used in the computation graph of `loss`. Shapes: loss: `[1,]` x: `[A, B, C, ...]` Returns: Hessian tensor of `loss` w.r.t. `x`. The Hessian has shape `[A, B, C, ..., A, B, C, ...]`. """ assert loss.numel() == 1 vectorized_shape = (x.numel(), x.numel()) final_shape = (*x.shape, *x.shape) hessian_vec_x = zeros(vectorized_shape).to(loss.device) num_cols = hessian_vec_x.shape[1] for column_idx in range(num_cols): unit = zeros(num_cols).to(loss.device) unit[column_idx] = 1.0 unit = unit.view_as(x) column = hessian_vector_product(loss, [x], [unit])[0].reshape(-1) hessian_vec_x[:, column_idx] = column return hessian_vec_x.reshape(final_shape)
def hutchinson_trace_autodiff_blockwise(V): """Hessian trace estimate using autodiff block HVPs.""" trace = 0 for _ in range(V): for p in model.parameters(): v = [rademacher(p.shape)] Hv = hessian_vector_product(loss, [p], v) vHv = torch.einsum("i,i->", v[0].flatten(), Hv[0].flatten()) trace += vHv / V return trace
def hutchinson_trace_autodiff(V): """Hessian trace estimate using autodiff HVPs.""" trace = 0 for _ in range(V): vec = [rademacher(p.shape) for p in model.parameters()] Hvec = hessian_vector_product(loss, list(model.parameters()), vec) for v, Hv in zip(vec, Hvec): vHv = torch.einsum("i,i->", v.flatten(), Hv.flatten()) trace += vHv / V return trace
def ggn_vector_product_from_plist(loss: Tensor, output: Tensor, plist: List[Parameter], v: List[Tensor]) -> Tuple[Tensor]: """Multiply a vector with a sub-block of the generalized Gauss-Newton/Fisher. Args: loss: Scalar tensor that represents the loss. output: Model output. plist: List of trainable parameters whose GGN block is used for multiplication. v: Vector specified as list of tensors matching the sizes of ``plist``. Returns: GGN-vector product in list format, i.e. as list that matches the sizes of ``plist``. """ Jv = R_op(output, plist, v) HJv = hessian_vector_product(loss, output, Jv) JTHJv = L_op(output, plist, HJv) return JTHJv
def hessian_vector_product(self, v_torch): """Multiply by the Hessian using autodiff in torch.""" return hessian_vector_product( self.loss, self.params, v_torch, grad_params=self.grad_params )
def hvp(v): vecs = vector_to_parameter_list(v, params) results = hessian_vector_product(f, params, vecs) return flatten_and_concatenate(results)
X, y = load_mnist_data(B) print("# Hessian-vector product and gradients with PyTorch | B =", B) model = Sequential( Flatten(), Linear(784, 10), ) lossfunc = CrossEntropyLoss() print("# 1) Vector with shapes like parameters | B =", B) loss = lossfunc(model(X), y) v = [torch.randn_like(p) for p in model.parameters()] Hv = hessian_vector_product(loss, list(model.parameters()), v) # has to be called afterwards, or with create_graph=True loss.backward() for (name, param), vec, Hvec in zip(model.named_parameters(), v, Hv): print(name) print(".grad.shape: ", param.grad.shape) # vector print("vector shape: ", vec.shape) # Hessian-vector product print("Hessian-vector product shape: ", Hvec.shape) print("# 2) Flattened vector | B =", B) loss = lossfunc(model(X), y)
def hessian_mat_prod(self, mat: Tensor) -> Tensor: # noqa: D102 input, output, _ = self.problem.forward_pass(input_requires_grad=True) return stack( [hessian_vector_product(output, [input], [vec])[0] for vec in mat])
print("# 1) Hessian matrix with automatic differentiation | B =", B) loss = lossfunc(model(X), y) num_params = sum(p.numel() for p in model.parameters()) hessian = torch.zeros(num_params, num_params) start = time.time() for i in range(num_params): # GGN-vector product with i.th unit vector yields the i.th row e_i = torch.zeros(num_params) e_i[i] = 1.0 # convert to model parameter shapes e_i_list = vector_to_parameter_list(e_i, model.parameters()) hessian_i_list = hessian_vector_product(loss, list(model.parameters()), e_i_list) hessian_i = parameters_to_vector(hessian_i_list) hessian[i, :] = hessian_i end = time.time() print("Model parameters: ", num_params) print("Hessian shape: ", hessian.shape) print("Hessian: ", hessian) print("Time [s]: ", end - start) print("# 2) Hessian matrix with automatic differentiation (faster) | B =", B) print("# Save one backpropagation for each HVP by recycling gradients") loss = lossfunc(model(X), y) loss.backward(create_graph=True)