def lstsq(A, b, use_scipy=False):
    if use_scipy:
        return _lstsq(A, b)[0]
    else:
        if A.ndim == 3:
            ret_mat = np.empty((A.shape[-1], A.shape[0]))
            for i in range(A.shape[0]):
                ret_mat[:, i] = direct_solve(A[i].T.dot(A[i]),
                                             A[i].T.dot(b[:, i]))

            return ret_mat
        return direct_solve(A.T.dot(A), A.T.dot(b))
Ejemplo n.º 2
0
 def fit(self, A, B):
     """ Solve for X: AX = B"""
     self.X_, self.residual_, self.rank_, self.svs_ = _lstsq(A, B)
Ejemplo n.º 3
0
    def compute_update(self, oks, grad, out=None):
        r"""
        Solves the SR flow equation for the parameter update ẋ.

        The SR update is computed by solving the linear equation
           Sẋ = f
        where S is the covariance matrix of the partial derivatives
        O_i(v_j) = ∂/∂x_i log Ψ(v_j) and f is a generalized force (the loss
        gradient).

        Args:
            oks: The matrix of log-derivatives,
           	O_i(v_j)
            grad: The vector of forces f.
            out: Output array for the update ẋ.
        """

        oks -= _mean(oks, axis=0)

        if self.has_complex_parameters is None:
            raise ValueError(
                "has_complex_parameters not set: this SR object is not properly initialized."
            )

        n_samp = _sum_inplace(_np.atleast_1d(oks.shape[0]))

        n_par = grad.shape[0]

        if out is None:
            out = _np.zeros(n_par, dtype=_np.complex128)

        if self._has_complex_parameters:
            if self._use_iterative:
                op = self._linear_operator(oks, n_samp)

                if self._x0 is None:
                    self._x0 = _np.zeros(n_par, dtype=_np.complex128)

                out[:], info = self._sparse_solver(
                    op,
                    grad,
                    x0=self._x0,
                    tol=self.sparse_tol,
                    maxiter=self.sparse_maxiter,
                )
                if info < 0:
                    raise RuntimeError("SR sparse solver did not converge.")

                self._x0 = out
            else:
                self._S = _np.matmul(oks.conj().T, oks, self._S)
                self._S = _sum_inplace(self._S)
                self._S /= float(n_samp)

                self._apply_preconditioning(grad)

                if self._lsq_solver == "Cholesky":
                    c, low = _cho_factor(self._S, check_finite=False)
                    out[:] = _cho_solve((c, low), grad)

                else:
                    out[:], residuals, self._last_rank, s_vals = _lstsq(
                        self._S,
                        grad,
                        cond=self._svd_threshold,
                        lapack_driver=self._lapack_driver,
                    )

                self._revert_preconditioning(out)

        else:
            if self._use_iterative:
                op = self._linear_operator(oks, n_samp)

                if self._x0 is None:
                    self._x0 = _np.zeros(n_par)

                out[:].real, info = self._sparse_solver(
                    op,
                    grad.real,
                    x0=self._x0,
                    tol=self.sparse_tol,
                    maxiter=self.sparse_maxiter,
                )
                if info < 0:
                    raise RuntimeError("SR sparse solver did not converge.")
                self._x0 = out.real
            else:
                self._S = _np.matmul(oks.conj().T, oks, self._S)
                self._S /= float(n_samp)

                self._apply_preconditioning(grad)

                if self._lsq_solver == "Cholesky":
                    c, low = _cho_factor(self._S, check_finite=False)
                    out[:].real = _cho_solve((c, low), grad)
                else:
                    out[:].real, residuals, self._last_rank, s_vals = _lstsq(
                        self._S.real,
                        grad.real,
                        cond=self._svd_threshold,
                        lapack_driver=self._lapack_driver,
                    )

                self._revert_preconditioning(out.real)

            out.imag.fill(0.0)

        if _n_nodes > 1:
            self._comm.bcast(out, root=0)
            self._comm.barrier()

        return out