Beispiel #1
0
        def iter_func(position):
            for j in range(num_iters):
                j = jnp.uint32(j)
                upper = jnp.right_shift(position, bits_lower)
                lower = jnp.bitwise_and(position, mask_lower)
                mixer = hash_func_in(upper + seed_offst + j)

                tmp = jnp.bitwise_xor(lower, mixer)
                position = upper + (jnp.left_shift(
                    jnp.bitwise_and(tmp, mask_lower), bits_upper))
            return position
Beispiel #2
0
def threefry_seed(seed: int) -> jnp.ndarray:
    """Create a single raw threefry PRNG key given an integer seed.

  Args:
    seed: a 64- or 32-bit integer used as the value of the key.

  Returns:
    The PRNG key contents, modeled as an array of shape (2,) and dtype
    uint32. The key is constructed from a 64-bit seed by effectively
    bit-casting to a pair of uint32 values (or from a 32-bit seed by
    first padding out with zeros).
  """
    # Avoid overflowerror in X32 mode by first converting ints to int64.
    # This breaks JIT invariance for large ints, but supports the common
    # use-case of instantiating with Python hashes in X32 mode.
    if isinstance(seed, int):
        seed_arr = jnp.asarray(np.int64(seed))
    else:
        seed_arr = jnp.asarray(seed)
    if seed_arr.shape:
        raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
    if not np.issubdtype(seed_arr.dtype, np.integer):
        raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")

    convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32),
                                    [1])
    k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
    k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
    return lax.concatenate([k1, k2], 0)
Beispiel #3
0
def lennard_jones_exclusion(conf, lj_params, exclusion_idxs, lj_scales, cutoff, groups=None):

    box = None
    assert box is None

    assert exclusion_idxs.shape[1] == 2
    # assert exclusion_idxs.shape[0] == conf.shape[0]
    assert exclusion_idxs.shape[0] == lj_scales.shape[0]

    src_idxs = exclusion_idxs[:, 0]
    dst_idxs = exclusion_idxs[:, 1]
    ri = conf[src_idxs]
    rj = conf[dst_idxs]

    gi = groups[src_idxs]
    gj = groups[dst_idxs]
    gij = np.bitwise_and(gi, gj) > 0
    dij = distance(ri, rj, box, gij)

    sig_params = lj_params[:, 0] 
    sig_i = sig_params[src_idxs]
    sig_j = sig_params[dst_idxs]
    sig_ij = (sig_i + sig_j)/2

    eps_params = lj_params[:, 1] 
    eps_i = eps_params[src_idxs]
    eps_j = eps_params[dst_idxs]
    eps_ij = np.sqrt(eps_i * eps_j)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij/dij
    sig2 *= sig2
    sig6 = sig2*sig2*sig2

    scale_ij = lj_scales
    eij_exc = scale_ij*4*eps_ij*(sig6-1.0)*sig6

    if cutoff is not None:
        # sw = switch_fn(dij, cutoff)
        # eij_exc = eij_exc*sw
        eij_exc = np.where(dij > cutoff, np.zeros_like(eij_exc), eij_exc)
        eij_exc = np.where(src_idxs == dst_idxs, np.zeros_like(eij_exc), eij_exc)

    # the exclusion energy is not divided by two.
    return np.sum(eij_exc)
Beispiel #4
0
 def cond_fn(*args):
     """ check if all are done or reached max number of iterations """
     i, _, done, _, _ = args[0]
     return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
Beispiel #5
0
def bitwise_and(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.bitwise_and(x1, x2))
Beispiel #6
0
 def cond_fn(curr):
     return jnp.bitwise_and(
         curr.i < SineBivariateVonMises.max_sample_iter,
         jnp.logical_not(jnp.all(curr.done)),
     )
Beispiel #7
0
def lennard_jones(conf, lj_params, cutoff, groups=None):
    """
    Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining
    rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j).

    Parameters
    ----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    param_idxs: shape [num_atoms, 2] np.array
        each tuple (sig, eps) is used as part of the combining rules

    scale_matrix: shape [num_atoms, num_atoms] np.array
        scale mask denoting how we should scale interaction e[i,j].
        The elements should be between [0, 1]. If e[i,j] is 1 then the interaction
        is fully included, 0 implies it is discarded.

    cutoff: float
        Whether or not we apply cutoffs to the system. Any interactions
        greater than cutoff is fully discarded.
    
    """
    box = None
    assert box is None

    sig = lj_params[:, 0]
    eps = lj_params[:, 1]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = (sig_i + sig_j)/2
    sig_ij_raw = sig_ij

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = np.sqrt(eps_i * eps_j)

    eps_ij_raw = eps_ij

    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    gi = np.expand_dims(groups, axis=0)
    gj = np.expand_dims(groups, axis=1)
    gij = np.bitwise_and(gi, gj) > 0

    # print(gij)
    dij = distance(ri, rj, box, gij)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    N = conf.shape[0]
    keep_mask = np.ones((N,N)) - np.eye(N)

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij))
    eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij/dij
    sig2 *= sig2
    sig6 = sig2*sig2*sig2

    eij = 4*eps_ij*(sig6-1.0)*sig6

    # if cutoff is not None:
        # sw = switch_fn(dij, cutoff)
        # eij = eij*sw

    eij = np.where(keep_mask, eij, np.zeros_like(eij))
    return np.sum(eij/2)
Beispiel #8
0
 def cond_fn(curr):
     return jnp.bitwise_and(curr.i < Sine.max_sample_iter,
                            jnp.logical_not(jnp.all(curr.done)))
Beispiel #9
0
def _arctan2(x, y, fill_zero: Optional[float] = None):
    if fill_zero is not None:
        return np.where(np.bitwise_and(x == 0., y == 0.), fill_zero,
                        np.arctan2(x, y))
    return np.arctan2(x, y)