def matrices_from_kronecker_curvature(self, extension_cls, savefield):
     results = []
     with backpack(extension_cls()):
         self.loss().backward()
         for p in self.model.parameters():
             factors = getattr(p, savefield)
             results.append(kfacs_to_mat(factors))
     return results
Example #2
0
    def test_kfacs_to_mat(self):
        """Check matrix from list of Kronecker factors with `scipy`."""
        for _ in range(self.RUNS):
            factors = self.make_random_kfacs()

            bp_result = bp_utils.kfacs_to_mat(factors)
            sp_result = self.scipy_kfacs_to_mat(factors)

            assert self.allclose(bp_result, sp_result)
Example #3
0
def test_kfac_should_approx_ggn_montecarlo(problem: ExtensionsTestProblem):
    """Check that for batch_size = 1, the K-FAC is the same as the GGN.

    Should be true for linear layers and in the limit of infinite mc_samples.

    Args:
        problem: Test case.
    """
    problem.set_up()
    autograd_res = AutogradExtensions(problem).ggn_blocks()

    mc_samples = 300000
    backpack_kfac = BackpackExtensions(problem).kfac_chunk(mc_samples)
    backpack_res = [kfacs_to_mat(kfac) for kfac in backpack_kfac]

    check_sizes_and_values(autograd_res, backpack_res, atol=5e-3, rtol=5e-3)

    problem.tear_down()
 def hbp_with_curv(
     self,
     curv_type,
     loss_hessian_strategy=LossHessianStrategy.SUM,
     backprop_strategy=BackpropStrategy.BATCH_AVERAGE,
     ea_strategy=ExpectationApproximation.BOTEV_MARTENS,
 ):
     results = []
     with backpack(
             new_ext.HBP(
                 curv_type=curv_type,
                 loss_hessian_strategy=loss_hessian_strategy,
                 backprop_strategy=backprop_strategy,
                 ea_strategy=ea_strategy,
             )):
         self.loss().backward()
         for p in self.model.parameters():
             factors = p.hbp
             results.append(kfacs_to_mat(factors))
     return results
Example #5
0
 def set_up():
     factors = self.make_random_kfacs()
     kfac = bp_utils.kfacs_to_mat(factors)
     tensor = make_tensor(kfac)
     return factors, kfac, tensor