Exemple #1
0
  def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype, format):
    key = random.PRNGKey(0)

    def truncated_square(dr, sigma):
      return np.where(dr < sigma, dr ** 2, f32(0.))

    N = NEIGHBOR_LIST_PARTICLE_COUNT
    box_size = 4. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)
    d = space.metric(disp)

    neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0)
    neighbor_square = jit(neighbor_square)
    mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(
        split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
      neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0,
                                            format=format)
      nbrs = neighbor_fn.allocate(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, nbrs, sigma=sigma))
    def test_pair_neighbor_list_force_scalar_diverging_potential(
            self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def potential(dr, sigma):
            return np.where(dr < sigma, dr**-6, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 4. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(
            quantity.force(smap.pair_neighbor_list(potential, d)))
        mapped_square = jit(quantity.force(smap.pair(potential, d)))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=4.5)
            neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0)
            nbrs = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))
    def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension,
                                                      dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)
        species = np.zeros((N, ), np.int32)
        species = np.where(np.arange(N) > N / 3, 1, species)
        species = np.where(np.arange(N) > 2 * N / 3, 2, species)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square, d, species=species))
        mapped_square = jit(smap.pair(truncated_square, d, species=species))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = partition.neighbor_list(disp, box_size,
                                                  np.max(sigma), 0.)
            nbrs = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))
Exemple #4
0
    def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype,
                                               format):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(
            energy.lennard_jones_pair(displacement))

        r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, box_size, format=format)
        force_fn = quantity.force(energy_fn)

        nbrs = neighbor_fn.allocate(r)
        if dtype == f32 and format is partition.OrderedSparse:
            self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype),
                                force_fn(r, nbrs),
                                atol=5e-5,
                                rtol=5e-5)
        else:
            self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype),
                                force_fn(r, nbrs))
Exemple #5
0
  def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    square = lambda dr, param=1.0: param * dr ** 2
    params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

    key, split = random.split(key)
    species = random.randint(split, (PARTICLE_COUNT,), 0, 2)
    displacement, _ = space.free()
    metric = space.metric(displacement)

    mapped_square = smap.pair(square, metric, species=2, param=params)

    metric = space.map_product(metric)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      total = 0.0
      for i in range(2):
        for j in range(2):
          param = params[i, j]
          R_1 = R[species == i]
          R_2 = R[species == j]
          total = total + 0.5 * np.sum(square(metric(R_1, R_2), param))
      self.assertAllClose(mapped_square(R, species),
                          np.array(total, dtype=dtype))
Exemple #6
0
    def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension,
                                                     dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        tol = 2e-10 if dtype == np.float32 else None

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
        mapped_square = jit(smap.pair(truncated_square, d))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = jit(
                partition.neighbor_list(disp, box_size, np.max(sigma), R))
            idx = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, idx, sigma=sigma), True,
                                tol, tol)
Exemple #7
0
 def test_bks(self, dtype):
     LATCON = 3.5660930663857577e+01
     displacement, shift = space.periodic(LATCON)
     dist_fun = space.metric(displacement)
     species = np.tile(np.array([0, 1, 1]), 1000)
     R_f = test_util.load_silica_data()
     energy_fn = energy.bks_silica_pair(dist_fun, species=species)
     self.assertAllClose(-857939.528386092, energy_fn(R_f))
Exemple #8
0
 def test_bks_neighbor_list(self, dtype, format):
   LATCON = 3.5660930663857577e+01
   displacement, shift = space.periodic(LATCON)
   dist_fun = space.metric(displacement)
   species = np.tile(np.array([0, 1, 1]), 1000)
   R_f = test_util.load_silica_data()
   neighbor_fn, energy_nei = energy.bks_silica_neighbor_list(
     dist_fun, LATCON, species=species, format=format)
   nbrs = neighbor_fn.allocate(R_f)
   self.assertAllClose(-857939.528386092, energy_nei(R_f, nbrs))
