def lstsq(A, b, use_scipy=False): if use_scipy: return _lstsq(A, b)[0] else: if A.ndim == 3: ret_mat = np.empty((A.shape[-1], A.shape[0])) for i in range(A.shape[0]): ret_mat[:, i] = direct_solve(A[i].T.dot(A[i]), A[i].T.dot(b[:, i])) return ret_mat return direct_solve(A.T.dot(A), A.T.dot(b))
def fit(self, A, B): """ Solve for X: AX = B""" self.X_, self.residual_, self.rank_, self.svs_ = _lstsq(A, B)
def compute_update(self, oks, grad, out=None): r""" Solves the SR flow equation for the parameter update ẋ. The SR update is computed by solving the linear equation Sẋ = f where S is the covariance matrix of the partial derivatives O_i(v_j) = ∂/∂x_i log Ψ(v_j) and f is a generalized force (the loss gradient). Args: oks: The matrix of log-derivatives, O_i(v_j) grad: The vector of forces f. out: Output array for the update ẋ. """ oks -= _mean(oks, axis=0) if self.has_complex_parameters is None: raise ValueError( "has_complex_parameters not set: this SR object is not properly initialized." ) n_samp = _sum_inplace(_np.atleast_1d(oks.shape[0])) n_par = grad.shape[0] if out is None: out = _np.zeros(n_par, dtype=_np.complex128) if self._has_complex_parameters: if self._use_iterative: op = self._linear_operator(oks, n_samp) if self._x0 is None: self._x0 = _np.zeros(n_par, dtype=_np.complex128) out[:], info = self._sparse_solver( op, grad, x0=self._x0, tol=self.sparse_tol, maxiter=self.sparse_maxiter, ) if info < 0: raise RuntimeError("SR sparse solver did not converge.") self._x0 = out else: self._S = _np.matmul(oks.conj().T, oks, self._S) self._S = _sum_inplace(self._S) self._S /= float(n_samp) self._apply_preconditioning(grad) if self._lsq_solver == "Cholesky": c, low = _cho_factor(self._S, check_finite=False) out[:] = _cho_solve((c, low), grad) else: out[:], residuals, self._last_rank, s_vals = _lstsq( self._S, grad, cond=self._svd_threshold, lapack_driver=self._lapack_driver, ) self._revert_preconditioning(out) else: if self._use_iterative: op = self._linear_operator(oks, n_samp) if self._x0 is None: self._x0 = _np.zeros(n_par) out[:].real, info = self._sparse_solver( op, grad.real, x0=self._x0, tol=self.sparse_tol, maxiter=self.sparse_maxiter, ) if info < 0: raise RuntimeError("SR sparse solver did not converge.") self._x0 = out.real else: self._S = _np.matmul(oks.conj().T, oks, self._S) self._S /= float(n_samp) self._apply_preconditioning(grad) if self._lsq_solver == "Cholesky": c, low = _cho_factor(self._S, check_finite=False) out[:].real = _cho_solve((c, low), grad) else: out[:].real, residuals, self._last_rank, s_vals = _lstsq( self._S.real, grad.real, cond=self._svd_threshold, lapack_driver=self._lapack_driver, ) self._revert_preconditioning(out.real) out.imag.fill(0.0) if _n_nodes > 1: self._comm.bcast(out, root=0) self._comm.barrier() return out