def cosine_decay_scheduler(step, steps_per_cycle, t_mul=1, m_mul=1., alpha=0.): """Gives a scaling factor based on scheduling with a cosine decay. Args: step: int; Current step. steps_per_cycle: int; Number of steps to reset the decay cycle. t_mul: int; Used to derive the number of iterations in the i-th period. m_mul: float; Used to derive the initial learning rate of the i-th period. alpha: float; The minimum value as a fraction of the initial value. Returns: Scaling factor applied to the learning rate on the given step. """ progress = step / float(steps_per_cycle) if t_mul == 1.0: i_restart = jnp.floor(progress) progress -= i_restart else: i_restart = jnp.floor( jnp.log(1.0 - progress * (1.0 - t_mul)) / jnp.log(t_mul)) sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) progress = (progress - sum_r) / t_mul**i_restart m_fac = m_mul**i_restart cosine_decay = jnp.maximum( 0.0, 0.5 * m_fac * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) return (1 - alpha) * cosine_decay + alpha
def delta_r(ri, rj, box=None): diff = ri - rj # this can be either N,N,3 or B,3 if box is not None: diff -= box[2]*np.floor(np.expand_dims(diff[...,2], axis=-1)/box[2][2]+0.5) diff -= box[1]*np.floor(np.expand_dims(diff[...,1], axis=-1)/box[1][1]+0.5) diff -= box[0]*np.floor(np.expand_dims(diff[...,0], axis=-1)/box[0][0]+0.5) return diff
def interpolate_bilinear( # type: ignore im: np.ndarray, rows: np.ndarray, cols: np.ndarray ) -> np.ndarray: # based on http://stackoverflow.com/a/12729229 col_lo = np.floor(cols).astype(int) col_hi = col_lo + 1 row_lo = np.floor(rows).astype(int) row_hi = row_lo + 1 def cclip(cols: np.ndarray) -> np.ndarray: # type: ignore return np.clip(cols, 0, ncols - 1) def rclip(rows: np.ndarray) -> np.ndarray: # type: ignore return np.clip(rows, 0, nrows - 1) nrows, ncols = im.shape[-3:-1] Ia = im[..., rclip(row_lo), cclip(col_lo), :] Ib = im[..., rclip(row_hi), cclip(col_lo), :] Ic = im[..., rclip(row_lo), cclip(col_hi), :] Id = im[..., rclip(row_hi), cclip(col_hi), :] wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1) wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1) wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1) wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1) return wa * Ia + wb * Ib + wc * Ic + wd * Id
def __init__(self, lower, upper): self.support = constraints.integer_interval(lower, upper) self.event_shape = () self.batch_shape = broadcast_batch_shape(np.shape(lower), np.shape(upper)) self.lower = np.floor(lower) self.upper = np.floor(upper)
def __init__(self, lower, upper): self.support = constraints.integer_interval(lower, upper) self.event_shape = () lower, upper = promote_shapes(lower, upper) batch_shape = lax.broadcast_shapes(jnp.shape(lower), jnp.shape(upper)) self.batch_shape = batch_shape self.lower = jnp.broadcast_to(jnp.floor(lower), batch_shape) self.upper = jnp.broadcast_to(jnp.floor(upper), batch_shape)
def bilinear_resample(x, warp): batch_shape = x.shape[:-3] input_image_dims = x.shape[-3:-1] batch_shape = list(batch_shape) input_image_dims = list(input_image_dims) num_feats = x.shape[-1] # image statistics height, width = input_image_dims max_x = width - 1 max_y = height - 1 idx_size = _reduce(_mul, warp.shape[-3:-1], 1) batch_shape_flat = _reduce(_mul, batch_shape, 1) # B batch_offsets = _jnp.arange(batch_shape_flat) * idx_size # B x (HxW) base_grid = _jnp.tile(_jnp.expand_dims(batch_offsets, 1), [1, idx_size]) # (BxHxW) base = _jnp.reshape(base_grid, [-1]) # (BxHxW) x D data_flat = _jnp.reshape(x, [batch_shape_flat * height * width, -1]) # (BxHxW) x 2 warp_flat = _jnp.reshape(warp, [-1, 2]) warp_floored = (_jnp.floor(warp_flat)).astype(_jnp.int32) bilinear_weights = warp_flat - _jnp.floor(warp_flat) # (BxHxW) x0 = warp_floored[:, 0] x1 = x0 + 1 y0 = warp_floored[:, 1] y1 = y0 + 1 x0 = _jnp.clip(x0, 0, max_x) x1 = _jnp.clip(x1, 0, max_x) y0 = _jnp.clip(y0, 0, max_y) y1 = _jnp.clip(y1, 0, max_y) base_y0 = base + y0 * width base_y1 = base + y1 * width idx_a = base_y0 + x0 idx_b = base_y1 + x0 idx_c = base_y0 + x1 idx_d = base_y1 + x1 # (BxHxW) x D Ia = _jnp.take(data_flat, idx_a, axis=0) Ib = _jnp.take(data_flat, idx_b, axis=0) Ic = _jnp.take(data_flat, idx_c, axis=0) Id = _jnp.take(data_flat, idx_d, axis=0) # (BxHxW) xw = bilinear_weights[:, 0] yw = bilinear_weights[:, 1] # (BxHxW) x 1 wa = _jnp.expand_dims((1 - xw) * (1 - yw), 1) wb = _jnp.expand_dims((1 - xw) * yw, 1) wc = _jnp.expand_dims(xw * (1 - yw), 1) wd = _jnp.expand_dims(xw * yw, 1) # (BxHxW) x D resampled_flat = wa * Ia + wb * Ib + wc * Ic + wd * Id # B x H x W x D return _jnp.reshape(resampled_flat, batch_shape + [-1, num_feats])
def fourierfreq(ncoeff, delta, CONDTHRESH=1e8): maxfreq = np.floor(ncoeff / (np.pi * delta) * np.sqrt(0.5 * np.log(CONDTHRESH))).astype(int) # wvec = np.hstack([np.arange(maxfreq+1), np.arange(-maxfreq+1, 0)]) if maxfreq < ncoeff / 2: wvec = np.hstack([np.arange(maxfreq + 1), np.arange(-maxfreq + 1, 0)]) else: ncos = np.ceil((ncoeff + 1) / 2) nsin = np.floor((ncoeff - 1) / 2) wvec = np.hstack([np.arange(ncos), np.arange(-nsin, 0)]) return wvec
def recenter(conf, b): periodicBoxSize = jnp.array([ [b, 0.], [0., b] ]) diff = jnp.zeros_like(conf) diff += jnp.expand_dims(periodicBoxSize[1], axis=0)*jnp.expand_dims(jnp.floor(conf[:, 1]/periodicBoxSize[1][1]), axis=-1) diff += jnp.expand_dims(periodicBoxSize[0], axis=0)*jnp.expand_dims(jnp.floor((conf[:, 0]-diff[:, 0])/periodicBoxSize[0][0]), axis=-1) return conf - diff
def recenter(conf, b): new_coords = [] periodicBoxSize = jnp.array([[b, 0.0], [0.0, b]]) for atom in conf: diff = jnp.array([0.0, 0.0]) diff += periodicBoxSize[1] * jnp.floor(atom[1] / periodicBoxSize[1][1]) diff += periodicBoxSize[0] * jnp.floor( (atom[0] - diff[0]) / periodicBoxSize[0][0]) new_coords.append(atom - diff) return np.array(new_coords)
def rescale_coordinates(conf, indices, box, scales): mol_sizes = np.expand_dims(onp.bincount(indices), axis=1) mol_centers = jax.ops.segment_sum(coords, indices) / mol_sizes new_centers = mol_centers - box[2] * np.floor( np.expand_dims(mol_centers[..., 2], axis=-1) / box[2][2]) new_centers -= box[1] * np.floor( np.expand_dims(new_centers[..., 1], axis=-1) / box[1][1]) new_centers -= box[0] * np.floor( np.expand_dims(new_centers[..., 0], axis=-1) / box[0][0]) offset = new_centers - mol_centers return conf + offset[indices]
def interp_regular_1d(x: np.ndarray, xmin: float, xmax: float, yp: np.ndarray) -> np.ndarray: """One-dimensional linear interpolation. Returns the one-dimensional piecewise linear interpolation of the data points (xp, yp) evaluated at x. We extrapolate with the constants xmin and xmax outside the range [xmin, xmax]. Args: x: The x-coordinates at which to evaluate the interpolated values. xmin: The lower bound of the regular input x-coordinate grid. xmax: The upper bound of the regular input x-coordinate grid. yp: The y coordinates of the data points. Returns: y: The interpolated values, same shape as x. """ ny = len(yp) fractional_idx = (x - xmin) / (xmax - xmin) x_idx_unclipped = fractional_idx * (ny - 1) x_idx = np.clip(x_idx_unclipped, 0, ny - 1) idx_below = np.floor(x_idx) idx_above = np.minimum(idx_below + 1, ny - 1) idx_below = np.maximum(idx_above - 1, 0) y_ref_below = yp[idx_below.astype(np.int32)] y_ref_above = yp[idx_above.astype(np.int32)] t = x_idx - idx_below y = t * y_ref_above + (1 - t) * y_ref_below return y
def hsv_to_rgb(h, s, v): """Converts H, S, V values to an R, G, B tuple. Reference TF implementation: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc Only input values between 0 and 1 are guaranteed to work properly, but this function complies with the TF implementation outside of this range. Args: h: A float tensor of arbitrary shape for the hue (0-1 values). s: A float tensor of the same shape for the saturation (0-1 values). v: A float tensor of the same shape for the value channel (0-1 values). Returns: An (r, g, b) tuple, each with the same dimension as the inputs. """ c = s * v m = v - c dh = (h % 1.) * 6. fmodu = dh % 2. x = c * (1 - jnp.abs(fmodu - 1)) hcat = jnp.floor(dh).astype(jnp.int32) rr = jnp.where( (hcat == 0) | (hcat == 5), c, jnp.where( (hcat == 1) | (hcat == 4), x, 0)) + m gg = jnp.where( (hcat == 1) | (hcat == 2), c, jnp.where( (hcat == 0) | (hcat == 3), x, 0)) + m bb = jnp.where( (hcat == 3) | (hcat == 4), c, jnp.where( (hcat == 2) | (hcat == 5), x, 0)) + m return rr, gg, bb
def _cell_dimensions(spatial_dimension, box_size, minimum_cell_size): """Compute the number of cells-per-side and total number of cells in a box.""" if isinstance(box_size, int) or isinstance(box_size, float): box_size = f32(box_size) if (isinstance(box_size, np.ndarray) and (box_size.dtype == np.int32 or box_size.dtype == np.int64)): box_size = f32(box_size) cells_per_side = np.floor(box_size / minimum_cell_size) cell_size = box_size / cells_per_side cells_per_side = np.array(cells_per_side, dtype=np.int64) if isinstance(box_size, np.ndarray): flat_cells_per_side = np.reshape(cells_per_side, (-1,)) for cells in flat_cells_per_side: if cells < 3: raise ValueError( ('Box must be at least 3x the size of the grid spacing in each ' 'dimension.')) cell_count = reduce(mul, flat_cells_per_side, 1) else: cell_count = cells_per_side ** spatial_dimension return box_size, cell_size, cells_per_side, int(cell_count)
def _cell_dimensions(spatial_dimension, box_size, minimum_cell_size): """Compute the number of cells-per-side and total number of cells in a box.""" if isinstance(box_size, int) or isinstance(box_size, float): box_size = f32(box_size) # NOTE(schsam): Should we auto-cast based on box_size? I can't imagine a case # in which the box_size would not be accurately represented by an f32. if (isinstance(box_size, np.ndarray) and (box_size.dtype == np.int32 or box_size.dtype == np.int64)): box_size = f32(box_size) cells_per_side = np.floor(box_size / minimum_cell_size) cell_size = box_size / cells_per_side cells_per_side = np.array(cells_per_side, dtype=np.int64) if isinstance(box_size, np.ndarray): if box_size.ndim == 1 or box_size.ndim == 2: assert box_size.size == spatial_dimension flat_cells_per_side = np.reshape(cells_per_side, (-1, )) for cells in flat_cells_per_side: if cells < 3: raise ValueError(( 'Box must be at least 3x the size of the grid spacing in each ' 'dimension.')) cell_count = reduce(mul, flat_cells_per_side, 1) elif box_size.ndim == 0: cell_count = cells_per_side**spatial_dimension else: raise ValueError('Box must either be a scalar or a vector.') else: cell_count = cells_per_side**spatial_dimension return box_size, cell_size, cells_per_side, int(cell_count)
def two_normalize(m): # Divide m by a power of 2 to get its norm close to 1 norm = np.linalg.norm(m, axis=(2, 3), keepdims=True) two_pow = np.floor(np.log2(norm)) stable_m = m / (2**two_pow) return stable_m, np.sum(two_pow, axis=0)
def realfftbasis(nx): """ Basis of sines+cosines for nn-point discrete fourier transform (DFT). Ported from MatLab code: https://github.com/leaduncker/SimpleEvidenceOpt/blob/master/util/realfftbasis.m """ import numpy as np nn = nx ncos = np.ceil((nn + 1) / 2) nsin = np.floor((nn - 1) / 2) wvec = np.hstack( [np.arange(start=0., stop=ncos), np.arange(start=-nsin, stop=0.)]) wcos = wvec[wvec >= 0] wsin = wvec[wvec < 0] x = np.arange(nx) t0 = np.cos(np.outer(wcos * 2 * np.pi / nn, x)) t1 = np.sin(np.outer(wsin * 2 * np.pi / nn, x)) B = np.vstack([t0, t1]) / np.sqrt(nn * 0.5) return B, wvec
def sample(self, key, sample_shape=()): assert is_prng_key(key) probs = self.probs dtype = jnp.result_type(probs) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
def test_flip_state_fock_infinite(): hi = Fock(N=2) rng = nk.jax.PRNGSeq(1) N_batches = 20 states = hi.random_state(rng.next(), N_batches, dtype=jnp.int64) ids = jnp.asarray( jnp.floor(hi.size * jax.random.uniform(rng.next(), shape=(N_batches, ))), dtype=int, ) new_states, old_vals = nk.hilbert.random.flip_state( hi, rng.next(), states, ids) assert new_states.shape == states.shape assert np.all(states >= 0) states_np = np.asarray(states) states_new_np = np.array(new_states) for (row, col) in enumerate(ids): states_new_np[row, col] = states_np[row, col] np.testing.assert_allclose(states_np, states_new_np)
def uniform_stochastic_quantize(v: jnp.ndarray, num_levels: int, rng: PRNGKey, v_min: Optional[float] = None, v_max: Optional[float] = None) -> jnp.ndarray: """Uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf. Args: v: vector to be quantized. num_levels: Number of levels of quantization. rng: jax random key. v_min: minimum threshold for quantization. If None, sets it to jnp.amin(v). v_max: maximum threshold for quantization. If None, sets it to jnp.amax(v). Returns: Quantized array. """ # Rescale the vector to be between zero to one. if v_min is None: v_min = jnp.amin(v) if v_max is None: v_max = jnp.amax(v) v = jnp.nan_to_num((v - v_min) / (v_max - v_min)) v = jnp.maximum(0., jnp.minimum(v, 1.)) # Compute the upper and lower boundary of each value. v_ceil = jnp.ceil(v * (num_levels - 1)) / (num_levels - 1) v_floor = jnp.floor(v * (num_levels - 1)) / (num_levels - 1) # uniformly quantize between v_ceil and v_floor. rand = jax.random.uniform(key=rng, shape=v.shape) threshold = jnp.nan_to_num((v - v_floor) / (v_ceil - v_floor)) quantized = jnp.where(rand > threshold, v_floor, v_ceil) # Rescale the values and return it. return v_min + quantized * (v_max - v_min)
def bcL(self, rng=None): """bcL creates a random boundary condition for a L-shaped domain. The boundary is a random 3rd order polynomial of sine functions. rng variable allows to reproduce the results. Sine functions are chosen so that the boundary is periodic and does not have discontinuities. """ if rng is None: rng = random.PRNGKey(1) n = self.n x = onp.sin(self.bcmesh * np.pi) n_y = (np.floor((n + 1) / 2) - 1).astype(int) if rng is not None: coeffs = random.multivariate_normal(rng, np.zeros(16), np.diag(np.ones(16))) else: key = random.randint(random.PRNGKey(1), (1,), 1, 1000) coeffs = random.multivariate_normal( random.PRNGKey(key[0]), np.zeros(16), np.diag(np.ones(16))) left = coeffs[0] * x**3 + coeffs[1] * x**2 + coeffs[2] * x #+ coeffs[3] right = coeffs[4] * x**3 + coeffs[5] * x**2 + coeffs[6] * x #+ coeffs[7] lower = coeffs[8] * x**3 + coeffs[9] * x**2 + coeffs[10] * x #+ coeffs[11] upper = coeffs[12] * x**3 + coeffs[13] * x**2 + coeffs[14] * x #+ coeffs[15] shape = 2 * x.shape source = onp.zeros(shape) source[0, :] = upper source[n_y - 1, n_y - 1:] = lower[:n - n_y + 1] source[n_y - 1:, n_y - 1] = right[:n - n_y + 1] source[:, 0] = left source[-1, :n_y - 1] = right[n:n - n_y:-1] source[:n_y - 1, -1] = lower[n:n - n_y:-1] # because this makes the correct order of boundary conditions return source * (n + 1)**2
def call(self, inputs: Mapping[str, jnp.ndarray], rng: PRNGKey, sample: Optional[bool]=False, **kwargs ) -> Mapping[str, jnp.ndarray]: x = inputs["x"] x_shape = self.get_unbatched_shapes(sample)["x"] log_det = -jnp.zeros(self.batch_shape) flow = self.flow if self.flow is not None else self.default_flow() if sample == False: flow_inputs = {"x": jnp.zeros(x.shape), "condition": x} outputs = flow(flow_inputs, rng, sample=True) noise = outputs["x"] z = x + noise log_qugx = outputs["log_det"] + outputs["log_pz"] log_det -= log_qugx else: z_continuous = x z = jnp.floor(z_continuous).astype(jnp.int32) noise = z_continuous - z flow_inputs = {"x": noise, "condition": x} outputs = flow(flow_inputs, rng, sample=False) log_qugx = outputs["log_det"] + outputs["log_pz"] log_det -= log_qugx return {"x": z, "log_det": log_det}
def __call__(self, x, mask_props=None, is_training=True): out = hk.Flatten()(x) for l in range(self.nlayers): out = hk.Linear(self.nhid, with_bias=self.with_bias)(out) if self.batch_norm: out = hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(out, is_training) if mask_props is not None: num_units = jnp.floor(mask_props[l] * out.shape[1]).astype( jnp.int32) mask = jnp.arange(out.shape[1]) < num_units out = jnp.where(mask, out, jnp.zeros(out.shape)) if self.activation == 'relu': out = jax.nn.relu(out) elif self.activation == 'sigmoid': out = jax.nn.sigmoid(out) elif self.activation == 'tanh': out = jnp.tanh(out) elif self.activation == 'linear': out = out out = hk.Linear(10, with_bias=self.with_bias)(out) return out
def _binom_inv_body_fn(val): i, key, geom_acc = val key, key_u = random.split(key) u = random.uniform(key_u) geom = np.floor(np.log1p(-u) / log1_p) + 1 geom_acc = geom_acc + geom return i + 1, key, geom_acc
def interpolate1d(x, values, tangents): r"""Perform cubic hermite spline interpolation on a 1D spline. The x coordinates of the spline knots are at [0 : len(values)-1]. Queries outside of the range of the spline are computed using linear extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline for details, where "x" corresponds to `x`, "p" corresponds to `values`, and "m" corresponds to `tangents`. Args: x: A tensor containing the set of values to be used for interpolation into the spline. values: A vector containing the value of each knot of the spline being interpolated into. Must be the same length as `tangents`. tangents: A vector containing the tangent (derivative) of each knot of the spline being interpolated into. Must be the same length as `values` and the same type as `x`. Returns: The result of interpolating along the spline defined by `values`, and `tangents`, using `x` as the query values. Will be the same shape as `x`. """ assert len(values.shape) == 1 assert len(tangents.shape) == 1 assert values.shape[0] == tangents.shape[0] # Find the indices of the knots below and above each x. x_lo = jnp.int32(jnp.floor(jnp.clip(x, 0., values.shape[0] - 2))) x_hi = x_lo + 1 # Compute the relative distance between each `x` and the knot below it. t = x - x_lo # Compute the cubic hermite expansion of `t`. t_sq = t**2 t_cu = t * t_sq h01 = -2 * t_cu + 3 * t_sq h00 = 1 - h01 h11 = t_cu - t_sq h10 = h11 - t_sq + t # Linearly extrapolate above and below the extents of the spline for all # values. value_before = tangents[0] * t + values[0] value_after = tangents[-1] * (t - 1) + values[-1] # Cubically interpolate between the knots below and above each query point. neighbor_values_lo = jnp.take(values, x_lo) neighbor_values_hi = jnp.take(values, x_hi) neighbor_tangents_lo = jnp.take(tangents, x_lo) neighbor_tangents_hi = jnp.take(tangents, x_hi) value_mid = ( neighbor_values_lo * h00 + neighbor_values_hi * h01 + neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11) # Return the interpolated or extrapolated values for each query point, # depending on whether or not the query lies within the span of the spline. return jnp.where(t < 0., value_before, jnp.where(t > 1., value_after, value_mid))
def schedule(count): count -= transition_begin p = count / transition_steps if staircase: p = jnp.floor(p) return jnp.where( count <= 0, init_value, init_value * jnp.power(decay_rate, p))
def getE(M, e): """JAX autograd compatible version of the Solver of Kepler's Equation for the "eccentric anomaly", E. Args: M : Mean anomaly e : Eccentricity Returns: Eccentric anomaly """ pi = jnp.pi Mt = M - (jnp.floor(M / (2. * pi)) * 2. * pi) Mt = jnp.where(M > pi, 2. * pi - Mt, Mt) Mt = jnp.where(Mt == 0.0, 0.0, Mt) alpha = _alpha(e, Mt) d = _d(alpha, e) r = _r(alpha, d, Mt, e) q = _q(alpha, d, e, Mt) w = _w(r, q) E1 = _E1(d, r, w, q, Mt) f = _f01234(e, E1, Mt) d3 = _d3(E1, f) d4 = _d4(E1, f, d3) d5 = _d5(E1, f, d4) # Eq. 29 E5 = E1 + d5 # if flip: E5 = jnp.where(M > pi, 2. * pi - E5, E5) # E5 = 2. * pi - E5 E = E5 return E5
def sample(self, key, sample_shape=()): assert is_prng_key(key) logits = self.logits dtype = jnp.result_type(logits) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / -softplus(logits))
def logpdf(self, k): k = jnp.floor(k) unnormalized = xlogy(k, self.p) + xlog1py(self.n - k, -self.p) binomialcoeffln = gammaln(self.n + 1) - ( gammaln(k + 1) + gammaln(self.n - k + 1) ) return unnormalized + binomialcoeffln
def test_flip_state_discrete(hi: DiscreteHilbert): rng = nk.jax.PRNGSeq(1) N_batches = 20 states = hi.random_state(rng.next(), N_batches) ids = jnp.asarray( jnp.floor(hi.size * jax.random.uniform(rng.next(), shape=(N_batches, ))), dtype=int, ) new_states, old_vals = nk.hilbert.random.flip_state( hi, rng.next(), states, ids) assert new_states.shape == states.shape for state in states: assert all(val in hi.states_at_index(i) for i, val in enumerate(state)) states_np = np.asarray(states) states_new_np = np.array(new_states) for (row, col) in enumerate(ids): states_new_np[row, col] = states_np[row, col] np.testing.assert_allclose(states_np, states_new_np)
def test_flip_state(hi): rng = nk.jax.PRNGSeq(1) N_batches = 20 if isinstance(hi, DiscreteHilbert): local_states = hi.local_states states = hi.random_state(rng.next(), N_batches) ids = jnp.asarray( jnp.floor(hi.size * jax.random.uniform(rng.next(), shape=(N_batches, ))), dtype=int, ) new_states, old_vals = nk.hilbert.random.flip_state( hi, rng.next(), states, ids) assert new_states.shape == states.shape assert np.all(np.in1d(new_states.reshape(-1), local_states)) states_np = np.asarray(states) states_new_np = np.array(new_states) for (row, col) in enumerate(ids): states_new_np[row, col] = states_np[row, col] np.testing.assert_allclose(states_np, states_new_np)