def brute_force_hessian(layer_idx, parallel_idx, which): """Compute Hessian of loss w.r.t. parameter in layer.""" if layer_idx > 2: raise ValueError("Works only for indices from 0 to 2") parallel = example_sequence_parallel() _, loss = forward(parallel, random_input()) if which == "weight": return exact_hessian( loss, [ list(parallel.children())[layer_idx] ._get_parallel_module(parallel_idx) .weight ], ) elif which == "bias": return exact_hessian( loss, [ list(parallel.children())[layer_idx] ._get_parallel_module(parallel_idx) .bias ], ) else: raise ValueError
def brute_force_parameter_hessian(which): """Compute Hessian of loss w.r.t. the weights.""" layer = example_linear_parallel() input = random_input() _, loss = forward(layer, input) if which == "weight": return [ exact_hessian(loss, [child.weight]) for child in layer.parallel_children() ] elif which == "bias": return [ exact_hessian(loss, [child.bias]) for child in layer.parallel_children() ]
def _torch_input_hvp(self): """Create Hessian-vector product routine for torch layer.""" layer = self._create_torch_layer() x, y = self._create_input() x.requires_grad = True out = layer(x, y) hessian_x = exact_hessian(out, [x]).detach().to(self.DEVICE) return hessian_x.matmul
def _torch_parameter_hvp(self): """Yield block-wise HVP with the parameter Hessians""" layer = self._create_torch_layer() x = self._create_input() x.requires_grad = True out = layer(x) loss = self._loss_fn(out) for p in layer.parameters(): yield exact_hessian(loss, [p]).detach().to(self.DEVICE)
def _cvp_after_hessian_backward(self): """Return the CVP layer after performing CVP.""" layer = self._create_cvp_layer() x = self._create_input() x.requires_grad = True out = layer(x) loss = self._loss_fn(out) # required for nonlinear layers (need to save backprop quantities) loss_hessian_vp = exact_hessian(loss, [out]).detach().to( self.DEVICE).matmul loss.backward() layer.backward_hessian(loss_hessian_vp) return layer
def _cvp_input_hvp(self): """Create Hessian-vector product routine for CVP layer.""" layer = self._create_cvp_layer() x = self._create_input() x.requires_grad = True out = layer(x) loss = self._loss_fn(out) # required for nonlinear layers (need to save backprop quantities) loss_hessian_vp = exact_hessian(loss, [out]).detach().to( self.DEVICE).matmul loss.backward() hessian_x = layer.backward_hessian(loss_hessian_vp, compute_input_hessian=True) return hessian_x
def brute_force_input_hessian(): """Compute the Hessian with respect to the input by brute force.""" layer = example_sequence_parallel() input = random_input() _, loss = forward(layer, input) return exact_hessian(loss, [input])
def layer_input_hessian(): """Compute the Hessian with respect to the input.""" layer, x, out, loss = layer_with_input_output_and_loss() input_hessian = exact_hessian(loss, [x]) return input_hessian
def brute_force_input_hessian(): """Compute Hessian of loss w.r.t. input.""" parallel = example_linear_parallel() input = random_input() _, loss = forward(parallel, input) return exact_hessian(loss, [input])