Exemplo n.º 1
0
def brute_force_hessian(layer_idx, parallel_idx, which):
    """Compute Hessian of loss w.r.t. parameter in layer."""
    if layer_idx > 2:
        raise ValueError("Works only for indices from 0 to 2")
    parallel = example_sequence_parallel()
    _, loss = forward(parallel, random_input())
    if which == "weight":
        return exact_hessian(
            loss,
            [
                list(parallel.children())[layer_idx]
                ._get_parallel_module(parallel_idx)
                .weight
            ],
        )
    elif which == "bias":
        return exact_hessian(
            loss,
            [
                list(parallel.children())[layer_idx]
                ._get_parallel_module(parallel_idx)
                .bias
            ],
        )
    else:
        raise ValueError
Exemplo n.º 2
0
def brute_force_parameter_hessian(which):
    """Compute Hessian of loss w.r.t. the weights."""
    layer = example_linear_parallel()
    input = random_input()
    _, loss = forward(layer, input)
    if which == "weight":
        return [
            exact_hessian(loss, [child.weight])
            for child in layer.parallel_children()
        ]
    elif which == "bias":
        return [
            exact_hessian(loss, [child.bias])
            for child in layer.parallel_children()
        ]
Exemplo n.º 3
0
 def _torch_input_hvp(self):
     """Create Hessian-vector product routine for torch layer."""
     layer = self._create_torch_layer()
     x, y = self._create_input()
     x.requires_grad = True
     out = layer(x, y)
     hessian_x = exact_hessian(out, [x]).detach().to(self.DEVICE)
     return hessian_x.matmul
Exemplo n.º 4
0
 def _torch_parameter_hvp(self):
     """Yield block-wise HVP with the parameter Hessians"""
     layer = self._create_torch_layer()
     x = self._create_input()
     x.requires_grad = True
     out = layer(x)
     loss = self._loss_fn(out)
     for p in layer.parameters():
         yield exact_hessian(loss, [p]).detach().to(self.DEVICE)
Exemplo n.º 5
0
 def _cvp_after_hessian_backward(self):
     """Return the CVP layer after performing CVP."""
     layer = self._create_cvp_layer()
     x = self._create_input()
     x.requires_grad = True
     out = layer(x)
     loss = self._loss_fn(out)
     # required for nonlinear layers (need to save backprop quantities)
     loss_hessian_vp = exact_hessian(loss, [out]).detach().to(
         self.DEVICE).matmul
     loss.backward()
     layer.backward_hessian(loss_hessian_vp)
     return layer
Exemplo n.º 6
0
 def _cvp_input_hvp(self):
     """Create Hessian-vector product routine for CVP layer."""
     layer = self._create_cvp_layer()
     x = self._create_input()
     x.requires_grad = True
     out = layer(x)
     loss = self._loss_fn(out)
     # required for nonlinear layers (need to save backprop quantities)
     loss_hessian_vp = exact_hessian(loss, [out]).detach().to(
         self.DEVICE).matmul
     loss.backward()
     hessian_x = layer.backward_hessian(loss_hessian_vp,
                                        compute_input_hessian=True)
     return hessian_x
Exemplo n.º 7
0
def brute_force_input_hessian():
    """Compute the Hessian with respect to the input by brute force."""
    layer = example_sequence_parallel()
    input = random_input()
    _, loss = forward(layer, input)
    return exact_hessian(loss, [input])
Exemplo n.º 8
0
def layer_input_hessian():
    """Compute the Hessian with respect to the input."""
    layer, x, out, loss = layer_with_input_output_and_loss()
    input_hessian = exact_hessian(loss, [x])
    return input_hessian
Exemplo n.º 9
0
def brute_force_input_hessian():
    """Compute Hessian of loss w.r.t. input."""
    parallel = example_linear_parallel()
    input = random_input()
    _, loss = forward(parallel, input)
    return exact_hessian(loss, [input])