Exemple #9
0
def eam(displacement: DisplacementFn,
        charge_fn: Callable[[Array], Array],
        embedding_fn: Callable[[Array], Array],
        pairwise_fn: Callable[[Array], Array],
        axis: Tuple[int, ...] = None) -> Callable[[Array], Array]:
    """Interatomic potential as approximated by embedded atom model (EAM).

  This code implements the EAM approximation to interactions between metallic
  atoms. In EAM, the potential energy of an atom is given by two terms: a
  pairwise energy and an embedding energy due to the interaction between the
  atom and background charge density. The EAM potential for a single atomic
  species is often
  determined by three functions:
    1) Charge density contribution of an atom as a function of distance.
    2) Energy of embedding an atom in the background charge density.
    3) Pairwise energy.
  These three functions are usually provided as spline fits, and we follow the
  implementation and spline fits given by [1]. Note that in current
  implementation, the three functions listed above can also be expressed by a
  any function with the correct signature, including neural networks.

  Args:
    displacement: A function that produces an ndarray of shape [n, m,
      spatial_dimension] of particle displacements from particle positions
      specified as an ndarray of shape [n, spatial_dimension] and [m,
      spatial_dimension] respectively.
    charge_fn: A function that takes an ndarray of shape [n, m] of distances
      between particles and returns a matrix of charge contributions.
    embedding_fn: Function that takes an ndarray of shape [n] of charges and
      returns an ndarray of shape [n] of the energy cost of embedding an atom
      into the charge.
    pairwise_fn: A function that takes an ndarray of shape [n, m] of distances
      and returns an ndarray of shape [n, m] of pairwise energies.
    axis: Specifies which axis the total energy should be summed over.

  Returns:
    A function that computes the EAM energy of a set of atoms with positions
    given by an [n, spatial_dimension] ndarray.

  [1] Y. Mishin, D. Farkas, M.J. Mehl, DA Papaconstantopoulos, "Interatomic
  potentials for monoatomic metals from experimental data and ab initio
  calculations." Physical Review B, 59 (1999)
  """
    metric = space.map_product(space.metric(displacement))

    def energy(R, **kwargs):
        dr = metric(R, R, **kwargs)
        total_charge = util.high_precision_sum(charge_fn(dr), axis=1)
        embedding_energy = embedding_fn(total_charge)
        pairwise_energy = util.high_precision_sum(
            smap._diagonal_mask(pairwise_fn(dr)), axis=1) / f32(2.0)
        return util.high_precision_sum(embedding_energy + pairwise_energy,
                                       axis=axis)

    return energy
Exemple #10
0
 def test_bks(self, dtype):
     LATCON = 3.5660930663857577e+01
     displacement, shift = space.periodic(LATCON)
     dist_fun = space.metric(displacement)
     species = np.tile(np.array([0, 1, 1]), 1000)
     current_dir = os.getcwd()
     filename = os.path.join(current_dir, 'tests/data/silica_positions.npy')
     with open(filename, 'rb') as f:
         R_f = np.array(np.load(f))
     energy_fn = energy.bks_silica_pair(dist_fun, species=species)
     self.assertAllClose(-857939.528386092, energy_fn(R_f))
    def test_pair_scalar_dummy_arg(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr**2

        key, split = random.split(key)
        R = random.normal(key, (PARTICLE_COUNT, spatial_dimension),
                          dtype=dtype)
        displacement, shift = space.free()

        mapped = smap.pair(square, space.metric(displacement))

        mapped(R, t=f32(0))
Exemple #12
0
    def test_neighbor_list_build_time_dependent(self, dtype, dim):
        key = random.PRNGKey(1)

        if dim == 2:
            box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32)
        elif dim == 3:
            box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0],
                                         [0.0, 0.0, 7.25]])
        min_length = np.min(np.diag(box_fn(0.)))
        cutoff = f32(1.23)
        # TODO(schsam): Get cell-list working with anisotropic cell sizes.
        cell_size = cutoff / min_length

        displacement, _ = space.periodic_general(box_fn)
        metric = space.metric(displacement)

        R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
        N = R.shape[0]
        neighbor_list_fn = partition.neighbor_list(metric,
                                                   1.,
                                                   cutoff,
                                                   0.0,
                                                   1.1,
                                                   cell_size=cell_size,
                                                   t=np.array(0.))

        idx = neighbor_list_fn(R, t=np.array(0.25)).idx
        R_neigh = R[idx]
        mask = idx < N

        metric = partial(metric, t=f32(0.25))
        d = vmap(vmap(metric, (None, 0)))
        dR = d(R, R_neigh)

        d_exact = space.map_product(metric)
        dR_exact = d_exact(R, R)

        dR = np.where(dR < cutoff, dR, 0) * mask
        dR_exact = np.where(dR_exact < cutoff, dR_exact, 0)

        dR = np.sort(dR, axis=1)
        dR_exact = np.sort(dR_exact, axis=1)

        for i in range(dR.shape[0]):
            dR_row = dR[i]
            dR_row = dR_row[dR_row > 0.]

            dR_exact_row = dR_exact[i]
            dR_exact_row = dR_exact_row[dR_exact_row > 0.]

            self.assertAllClose(dR_row, dR_exact_row)
