def _init_state(sampler, machine, parameters, key): rgen = np.random.default_rng(np.asarray(key)) σ = np.zeros((sampler.n_batches, sampler.hilbert.size), dtype=sampler.dtype) ma_out = jax.eval_shape(machine.apply, parameters, σ) state = MetropolisNumpySamplerState( σ=σ, σ1=np.copy(σ), log_values=np.zeros(sampler.n_batches, dtype=ma_out.dtype), log_values_1=np.zeros(sampler.n_batches, dtype=ma_out.dtype), log_prob_corr=np.zeros( sampler.n_batches, dtype=nkjax.dtype_real(ma_out.dtype) ), rng=rgen, rule_state=sampler.rule.init_state(sampler, machine, parameters, rgen), ) if not sampler.reset_chains: key = jnp.asarray( state.rng.integers(0, 1 << 32, size=2, dtype=np.uint32), dtype=np.uint32 ) state.σ = np.copy( sampler.rule.random_state(sampler, machine, parameters, state, key) ) return state
def __call__(self, σr, σc): U_S = nknn.Dense( name="Symm", features=int(self.alpha * σr.shape[-1]), dtype=self.dtype, use_bias=False, kernel_init=self.kernel_init, precision=self.precision, ) U_A = nknn.Dense( name="ASymm", features=int(self.alpha * σr.shape[-1]), dtype=self.dtype, use_bias=False, kernel_init=self.kernel_init, precision=self.precision, ) y = U_S(0.5 * (σr + σc)) + 1j * U_A(0.5 * (σr - σc)) if self.use_bias: bias = self.param( "bias", self.bias_init, (int(self.alpha * σr.shape[-1]),), nkjax.dtype_real(self.dtype), ) y = y + bias y = self.activation(y) return y.sum(axis=-1)
def random_state(hilb: Particle, key, batches: int, *, dtype): """Positions particles w.r.t. normal distribution, if no periodic boundary conditions are applied in a spatial dimension. Otherwise the particles are positioned evenly along the box from 0 to L, with Gaussian noise of certain width.""" pbc = np.array(hilb.n_particles * hilb.pbc) boundary = np.tile(pbc, (batches, 1)) Ls = np.array(hilb.n_particles * hilb.extent) modulus = np.where(np.equal(pbc, False), jnp.inf, Ls) min_modulus = np.min(modulus) # use real dtypes because this does not work with complex ones. gaussian = jax.random.normal(key, shape=(batches, hilb.size), dtype=nkjax.dtype_real(dtype)) width = min_modulus / (4.0 * hilb.n_particles) # The width gives the noise level. In the periodic case the # particles are evenly distributed between 0 and min(L). The # distance between the particles coordinates is therefore given by # min(L) / hilb.N. To avoid particles to have coincident # positions the noise level should be smaller than half this distance. # We choose width = min(L) / (4*hilb.N) noise = gaussian * width uniform = jnp.tile(jnp.linspace(0.0, min_modulus, hilb.size), (batches, 1)) select = np.equal(boundary, False) rs = select * gaussian + np.logical_not(select) * ( (uniform + noise) % modulus) return jnp.asarray(rs, dtype=dtype)
def _statistics(data, batch_size): data = jnp.atleast_1d(data) if data.ndim == 1: data = data.reshape((1, -1)) if data.ndim > 2: raise NotImplementedError("Statistics are implemented only for ndim<=2") mean = _mean(data) variance = _var(data) ts = _total_size(data) bare_var = variance batch_var, n_batches = _batch_variance(data) l_block = max(1, data.shape[1] // batch_size) block_var, n_blocks = _block_variance(data, l_block) tau_batch = ((ts / n_batches) * batch_var / bare_var - 1) * 0.5 tau_block = ((ts / n_blocks) * block_var / bare_var - 1) * 0.5 batch_good = (tau_batch < 6 * data.shape[1]) * (n_batches >= batch_size) block_good = (tau_block < 6 * l_block) * (n_blocks >= batch_size) stat_dtype = nkjax.dtype_real(data.dtype) # if batch_good: # error_of_mean = jnp.sqrt(batch_var / n_batches) # tau_corr = jnp.max(0, tau_batch) # elif block_good: # error_of_mean = jnp.sqrt(block_var / n_blocks) # tau_corr = jnp.max(0, tau_block) # else: # error_of_mean = jnp.nan # tau_corr = jnp.nan # jax style def batch_good_err(args): batch_var, tau_batch, *_ = args error_of_mean = jnp.sqrt(batch_var / n_batches) tau_corr = jnp.clip(tau_batch, 0) return jnp.asarray(error_of_mean, dtype=stat_dtype), jnp.asarray( tau_corr, dtype=stat_dtype ) def block_good_err(args): _, _, block_var, tau_block = args error_of_mean = jnp.sqrt(block_var / n_blocks) tau_corr = jnp.clip(tau_block, 0) return jnp.asarray(error_of_mean, dtype=stat_dtype), jnp.asarray( tau_corr, dtype=stat_dtype ) def nan_err(args): return jnp.asarray(jnp.nan, dtype=stat_dtype), jnp.asarray( jnp.nan, dtype=stat_dtype ) def batch_not_good(args): batch_var, tau_batch, block_var, tau_block, block_good = args return jax.lax.cond( block_good, block_good_err, nan_err, (batch_var, tau_batch, block_var, tau_block), ) error_of_mean, tau_corr = jax.lax.cond( batch_good, batch_good_err, batch_not_good, (batch_var, tau_batch, block_var, tau_block, block_good), ) if n_batches > 1: N = data.shape[-1] # V_loc = _np.var(data, axis=-1, ddof=0) # W_loc = _np.mean(V_loc) # W = _mean(W_loc) # # This approximation seems to hold well enough for larger n_samples W = variance R_hat = jnp.sqrt((N - 1) / N + batch_var / W) else: R_hat = jnp.nan res = Stats(mean, error_of_mean, variance, tau_corr, R_hat) return res
def _statistics(data, batch_size): data = jnp.atleast_1d(data) if data.ndim == 1: data = data.reshape((1, -1)) if data.ndim > 2: raise NotImplementedError( "Statistics are implemented only for ndim<=2") mean = _mean(data) variance = _var(data) ts = _total_size(data) bare_var = variance batch_var, n_batches = _batch_variance(data) l_block = max(1, data.shape[1] // batch_size) block_var, n_blocks = _block_variance(data, l_block) tau_batch = ((ts / n_batches) * batch_var / bare_var - 1) * 0.5 tau_block = ((ts / n_blocks) * block_var / bare_var - 1) * 0.5 batch_good = (tau_batch < 6 * data.shape[1]) * (n_batches >= batch_size) block_good = (tau_block < 6 * l_block) * (n_blocks >= batch_size) stat_dtype = nkjax.dtype_real(data.dtype) # if batch_good: # error_of_mean = jnp.sqrt(batch_var / n_batches) # tau_corr = jnp.max(0, tau_batch) # elif block_good: # error_of_mean = jnp.sqrt(block_var / n_blocks) # tau_corr = jnp.max(0, tau_block) # else: # error_of_mean = jnp.nan # tau_corr = jnp.nan # jax style def batch_good_err(args): batch_var, tau_batch, *_ = args error_of_mean = jnp.sqrt(batch_var / n_batches) tau_corr = jnp.clip(tau_batch, 0) return jnp.asarray(error_of_mean, dtype=stat_dtype), jnp.asarray(tau_corr, dtype=stat_dtype) def block_good_err(args): _, _, block_var, tau_block = args error_of_mean = jnp.sqrt(block_var / n_blocks) tau_corr = jnp.clip(tau_block, 0) return jnp.asarray(error_of_mean, dtype=stat_dtype), jnp.asarray(tau_corr, dtype=stat_dtype) def nan_err(args): return jnp.asarray(jnp.nan, dtype=stat_dtype), jnp.asarray(jnp.nan, dtype=stat_dtype) def batch_not_good(args): batch_var, tau_batch, block_var, tau_block, block_good = args return jax.lax.cond( block_good, block_good_err, nan_err, (batch_var, tau_batch, block_var, tau_block), ) error_of_mean, tau_corr = jax.lax.cond( batch_good, batch_good_err, batch_not_good, (batch_var, tau_batch, block_var, tau_block, block_good), ) if n_batches > 1: N = data.shape[-1] if not config.FLAGS["NETKET_USE_PLAIN_RHAT"]: # compute split-chain batch variance local_batch_size = data.shape[0] if N % 2 == 0: # split each chain in the middle, # like [[1 2 3 4]] -> [[1 2][3 4]] batch_var, _ = _batch_variance( data.reshape(2 * local_batch_size, N // 2)) else: # drop the last sample of each chain for an even split, # like [[1 2 3 4 5]] -> [[1 2][3 4]] batch_var, _ = _batch_variance(data[:, :-1].reshape( 2 * local_batch_size, N // 2)) # V_loc = _np.var(data, axis=-1, ddof=0) # W_loc = _np.mean(V_loc) # W = _mean(W_loc) # # This approximation seems to hold well enough for larger n_samples W = variance R_hat = jnp.sqrt((N - 1) / N + batch_var / W) else: R_hat = jnp.nan res = Stats(mean, error_of_mean, variance, tau_corr, R_hat) return res