def hbp_with_curv(self, curv_type, loss_hessian_strategy=LossHessianStrategy.AVERAGE, backprop_strategy=BackpropStrategy.BATCH_AVERAGE, ea_strategy=ExpectationApproximation.BOTEV_MARTENS): results = [] with backpack( new_ext.HBP( curv_type=curv_type, loss_hessian_strategy=loss_hessian_strategy, backprop_strategy=backprop_strategy, ea_strategy=ea_strategy, )): self.loss().backward() for p in self.model.parameters(): factors = p.hbp results.append(matrix_from_kron_facs(factors)) return results
def test_interface_hbp_conv(): interface_test(new_ext.HBP(), use_conv=True)
def test_interface_hbp(): interface_test(new_ext.HBP())