Exemple #13
0
    def test_lennard_jones_cell_list_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_energy_fn = energy.lennard_jones_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        energy_fn = energy.lennard_jones_cell_list(displacement, box_size, R)

        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R), True)
Exemple #14
0
def _canonicalize_displacement_or_metric(displacement_or_metric):
    """Checks whether or not a displacement or metric was provided."""
    for dim in range(4):
        try:
            R = ShapedArray((1, dim), f32)
            dR_or_dr = pe.abstract_eval_fun(displacement_or_metric, R, R, t=0)
            if len(dR_or_dr.shape) == 2:
                return displacement_or_metric
            else:
                return space.metric(displacement_or_metric)
        except ValueError:
            continue
    raise ValueError(
        'Canonicalize displacement not implemented for spatial dimension larger'
        'than 4.')
Exemple #15
0
def main(unused_argv):
    key = random.PRNGKey(0)

    # Setup some variables describing the system.
    N = 500
    dimension = 2
    box_size = f32(25.0)

    # Create helper functions to define a periodic box of some size.
    displacement, shift = space.periodic(box_size)

    metric = space.metric(displacement)

    # Use JAX's random number generator to generate random initial positions.
    key, split = random.split(key)
    R = random.uniform(split, (N, dimension),
                       minval=0.0,
                       maxval=box_size,
                       dtype=f32)

    # The system ought to be a 50:50 mixture of two types of particles, one
    # large and one small.
    sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32)
    N_2 = int(N / 2)
    species = np.array([0] * N_2 + [1] * N_2, dtype=i32)

    # Create an energy function.
    energy_fn = energy.soft_sphere_pair(displacement, species, sigma)
    force_fn = quantity.force(energy_fn)

    # Create a minimizer.
    init_fn, apply_fn = minimize.fire_descent(energy_fn, shift)
    opt_state = init_fn(R)

    # Minimize the system.
    minimize_steps = 50
    print_every = 10

    print('Minimizing.')
    print('Step\tEnergy\tMax Force')
    print('-----------------------------------')
    for step in range(minimize_steps):
        opt_state = apply_fn(opt_state)

        if step % print_every == 0:
            R = opt_state.position
            print('{:.2f}\t{:.2f}\t{:.2f}'.format(step, energy_fn(R),
                                                  np.max(force_fn(R))))
Exemple #16
0
  def test_pair_cell_list_energy(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(9.0)
    cell_size = f32(1.0)
    displacement, _ = space.periodic(box_size)
    metric = space.metric(displacement)
    exact_energy_fn = energy.soft_sphere_pair(displacement)
    energy_fn = smap.cartesian_product(energy.soft_sphere, metric)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    cell_energy_fn = smap.cell_list(energy_fn, box_size, cell_size, R)
    self.assertAllClose(
      np.array(exact_energy_fn(R), dtype=dtype),
      cell_energy_fn(R), True)
Exemple #17
0
    def test_morse_small_neighbor_list_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(5.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_energy_fn = energy.morse_pair(displacement)

        R = box_size * random.uniform(key, (10, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.morse_neighbor_list(
            displacement, box_size)

        nbrs = neighbor_fn(R)
        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, nbrs))
Exemple #18
0
  def test_cell_list_incommensurate(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(12.1)
    cell_size = f32(3.0)
    displacement, _ = space.periodic(box_size)
    energy_fn = energy.soft_sphere_pair(displacement)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    cell_list_energy = smap.cartesian_product(
      energy.soft_sphere, space.metric(displacement))
    cell_list_energy = \
      jit(smap.cell_list(cell_list_energy, box_size, cell_size, R))
    self.assertAllClose(
      np.array(energy_fn(R), dtype=dtype), cell_list_energy(R), True)
Exemple #19
0
    def test_lennard_jones_cell_neighbor_list_energy(self, spatial_dimension,
                                                     dtype, format):
        key = random.PRNGKey(1)

        box_size = f32(15)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_energy_fn = energy.lennard_jones_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, box_size, format=format)

        nbrs = neighbor_fn.allocate(R)
        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, nbrs))
