def transition(self, state, state_1, log_prob_corr): for i in range(state.shape[0]): state_1[i] = state[i] si = _random.randint(0, self.size, size=()) rs = _random.randint(0, self.n_states - 1, size=()) state_1[i, si] = self.local_states[ rs + (self.local_states[rs] >= state[i, si])] log_prob_corr.fill(0.0)
def _choose(states, sections, out, w): low_range = 0 for i, s in enumerate(sections): n_rand = _random.randint(low_range, s, size=()) out[i] = states[n_rand] w[i] = math.log(s - low_range) low_range = s
def test_correct_sampling(): for name, sa in samplers.items(): print("Sampler test: %s" % name) ma = sa.machine hi = ma.hilbert if ma.input_size == 2 * hi.size: hi = nk.hilbert.DoubledHilbert(hi) n_states = hi.n_states n_samples = max(40 * n_states, 10000) ord = randint(1, 3, size=()).item() assert ord == 1 or ord == 2 sa.machine_pow = ord ps = np.absolute(ma.to_array())**ord ps /= ps.sum() n_rep = 6 pvalues = np.zeros(n_rep) sa.reset(True) for jrep in range(n_rep): # Burnout phase samples = sa.generate_samples(n_samples // 10) assert (samples.shape[1], samples.shape[2]) == sa.sample_shape samples = sa.generate_samples(n_samples) assert samples.shape[2] == ma.input_size sttn = hi.states_to_numbers( np.asarray(samples.reshape(-1, ma.input_size))) n_s = sttn.size # fill in the histogram for sampler unique, counts = np.unique(sttn, return_counts=True) hist_samp = np.zeros(n_states) hist_samp[unique] = counts # expected frequencies f_exp = n_s * ps statistics, pvalues[jrep] = chisquare(hist_samp, f_exp=f_exp) s, pval = combine_pvalues(pvalues, method="fisher") assert pval > 0.01 or np.max(pvalues) > 0.01
def _transition(state, state_1, log_prob_corr, clusters): clusters_size = clusters.shape[0] for k in range(state.shape[0]): state_1[k] = state[k] # pick a random cluster cl = _random.randint(0, clusters_size, size=()) # sites to be exchanged si = clusters[cl][0] sj = clusters[cl][1] state_1[k, si], state_1[k, sj] = state[k, sj], state[k, si] log_prob_corr[:] = 0.0
def __init__(self, machine, kernel, n_chains=16, sweep_size=None, rng_key=None): super().__init__(machine, n_chains) self._random_state_kernel = jax.jit(kernel.random_state) self._transition_kernel = jax.jit(kernel.transition) self._rng_key = rng_key if rng_key is None: self._rng_key = jax.random.PRNGKey( _random.randint(low=0, high=2 ** 32, size=()).item() ) self.machine_pow = 2 self.n_chains = n_chains self.sweep_size = sweep_size
def _exchange_step_kernel(log_values, machine_pow, beta, proposed_beta, prob, beta_stats, accepted_samples): # Choose a random swap order (odd/even swap) swap_order = _random.randint(0, 2, size=()).item() n_replicas = beta.shape[0] for i in range(swap_order, n_replicas, 2): inn = (i + 1) % n_replicas proposed_beta[i] = beta[inn] proposed_beta[inn] = beta[i] for i in range(n_replicas): prob[i] = math.exp(machine_pow * (proposed_beta[i] - beta[i]) * log_values[i].real) for i in range(swap_order, n_replicas, 2): inn = (i + 1) % n_replicas prob[i] *= prob[inn] if prob[i] > _random.uniform(0, 1): # swapping status beta[i], beta[inn] = beta[inn], beta[i] accepted_samples[i], accepted_samples[inn] = ( accepted_samples[inn], accepted_samples[i], ) if beta_stats[0] == i: beta_stats[0] = inn elif beta_stats[0] == inn: beta_stats[0] = i # Update statistics to compute diffusion coefficient of replicas # Total exchange steps performed beta_stats[-1] += 1 delta = beta_stats[0] - beta_stats[1] beta_stats[1] += delta / float(beta_stats[-1]) delta2 = beta_stats[0] - beta_stats[1] beta_stats[2] += delta * delta2
def random_state(self, state): for i in range(state.shape[0]): for si in range(state.shape[1]): rs = _random.randint(0, self.n_states, size=()) state[i, si] = self.local_states[rs]