Beispiel #1
0
    def backpropagate(self, ext, module, g_inp, g_out, backproped):
        Curvature.check_loss_hessian(
            self.derivatives.hessian_is_psd(), curv_type=ext.get_curv_type()
        )

        H_func = self.make_loss_hessian_func(ext)
        H_loss = H_func(module, g_inp, g_out)

        return H_loss
Beispiel #2
0
    def backpropagate(self, ext, module, g_inp, g_out, backproped):
        Curvature.check_loss_hessian(
            self.derivatives.hessian_is_psd(),
            curv_type=ext.get_curv_type()
        )

        hessian_strategy = ext.get_loss_hessian_strategy()
        H_func = self.LOSS_HESSIAN_GETTERS[hessian_strategy]
        H_loss = H_func(module, g_inp, g_out)

        return H_loss
Beispiel #3
0
    def backpropagate_batch_average(self, ext, module, g_inp, g_out, H):
        ggn = self.derivatives.ea_jac_t_mat_jac_prod(module, g_inp, g_out, H)

        residual = self.second_order_module_effects(module, g_inp, g_out)
        residual_mod = Curvature.modify_residual(residual, ext.get_curv_type())

        if residual_mod is not None:
            ggn = self.add_diag_to_mat(residual_mod, ggn)

        return ggn
Beispiel #4
0
    def _second_order_module_effects(self, module, ext, g_inp, g_out):
        if self.derivatives.hessian_is_zero():
            return None
        if not Curvature.require_residual(ext.get_curv_type()):
            return None

        if not self.derivatives.hessian_is_diagonal():
            raise NotImplementedError(
                "Residual terms are only supported for elementwise functions")

        return self.derivatives.hessian_diagonal(module, g_inp, g_out)
Beispiel #5
0
    def __make_nondiagonal_R_mat_prod(self, ext, module, g_inp, g_out, backproped):
        curv_type = ext.get_curv_type()
        if not Curvature.is_pch(curv_type):
            R_mat_prod = self.derivatives.make_residual_mat_prod(module, g_inp, g_out)
        else:
            raise ValueError(
                "{} not supported for {}. Residual cannot be cast PSD.".format(
                    curv_type, module
                )
            )

        return R_mat_prod
Beispiel #6
0
    def __make_diagonal_R_mat_prod(self, ext, module, g_inp, g_out, backproped):
        # TODO Refactor core: hessian_diagonal -> residual_diagonal
        R = self.derivatives.hessian_diagonal(module, g_inp, g_out)
        R_mod = Curvature.modify_residual(R, ext.get_curv_type())

        @R_mat_prod_accept_vectors
        @R_mat_prod_check_shapes
        def make_residual_mat_prod(self, module, g_inp, g_out):
            def R_mat_prod(mat):
                """Multiply with the residual: mat Рєњ [РѕЉ_{k} Hz_k(x) ­ЮЏ┐z_k] mat.

                Second term of the module input Hessian backpropagation equation.
                """
                return einsum("n...,vn...->vn...", (R_mod, mat))

            return R_mat_prod

        return make_residual_mat_prod(self, module, g_inp, g_out)
Beispiel #7
0
    def _require_residual(self, ext, module, g_inp, g_out, backproped):
        """Is the residual term required for multiply with the curvature?"""
        vanishes = self.derivatives.hessian_is_zero()
        neglected = not Curvature.require_residual(ext.get_curv_type())

        return not (vanishes or neglected)
Beispiel #8
0
    def backpropagate(self, ext, module, g_inp, g_out, backproped):
        Curvature.check_loss_hessian(self.derivatives.hessian_is_psd(),
                                     curv_type=ext.get_curv_type())

        CMP = self.derivatives.hessian_matrix_product(module, g_inp, g_out)
        return CMP
Beispiel #9
0
 def _modify_residual(self, ext, residual):
     return Curvature.modify_residual(residual, ext.get_curv_type())