Пример #1
0
def Odagger_DeltaO_v(samples, params, v, forward_fn, vjp_fun=None):
    r"""
    compute \langle O^\dagger \DeltaO \rangle v
    where \DeltaO = O - \langle O \rangle

    optional: pass jvp_fun to be reused
    """

    # here the allreduce is deferred until after the dot product,
    # where only scalars instead of vectors have to be summed
    # the vjp_fun is returned so that it can be reused in OH_w below
    omean, vjp_fun = O_mean(
        samples,
        params,
        forward_fn,
        return_vjp_fun=True,
        vjp_fun=vjp_fun,
        allreduce=False,
    )
    omeanv = tree_dot(omean, v)  # omeanv = omean.dot(v); is a scalar
    omeanv = _sum_inplace(omeanv)  # MPI Allreduce w/ MPI_SUM

    # v_tilde is an array of size n_samples; each MPI rank has its own slice
    v_tilde = O_jvp(samples, params, v, forward_fn)
    # v_tilde -= omeanv (elementwise):
    v_tilde = v_tilde - omeanv
    # v_tilde /= n_samples (elementwise):
    v_tilde = v_tilde * (1.0 / (samples.shape[0] * n_nodes))

    return OH_w(samples, params, v_tilde, forward_fn, vjp_fun=vjp_fun)
Пример #2
0
            def matvec(v):
                v_tilde = self._v_tilde
                res = self._res_t

                v_tilde = _np.matmul(oks, v, v_tilde) / float(n_samp)
                res = _np.matmul(v_tilde, oks_conj, res)
                res = _sum_inplace(res) + shift * v
                return res
def _S_grad_mul(oks, v):
    r"""
    Computes y = 1/N * ( O^\dagger * O * v ) where v is a vector of
    length n_parameters, and O is a matrix (n_samples, n_parameters)
    """
    # TODO apply the transpose of sum_inplace (allreduce) to v here
    # in order to get correct transposition with MPI
    v_tilde = jnp.matmul(oks, v) / (oks.shape[0] * n_nodes)
    y = jnp.matmul(oks.conjugate().transpose(), v_tilde)
    return _sum_inplace(y)
Пример #4
0
    def compute_update(self, oks, grad, out=None):
        r"""
        Solves the SR flow equation for the parameter update ẋ.

        The SR update is computed by solving the linear equation
           Sẋ = f
        where S is the covariance matrix of the partial derivatives
        O_i(v_j) = ∂/∂x_i log Ψ(v_j) and f is a generalized force (the loss
        gradient).

        Args:
            oks: The matrix of log-derivatives,
                O_i(v_j)
            grad: The vector of forces f.
            out: Output array for the update ẋ.
        """

        oks -= _mean(oks, axis=0)

        if self.has_complex_parameters is None:
            raise ValueError(
                "has_complex_parameters not set: this SR object is not properly initialized."
            )

        n_samp = _sum_inplace(_np.atleast_1d(oks.shape[0]))

        n_par = grad.shape[0]

        if out is None:
            out = _np.zeros(n_par, dtype=_np.complex128)

        if self._has_complex_parameters:
            if self._use_iterative:
                op = self._linear_operator(oks, n_samp)

                if self._x0 is None:
                    self._x0 = _np.zeros(n_par, dtype=_np.complex128)

                out[:], info = self._sparse_solver(
                    op,
                    grad,
                    x0=self._x0,
                    tol=self.sparse_tol,
                    maxiter=self.sparse_maxiter,
                )
                if info < 0:
                    raise RuntimeError("SR sparse solver did not converge.")

                self._x0 = out
            else:
                self._S = _np.matmul(oks.conj().T, oks, self._S)
                self._S = _sum_inplace(self._S)
                self._S /= float(n_samp)

                self._apply_preconditioning(grad)

                if self._lsq_solver == "Cholesky":
                    c, low = _cho_factor(self._S, check_finite=False)
                    out[:] = _cho_solve((c, low), grad)

                else:
                    out[:], residuals, self._last_rank, s_vals = _lstsq(
                        self._S,
                        grad,
                        cond=self._svd_threshold,
                        lapack_driver=self._lapack_driver,
                    )

                self._revert_preconditioning(out)

        else:
            if self._use_iterative:
                op = self._linear_operator(oks, n_samp)

                if self._x0 is None:
                    self._x0 = _np.zeros(n_par)

                out[:].real, info = self._sparse_solver(
                    op,
                    grad.real,
                    x0=self._x0,
                    tol=self.sparse_tol,
                    maxiter=self.sparse_maxiter,
                )
                if info < 0:
                    raise RuntimeError("SR sparse solver did not converge.")
                self._x0 = out.real
            else:
                self._S = _np.matmul(oks.conj().T, oks, self._S)
                self._S /= float(n_samp)

                self._apply_preconditioning(grad)

                if self._lsq_solver == "Cholesky":
                    c, low = _cho_factor(self._S, check_finite=False)
                    out[:].real = _cho_solve((c, low), grad)
                else:
                    out[:].real, residuals, self._last_rank, s_vals = _lstsq(
                        self._S.real,
                        grad.real,
                        cond=self._svd_threshold,
                        lapack_driver=self._lapack_driver,
                    )

                self._revert_preconditioning(out.real)

            out.imag.fill(0.0)

        if _n_nodes > 1:
            self._comm.bcast(out, root=0)
            self._comm.barrier()

        return out
def _subtract_mean_from_oks(oks):
    return oks - _sum_inplace(jnp.sum(oks, axis=0) / (oks.shape[0] * n_nodes))
Пример #6
0
 def acceptance(self):
     """The measured acceptance probability."""
     return _sum_inplace(self._accepted_samples) / _sum_inplace(
         self._total_samples)