def iter_func(position): for j in range(num_iters): j = jnp.uint32(j) upper = jnp.right_shift(position, bits_lower) lower = jnp.bitwise_and(position, mask_lower) mixer = hash_func_in(upper + seed_offst + j) tmp = jnp.bitwise_xor(lower, mixer) position = upper + (jnp.left_shift( jnp.bitwise_and(tmp, mask_lower), bits_upper)) return position
def threefry_seed(seed: int) -> jnp.ndarray: """Create a single raw threefry PRNG key given an integer seed. Args: seed: a 64- or 32-bit integer used as the value of the key. Returns: The PRNG key contents, modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros). """ # Avoid overflowerror in X32 mode by first converting ints to int64. # This breaks JIT invariance for large ints, but supports the common # use-case of instantiating with Python hashes in X32 mode. if isinstance(seed, int): seed_arr = jnp.asarray(np.int64(seed)) else: seed_arr = jnp.asarray(seed) if seed_arr.shape: raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.") if not np.issubdtype(seed_arr.dtype, np.integer): raise TypeError(f"PRNG key seed must be an integer; got {seed!r}") convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32))) k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF))) return lax.concatenate([k1, k2], 0)
def lennard_jones_exclusion(conf, lj_params, exclusion_idxs, lj_scales, cutoff, groups=None): box = None assert box is None assert exclusion_idxs.shape[1] == 2 # assert exclusion_idxs.shape[0] == conf.shape[0] assert exclusion_idxs.shape[0] == lj_scales.shape[0] src_idxs = exclusion_idxs[:, 0] dst_idxs = exclusion_idxs[:, 1] ri = conf[src_idxs] rj = conf[dst_idxs] gi = groups[src_idxs] gj = groups[dst_idxs] gij = np.bitwise_and(gi, gj) > 0 dij = distance(ri, rj, box, gij) sig_params = lj_params[:, 0] sig_i = sig_params[src_idxs] sig_j = sig_params[dst_idxs] sig_ij = (sig_i + sig_j)/2 eps_params = lj_params[:, 1] eps_i = eps_params[src_idxs] eps_j = eps_params[dst_idxs] eps_ij = np.sqrt(eps_i * eps_j) if cutoff is not None: eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij)) sig2 = sig_ij/dij sig2 *= sig2 sig6 = sig2*sig2*sig2 scale_ij = lj_scales eij_exc = scale_ij*4*eps_ij*(sig6-1.0)*sig6 if cutoff is not None: # sw = switch_fn(dij, cutoff) # eij_exc = eij_exc*sw eij_exc = np.where(dij > cutoff, np.zeros_like(eij_exc), eij_exc) eij_exc = np.where(src_idxs == dst_idxs, np.zeros_like(eij_exc), eij_exc) # the exclusion energy is not divided by two. return np.sum(eij_exc)
def cond_fn(*args): """ check if all are done or reached max number of iterations """ i, _, done, _, _ = args[0] return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
def bitwise_and(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.bitwise_and(x1, x2))
def cond_fn(curr): return jnp.bitwise_and( curr.i < SineBivariateVonMises.max_sample_iter, jnp.logical_not(jnp.all(curr.done)), )
def lennard_jones(conf, lj_params, cutoff, groups=None): """ Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j). Parameters ---------- conf: shape [num_atoms, 3] np.array atomic coordinates params: shape [num_params,] np.array unique parameters box: shape [3, 3] np.array periodic boundary vectors, if not None param_idxs: shape [num_atoms, 2] np.array each tuple (sig, eps) is used as part of the combining rules scale_matrix: shape [num_atoms, num_atoms] np.array scale mask denoting how we should scale interaction e[i,j]. The elements should be between [0, 1]. If e[i,j] is 1 then the interaction is fully included, 0 implies it is discarded. cutoff: float Whether or not we apply cutoffs to the system. Any interactions greater than cutoff is fully discarded. """ box = None assert box is None sig = lj_params[:, 0] eps = lj_params[:, 1] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = (sig_i + sig_j)/2 sig_ij_raw = sig_ij eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = np.sqrt(eps_i * eps_j) eps_ij_raw = eps_ij ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) gi = np.expand_dims(groups, axis=0) gj = np.expand_dims(groups, axis=1) gij = np.bitwise_and(gi, gj) > 0 # print(gij) dij = distance(ri, rj, box, gij) if cutoff is not None: eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij)) N = conf.shape[0] keep_mask = np.ones((N,N)) - np.eye(N) # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij)) eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij)) sig2 = sig_ij/dij sig2 *= sig2 sig6 = sig2*sig2*sig2 eij = 4*eps_ij*(sig6-1.0)*sig6 # if cutoff is not None: # sw = switch_fn(dij, cutoff) # eij = eij*sw eij = np.where(keep_mask, eij, np.zeros_like(eij)) return np.sum(eij/2)
def cond_fn(curr): return jnp.bitwise_and(curr.i < Sine.max_sample_iter, jnp.logical_not(jnp.all(curr.done)))
def _arctan2(x, y, fill_zero: Optional[float] = None): if fill_zero is not None: return np.where(np.bitwise_and(x == 0., y == 0.), fill_zero, np.arctan2(x, y)) return np.arctan2(x, y)