def Odagger_DeltaO_v(samples, params, v, forward_fn, vjp_fun=None): r""" compute \langle O^\dagger \DeltaO \rangle v where \DeltaO = O - \langle O \rangle optional: pass jvp_fun to be reused """ # here the allreduce is deferred until after the dot product, # where only scalars instead of vectors have to be summed # the vjp_fun is returned so that it can be reused in OH_w below omean, vjp_fun = O_mean( samples, params, forward_fn, return_vjp_fun=True, vjp_fun=vjp_fun, allreduce=False, ) omeanv = tree_dot(omean, v) # omeanv = omean.dot(v); is a scalar omeanv = _sum_inplace(omeanv) # MPI Allreduce w/ MPI_SUM # v_tilde is an array of size n_samples; each MPI rank has its own slice v_tilde = O_jvp(samples, params, v, forward_fn) # v_tilde -= omeanv (elementwise): v_tilde = v_tilde - omeanv # v_tilde /= n_samples (elementwise): v_tilde = v_tilde * (1.0 / (samples.shape[0] * n_nodes)) return OH_w(samples, params, v_tilde, forward_fn, vjp_fun=vjp_fun)
def matvec(v): v_tilde = self._v_tilde res = self._res_t v_tilde = _np.matmul(oks, v, v_tilde) / float(n_samp) res = _np.matmul(v_tilde, oks_conj, res) res = _sum_inplace(res) + shift * v return res
def _S_grad_mul(oks, v): r""" Computes y = 1/N * ( O^\dagger * O * v ) where v is a vector of length n_parameters, and O is a matrix (n_samples, n_parameters) """ # TODO apply the transpose of sum_inplace (allreduce) to v here # in order to get correct transposition with MPI v_tilde = jnp.matmul(oks, v) / (oks.shape[0] * n_nodes) y = jnp.matmul(oks.conjugate().transpose(), v_tilde) return _sum_inplace(y)
def compute_update(self, oks, grad, out=None): r""" Solves the SR flow equation for the parameter update ẋ. The SR update is computed by solving the linear equation Sẋ = f where S is the covariance matrix of the partial derivatives O_i(v_j) = ∂/∂x_i log Ψ(v_j) and f is a generalized force (the loss gradient). Args: oks: The matrix of log-derivatives, O_i(v_j) grad: The vector of forces f. out: Output array for the update ẋ. """ oks -= _mean(oks, axis=0) if self.has_complex_parameters is None: raise ValueError( "has_complex_parameters not set: this SR object is not properly initialized." ) n_samp = _sum_inplace(_np.atleast_1d(oks.shape[0])) n_par = grad.shape[0] if out is None: out = _np.zeros(n_par, dtype=_np.complex128) if self._has_complex_parameters: if self._use_iterative: op = self._linear_operator(oks, n_samp) if self._x0 is None: self._x0 = _np.zeros(n_par, dtype=_np.complex128) out[:], info = self._sparse_solver( op, grad, x0=self._x0, tol=self.sparse_tol, maxiter=self.sparse_maxiter, ) if info < 0: raise RuntimeError("SR sparse solver did not converge.") self._x0 = out else: self._S = _np.matmul(oks.conj().T, oks, self._S) self._S = _sum_inplace(self._S) self._S /= float(n_samp) self._apply_preconditioning(grad) if self._lsq_solver == "Cholesky": c, low = _cho_factor(self._S, check_finite=False) out[:] = _cho_solve((c, low), grad) else: out[:], residuals, self._last_rank, s_vals = _lstsq( self._S, grad, cond=self._svd_threshold, lapack_driver=self._lapack_driver, ) self._revert_preconditioning(out) else: if self._use_iterative: op = self._linear_operator(oks, n_samp) if self._x0 is None: self._x0 = _np.zeros(n_par) out[:].real, info = self._sparse_solver( op, grad.real, x0=self._x0, tol=self.sparse_tol, maxiter=self.sparse_maxiter, ) if info < 0: raise RuntimeError("SR sparse solver did not converge.") self._x0 = out.real else: self._S = _np.matmul(oks.conj().T, oks, self._S) self._S /= float(n_samp) self._apply_preconditioning(grad) if self._lsq_solver == "Cholesky": c, low = _cho_factor(self._S, check_finite=False) out[:].real = _cho_solve((c, low), grad) else: out[:].real, residuals, self._last_rank, s_vals = _lstsq( self._S.real, grad.real, cond=self._svd_threshold, lapack_driver=self._lapack_driver, ) self._revert_preconditioning(out.real) out.imag.fill(0.0) if _n_nodes > 1: self._comm.bcast(out, root=0) self._comm.barrier() return out
def _subtract_mean_from_oks(oks): return oks - _sum_inplace(jnp.sum(oks, axis=0) / (oks.shape[0] * n_nodes))
def acceptance(self): """The measured acceptance probability.""" return _sum_inplace(self._accepted_samples) / _sum_inplace( self._total_samples)