コード例 #1
0
def autograd_diag_hessian(loss, params, concat=False):
    """Compute the Hessian diagonal via ``torch.autograd``.

    Args:
        loss (torch.Tensor): Loss whose Hessian is investigated.
        params ([torch.Tensor]): List of torch.Tensors holding the network's
            parameters.
        concat (bool): If ``True``, flatten and concatenate the columns over all
            parameters.

    Returns:
        Tensor: Hessian diagonal.
    """
    D = sum(p.numel() for p in params)
    device = loss.device

    hessian_diag = torch.zeros(D, device=device)

    for d, column_d in enumerate(autograd_hessian_columns(loss, params, concat=True)):
        hessian_diag[d] = column_d[d]

    if concat is False:
        hessian_diag = vector_to_parameter_list(hessian_diag, params)

    return hessian_diag
コード例 #2
0
def autograd_hessian_columns(loss, params, concat=False):
    """Return an iterator of the Hessian columns computed via ``torch.autograd``.

    Args:
        loss (torch.Tensor): Loss whose Hessian is investigated.
        params ([torch.Tensor]): List of torch.Tensors holding the network's
            parameters.
        concat (bool): If ``True``, flatten and concatenate the columns over all
            parameters.

    Yields:
        Tensor: Tensor of Hessian columns.
    """
    D = sum(p.numel() for p in params)
    device = loss.device

    for d in range(D):
        e_d = torch.zeros(D, device=device)
        e_d[d] = 1.0
        e_d_list = vector_to_parameter_list(e_d, params)

        hessian_e_d = hessian_vector_product(loss, params, e_d_list)

        if concat:
            hessian_e_d = torch.cat([tensor.flatten() for tensor in hessian_e_d])

        yield hessian_e_d
コード例 #3
0
 def extract_ith_element_of_diag_ggn(i, p, loss, output):
     v = torch.zeros(p.numel()).to(self.problem.device)
     v[i] = 1.0
     vs = vector_to_parameter_list(v, [p])
     GGN_vs = ggn_vector_product_from_plist(loss, output, [p], vs)
     GGN_v = torch.cat([g.detach().view(-1) for g in GGN_vs])
     return GGN_v[i]
コード例 #4
0
ファイル: examples.py プロジェクト: f-dangel/backpack
def _autograd_ggn_exact_columns(
        X: Tensor,
        y: Tensor,
        model: Module,
        loss_function: Module,
        idx: List[int] = None) -> Iterator[Tuple[int, Tensor]]:
    """Yield exact generalized Gauss-Newton's columns computed with ``torch.autograd``.

    Args:
        X: Input to the model.
        y: Labels.
        model: The neural network.
        loss_function: Loss function module.
        idx: Indices of columns that are computed. Default value ``None`` computes all
            columns.

    Yields:
        Tuple of column index and respective GGN column (flattened and concatenated).
    """
    trainable_parameters = [p for p in model.parameters() if p.requires_grad]
    D = sum(p.numel() for p in trainable_parameters)

    outputs = model(X)
    loss = loss_function(outputs, y)

    idx = idx if idx is not None else list(range(D))

    for d in idx:
        e_d = zeros(D, device=loss.device, dtype=loss.dtype)
        e_d[d] = 1.0
        e_d_list = vector_to_parameter_list(e_d, trainable_parameters)

        ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list)

        yield d, parameters_to_vector(ggn_d_list)
コード例 #5
0
        def extract_ith_element_of_diag_h(i, p, df_dx):
            v = torch.zeros(p.numel()).to(self.problem.device)
            v[i] = 1.0
            vs = vector_to_parameter_list(v, [p])

            Hvs = hvp(df_dx, [p], vs)
            Hv = torch.cat([g.detach().view(-1) for g in Hvs])

            return Hv[i]
コード例 #6
0
ファイル: autograd.py プロジェクト: f-dangel/backpack
        def extract_ith_element_of_diag_h(i, p, df_dx):
            v = zeros_like(p).flatten()
            v[i] = 1.0
            vs = vector_to_parameter_list(v, [p])

            Hvs = R_op(df_dx, [p], vs)
            Hv = cat([g.flatten() for g in Hvs])

            return Hv[i]
コード例 #7
0
ファイル: autograd.py プロジェクト: f-dangel/backpack
    def _ggn_columns(self, loss: Tensor, output: Tensor) -> Iterator[Tensor]:
        params = list(self.problem.trainable_parameters())
        num_params = sum(p.numel() for p in params)
        model = self.problem.model

        for i in range(num_params):
            # GGN-vector product with i.th unit vector yields the i.th row
            e_i = zeros(num_params).to(self.problem.device)
            e_i[i] = 1.0

            # convert to model parameter shapes
            e_i_list = vector_to_parameter_list(e_i, params)
            ggn_i_list = ggn_vector_product(loss, output, model, e_i_list)

            yield parameters_to_vector(ggn_i_list)
コード例 #8
0
 def _preprocess(self, v_numpy):
     """Convert to `torch`, block into parameters."""
     v_torch = torch.from_numpy(v_numpy).to(self.device)
     return vector_to_parameter_list(v_torch, self.params)
コード例 #9
0
loss.backward()

for (name, param), vec, GGNvec in zip(model.named_parameters(), v, GGNv):
    print(name)
    print(".grad.shape:              ", param.grad.shape)
    # vector
    print("vector shape:             ", vec.shape)
    # Hessian-vector product
    print("GGN-vector product shape: ", GGNvec.shape)

print("# 2) Flattened vector | B =", B)

output = model(X)
loss = lossfunc(output, y)

num_params = sum(p.numel() for p in model.parameters())
v_flat = torch.randn(num_params)

v = vector_to_parameter_list(v_flat, model.parameters())
GGNv = ggn_vector_product(loss, output, model, v)
GGNv_flat = parameters_to_vector(GGNv)

# has to be called afterwards, or with create_graph=True
loss.backward()

print("Model parameters:              ", num_params)
# vector
print("flat vector shape:             ", v_flat.shape)
# individual gradient L2 norm
print("flat GGN-vector product shape: ", GGNv_flat.shape)
コード例 #10
0
ファイル: autograd.py プロジェクト: f-dangel/backpack
 def _get_diag_ggn(self, loss: Tensor, output: Tensor) -> List[Tensor]:
     diag_ggn_flat = cat([
         col[[i]] for i, col in enumerate(self._ggn_columns(loss, output))
     ])
     return vector_to_parameter_list(
         diag_ggn_flat, list(self.problem.trainable_parameters()))
コード例 #11
0
print("# GGN matrix with automatic differentiation | B =", B)

model = Sequential(
    Flatten(),
    Linear(784, 10),
)
lossfunc = CrossEntropyLoss()

output = model(X)
loss = lossfunc(output, y)

num_params = sum(p.numel() for p in model.parameters())
ggn = torch.zeros(num_params, num_params)

for i in range(num_params):
    # GGN-vector product with i.th unit vector yields the i.th row
    e_i = torch.zeros(num_params)
    e_i[i] = 1.0

    # convert to model parameter shapes
    e_i_list = vector_to_parameter_list(e_i, model.parameters())
    ggn_i_list = ggn_vector_product(loss, output, model, e_i_list)

    ggn_i = parameters_to_vector(ggn_i_list)
    ggn[i, :] = ggn_i

print("Model parameters: ", num_params)
print("GGN shape:        ", ggn.shape)
print("GGN:              ", ggn)