Exemple #20
0
  def test_cell_list_direct_force_jit(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(9.0)
    cell_size = f32(1.0)
    displacement, _ = space.periodic(box_size)
    energy_fn = energy.soft_sphere_pair(displacement)
    force_fn = quantity.force(energy_fn)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    grid_energy_fn = smap.cartesian_product(
      energy.soft_sphere, space.metric(displacement))
    grid_force_fn = quantity.force(grid_energy_fn)
    grid_force_fn = jit(smap.cell_list(grid_force_fn, box_size, cell_size, R))
    self.assertAllClose(
      np.array(force_fn(R), dtype=dtype), grid_force_fn(R), True)
Exemple #21
0
  def test_canonicalize_displacement_or_metric(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    displacement, _ = space.periodic_general(np.eye(spatial_dimension))
    metric = space.metric(displacement)
    test_metric = space.canonicalize_displacement_or_metric(displacement)

    metric = space.map_product(metric)
    test_metric = space.map_product(test_metric)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2 = random.split(key, 3)

      R = random.normal(
        split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      self.assertAllClose(metric(R, R), test_metric(R, R), True)
    def test_bond_no_type_static(self, spatial_dimension, dtype):
        harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32))

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2]))

            self.assertAllClose(mapped(R), dtype(accum))
Exemple #23
0
    def test_morse_neighbor_list_force(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(energy.morse_pair(displacement))

        r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.morse_neighbor_list(
            displacement, box_size)
        force_fn = quantity.force(energy_fn)

        nbrs = neighbor_fn(r)
        self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype),
                            force_fn(r, nbrs))
Exemple #24
0
    def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(
            energy.lennard_jones_pair(displacement))

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, box_size, R)
        force_fn = quantity.force(energy_fn)

        idx = neighbor_fn(R)
        self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype),
                            force_fn(R, idx), True)
Exemple #25
0
    def test_neighbor_list_build_sparse(self, dtype, dim):
        key = random.PRNGKey(1)

        box_size = (np.array([9.0, 4.0, 7.25], f32)
                    if dim == 3 else np.array([9.0, 4.25], f32))
        cutoff = f32(1.23)

        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
        N = R.shape[0]
        neighbor_fn = partition.neighbor_list(displacement,
                                              box_size,
                                              cutoff,
                                              0.0,
                                              1.1,
                                              format=partition.Sparse)

        nbrs = neighbor_fn.allocate(R)
        mask = partition.neighbor_list_mask(nbrs)

        d = space.map_bond(metric)
        dR = d(R[nbrs.idx[0]], R[nbrs.idx[1]])

        d_exact = space.map_product(metric)
        dR_exact = d_exact(R, R)

        dR = np.where(dR < cutoff, dR, f32(0)) * mask
        mask_exact = 1. - np.eye(dR_exact.shape[0])
        dR_exact = np.where(dR_exact < cutoff, dR_exact, f32(0)) * mask_exact

        dR_exact = np.sort(dR_exact, axis=1)

        for i in range(N):
            dR_row = dR[nbrs.idx[0] == i]
            dR_row = dR_row[dR_row > 0.]
            dR_row = np.sort(dR_row)

            dR_exact_row = dR_exact[i]
            dR_exact_row = np.array(dR_exact_row[dR_exact_row > 0.], dtype)

            self.assertAllClose(dR_row, dR_exact_row)
Exemple #26
0
    def test_pair_grid_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f16(9.0)
        cell_size = f16(2.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        energy_fn = smap.pair(energy.soft_sphere,
                              metric,
                              quantity.Dynamic,
                              reduce_axis=(1, ),
                              keepdims=True)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        grid_energy_fn = smap.grid(energy_fn, box_size, cell_size, R)
        species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64)
        self.assertAllClose(np.array(energy_fn(R, species, 1), dtype=dtype),
                            grid_energy_fn(R), True)
