def _batch_sqrt_solve(self, rhs): # Recall the square root of this operator is M + VDV^T. # The Woodbury formula gives: # (M + VDV^T)^{-1} # = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1} # = M^{-1} - M^{-1} V C^{-1} V^T M^{-1} # where C is the capacitance matrix. m = self._operator v = self._v cchol = self._chol_capacitance(batch_mode=True) # The operators will use batch/singleton mode automatically. We don't # override. # M^{-1} rhs minv_rhs = m.solve(rhs) # V^T M^{-1} rhs vt_minv_rhs = math_ops.batch_matmul(v, minv_rhs, adj_x=True) # C^{-1} V^T M^{-1} rhs cinv_vt_minv_rhs = linalg_ops.batch_cholesky_solve(cchol, vt_minv_rhs) # V C^{-1} V^T M^{-1} rhs v_cinv_vt_minv_rhs = math_ops.batch_matmul(v, cinv_vt_minv_rhs) # M^{-1} V C^{-1} V^T M^{-1} rhs minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs) # M^{-1} - M^{-1} V C^{-1} V^T M^{-1} return minv_rhs - minv_v_cinv_vt_minv_rhs
def _batch_solve(self, rhs): return linalg_ops.batch_cholesky_solve(self._chol, rhs)