Beispiel #1
0
    def _second_order_module_effects(self, module, ext, g_inp, g_out):
        if self.derivatives.hessian_is_zero():
            return None
        if not Curvature.require_residual(ext.get_curv_type()):
            return None

        if not self.derivatives.hessian_is_diagonal():
            raise NotImplementedError(
                "Residual terms are only supported for elementwise functions")

        return self.derivatives.hessian_diagonal(module, g_inp, g_out)
Beispiel #2
0
    def _require_residual(self, ext, module, g_inp, g_out, backproped):
        """Is the residual term required for multiply with the curvature?"""
        vanishes = self.derivatives.hessian_is_zero()
        neglected = not Curvature.require_residual(ext.get_curv_type())

        return not (vanishes or neglected)