Exemple #27
0
  def test_bond_params_dynamic(self, spatial_dimension, dtype):
    harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2)
    disp, _ = space.free()
    metric = space.metric(disp)

    sigma = np.array([1.0, 2.0], f32)

    mapped = smap.bond(harmonic, metric, sigma=1.0)
    bonds = np.array([[0, 1], [0, 2]], i32)

    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2)

      self.assertAllClose(mapped(R, bonds, sigma=sigma), dtype(accum))
  def test_neighbor_list_build(self, dtype, dim):
    key = random.PRNGKey(1)

    box_size = (
      np.array([9.0, 4.0, 7.25], f32) if dim is 3 else
      np.array([9.0, 4.25], f32))
    cutoff = f32(1.23)

    displacement, _ = space.periodic(box_size)
    metric = space.metric(displacement)

    R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
    N = R.shape[0]
    neighbor_list_fn = partition.neighbor_list(
      displacement, box_size, cutoff, R)

    idx = neighbor_list_fn(R)
    R_neigh = R[idx]
    mask = idx < N

    d = vmap(vmap(metric, (None, 0)))
    dR = d(R, R_neigh)

    d_exact = space.map_product(metric)
    dR_exact = d_exact(R, R)

    dR = np.where(dR < cutoff, dR, f32(0)) * mask
    mask_exact = 1. - np.eye(dR_exact.shape[0])
    dR_exact = np.where(dR_exact < cutoff, dR_exact, f32(0)) * mask_exact

    dR = np.sort(dR, axis=1)
    dR_exact = np.sort(dR_exact, axis=1)

    for i in range(dR.shape[0]):
      dR_row = dR[i]
      dR_row = dR_row[dR_row > 0.]

      dR_exact_row = dR_exact[i]
      dR_exact_row = np.array(dR_exact_row[dR_exact_row > 0.], dtype)

      self.assertAllClose(dR_row, dR_exact_row, True)
Exemple #29
0
  def test_cell_list_force_nonuniform(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    if spatial_dimension == 2:
      box_size = f32(np.array([[8.0, 10.0]]))
    else:
      box_size = f32(np.array([[8.0, 10.0, 12.0]]))
    cell_size = f32(2.0)
    displacement, _ = space.periodic(box_size[0])
    energy_fn = energy.soft_sphere_pair(displacement)
    force_fn = quantity.force(energy_fn)
    
    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

    cell_energy_fn = smap.cartesian_product(
      energy.soft_sphere, space.metric(displacement))
    cell_force_fn = quantity.force(cell_energy_fn)
    cell_force_fn = smap.cell_list(cell_force_fn, box_size, cell_size, R)
    df = np.sum((force_fn(R) - cell_force_fn(R)) ** 2, axis=1)
    self.assertAllClose(
      np.array(force_fn(R), dtype=dtype), cell_force_fn(R), True)
Exemple #30
0
def hybrid_swap_mc(
        space_fns: space.Space,
        energy_fn: Callable[[Array, Array], Array],
        neighbor_fn: partition.NeighborFn,
        dt: float,
        kT: float,
        t_md: float,
        N_swap: int,
        sigma_fn: Optional[Callable[[Array], Array]] = None) -> Simulator:
    """Simulation of Hybrid Swap Monte-Carlo.

  This code simulates the hybrid Swap Monte Carlo algorithm introduced in [1].
  Here an NVT simulation is performed for `t_md` time and then `N_swap` MC
  moves are performed that swap the radii of randomly chosen particles. The
  random swaps are accepted with Metropolis-Hastings step. Each call to the
  step function runs molecular dynamics for `t_md` and then performs the swaps.

  Note that this code doesn't feature some of the convenience functions in the
  other simulations. In particular, there is no support for dynamics keyword
  arguments and the energy function must be a simple callable of two variables:
  the distance between adjacent particles and the diameter of the particles.
  If you want support for a better notion of potential or dynamic keyword
  arguments, please file an issue!

  Args:
    space_fns: A tuple of a displacement function and a shift function defined
      in `space.py`.
    energy_fn: A function that computes the energy between one pair of
      particles as a function of the distance between the particles and the
      diameter. This function should not have been passed to `smap.xxx`.
    neighbor_fn: A function to construct neighbor lists outlined in
      `partition.py`.
    dt: The timestep used for the continuous time MD portion of the simulation.
    kT: The temperature of heat bath that the system is coupled to during MD.
    t_md: The time of each MD block.
    N_swap: The number of swapping moves between MD blocks.
    sigma_fn: An optional function for combining radii if they are to be
      non-additive.

  Returns:
    See above.

  [1] L. Berthier, E. Flenner, C. J. Fullerton, C. Scalliet, and M. Singh.
      "Efficient swap algorithms for molecular dynamics simulations of
       equilibrium supercooled liquids"
      J. Stat. Mech. (2019) 064004
  """
    displacement_fn, shift_fn = space_fns
    metric_fn = space.metric(displacement_fn)
    nbr_metric_fn = space.map_neighbor(metric_fn)

    md_steps = int(t_md // dt)

    # Canonicalize the argument names to be dr and sigma.
    wrapped_energy_fn = lambda dr, sigma: energy_fn(dr, sigma)
    if sigma_fn is None:
        sigma_fn = lambda si, sj: 0.5 * (si + sj)
    nbr_energy_fn = smap.pair_neighbor_list(wrapped_energy_fn,
                                            metric_fn,
                                            sigma=sigma_fn)

    nvt_init_fn, nvt_step_fn = nvt_nose_hoover(nbr_energy_fn,
                                               shift_fn,
                                               dt,
                                               kT=kT,
                                               chain_length=3)

    def init_fn(key, position, sigma, nbrs=None):
        key, sim_key = random.split(key)
        nbrs = neighbor_fn(position, nbrs)  # pytype: disable=wrong-arg-count
        md_state = nvt_init_fn(sim_key, position, neighbor=nbrs, sigma=sigma)
        return SwapMCState(md_state, sigma, key, nbrs)  # pytype: disable=wrong-arg-count

    def md_step_fn(i, state):
        md, sigma, key, nbrs = dataclasses.unpack(state)
        md = nvt_step_fn(md, neighbor=nbrs, sigma=sigma)  # pytype: disable=wrong-keyword-args
        nbrs = neighbor_fn(md.position, nbrs)
        return SwapMCState(md, sigma, key, nbrs)  # pytype: disable=wrong-arg-count

    def swap_step_fn(i, state):
        md, sigma, key, nbrs = dataclasses.unpack(state)

        N = md.position.shape[0]

        # Swap a random pair of particle radii.
        key, particle_key, accept_key = random.split(key, 3)
        ij = random.randint(particle_key, (2, ), jnp.array(0), jnp.array(N))
        new_sigma = sigma.at[ij].set([sigma[ij[1]], sigma[ij[0]]])

        # Collect neighborhoods around the two swapped particles.
        nbrs_ij = nbrs.idx[ij]
        R_ij = md.position[ij]
        R_neigh = md.position[nbrs_ij]

        sigma_ij = sigma[ij][:, None]
        sigma_neigh = sigma[nbrs_ij]

        new_sigma_ij = new_sigma[ij][:, None]
        new_sigma_neigh = new_sigma[nbrs_ij]

        dR = nbr_metric_fn(R_ij, R_neigh)

        # Compute the energy before the swap.
        energy = energy_fn(dR, sigma_fn(sigma_ij, sigma_neigh))
        energy = jnp.sum(energy * (nbrs_ij < N))

        # Compute the energy after the swap.
        new_energy = energy_fn(dR, sigma_fn(new_sigma_ij, new_sigma_neigh))
        new_energy = jnp.sum(new_energy * (nbrs_ij < N))

        # Accept or reject with a metropolis probability.
        p = random.uniform(accept_key, ())
        accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT))
        sigma = jnp.where(p < accept_prob, new_sigma, sigma)

        return SwapMCState(md, sigma, key, nbrs)  # pytype: disable=wrong-arg-count

    def block_fn(state):
        state = lax.fori_loop(0, md_steps, md_step_fn, state)
        state = lax.fori_loop(0, N_swap, swap_step_fn, state)
        return state

    return init_fn, block_fn