def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr ** 2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0) neighbor_square = jit(neighbor_square) mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform( split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=2.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def test_pair_neighbor_list_force_scalar_diverging_potential( self, spatial_dimension, dtype): key = random.PRNGKey(0) def potential(dr, sigma): return np.where(dr < sigma, dr**-6, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit( quantity.force(smap.pair_neighbor_list(potential, d))) mapped_square = jit(quantity.force(smap.pair(potential, d))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=4.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0) nbrs = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) species = np.zeros((N, ), np.int32) species = np.where(np.arange(N) > N / 3, 1, species) species = np.where(np.arange(N) > 2 * N / 3, 2, species) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, d, species=species)) mapped_square = jit(smap.pair(truncated_square, d, species=species)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0.) nbrs = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype, format): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, format=format) force_fn = quantity.force(energy_fn) nbrs = neighbor_fn.allocate(r) if dtype == f32 and format is partition.OrderedSparse: self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs), atol=5e-5, rtol=5e-5) else: self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs))
def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * dr ** 2 params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT,), 0, 2) displacement, _ = space.free() metric = space.metric(displacement) mapped_square = smap.pair(square, metric, species=2, param=params) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(metric(R_1, R_2), param)) self.assertAllClose(mapped_square(R, species), np.array(total, dtype=dtype))
def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) tol = 2e-10 if dtype == np.float32 else None N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d)) mapped_square = jit(smap.pair(truncated_square, d)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = jit( partition.neighbor_list(disp, box_size, np.max(sigma), R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol)
def test_bks(self, dtype): LATCON = 3.5660930663857577e+01 displacement, shift = space.periodic(LATCON) dist_fun = space.metric(displacement) species = np.tile(np.array([0, 1, 1]), 1000) R_f = test_util.load_silica_data() energy_fn = energy.bks_silica_pair(dist_fun, species=species) self.assertAllClose(-857939.528386092, energy_fn(R_f))
def test_bks_neighbor_list(self, dtype, format): LATCON = 3.5660930663857577e+01 displacement, shift = space.periodic(LATCON) dist_fun = space.metric(displacement) species = np.tile(np.array([0, 1, 1]), 1000) R_f = test_util.load_silica_data() neighbor_fn, energy_nei = energy.bks_silica_neighbor_list( dist_fun, LATCON, species=species, format=format) nbrs = neighbor_fn.allocate(R_f) self.assertAllClose(-857939.528386092, energy_nei(R_f, nbrs))
def eam(displacement: DisplacementFn, charge_fn: Callable[[Array], Array], embedding_fn: Callable[[Array], Array], pairwise_fn: Callable[[Array], Array], axis: Tuple[int, ...] = None) -> Callable[[Array], Array]: """Interatomic potential as approximated by embedded atom model (EAM). This code implements the EAM approximation to interactions between metallic atoms. In EAM, the potential energy of an atom is given by two terms: a pairwise energy and an embedding energy due to the interaction between the atom and background charge density. The EAM potential for a single atomic species is often determined by three functions: 1) Charge density contribution of an atom as a function of distance. 2) Energy of embedding an atom in the background charge density. 3) Pairwise energy. These three functions are usually provided as spline fits, and we follow the implementation and spline fits given by [1]. Note that in current implementation, the three functions listed above can also be expressed by a any function with the correct signature, including neural networks. Args: displacement: A function that produces an ndarray of shape [n, m, spatial_dimension] of particle displacements from particle positions specified as an ndarray of shape [n, spatial_dimension] and [m, spatial_dimension] respectively. charge_fn: A function that takes an ndarray of shape [n, m] of distances between particles and returns a matrix of charge contributions. embedding_fn: Function that takes an ndarray of shape [n] of charges and returns an ndarray of shape [n] of the energy cost of embedding an atom into the charge. pairwise_fn: A function that takes an ndarray of shape [n, m] of distances and returns an ndarray of shape [n, m] of pairwise energies. axis: Specifies which axis the total energy should be summed over. Returns: A function that computes the EAM energy of a set of atoms with positions given by an [n, spatial_dimension] ndarray. [1] Y. Mishin, D. Farkas, M.J. Mehl, DA Papaconstantopoulos, "Interatomic potentials for monoatomic metals from experimental data and ab initio calculations." Physical Review B, 59 (1999) """ metric = space.map_product(space.metric(displacement)) def energy(R, **kwargs): dr = metric(R, R, **kwargs) total_charge = util.high_precision_sum(charge_fn(dr), axis=1) embedding_energy = embedding_fn(total_charge) pairwise_energy = util.high_precision_sum( smap._diagonal_mask(pairwise_fn(dr)), axis=1) / f32(2.0) return util.high_precision_sum(embedding_energy + pairwise_energy, axis=axis) return energy
def test_bks(self, dtype): LATCON = 3.5660930663857577e+01 displacement, shift = space.periodic(LATCON) dist_fun = space.metric(displacement) species = np.tile(np.array([0, 1, 1]), 1000) current_dir = os.getcwd() filename = os.path.join(current_dir, 'tests/data/silica_positions.npy') with open(filename, 'rb') as f: R_f = np.array(np.load(f)) energy_fn = energy.bks_silica_pair(dist_fun, species=species) self.assertAllClose(-857939.528386092, energy_fn(R_f))
def test_pair_scalar_dummy_arg(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr**2 key, split = random.split(key) R = random.normal(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) displacement, shift = space.free() mapped = smap.pair(square, space.metric(displacement)) mapped(R, t=f32(0))
def test_neighbor_list_build_time_dependent(self, dtype, dim): key = random.PRNGKey(1) if dim == 2: box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32) elif dim == 3: box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0], [0.0, 0.0, 7.25]]) min_length = np.min(np.diag(box_fn(0.))) cutoff = f32(1.23) # TODO(schsam): Get cell-list working with anisotropic cell sizes. cell_size = cutoff / min_length displacement, _ = space.periodic_general(box_fn) metric = space.metric(displacement) R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_list_fn = partition.neighbor_list(metric, 1., cutoff, 0.0, 1.1, cell_size=cell_size, t=np.array(0.)) idx = neighbor_list_fn(R, t=np.array(0.25)).idx R_neigh = R[idx] mask = idx < N metric = partial(metric, t=f32(0.25)) d = vmap(vmap(metric, (None, 0))) dR = d(R, R_neigh) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, 0) * mask dR_exact = np.where(dR_exact < cutoff, dR_exact, 0) dR = np.sort(dR, axis=1) dR_exact = np.sort(dR_exact, axis=1) for i in range(dR.shape[0]): dR_row = dR[i] dR_row = dR_row[dR_row > 0.] dR_exact_row = dR_exact[i] dR_exact_row = dR_exact_row[dR_exact_row > 0.] self.assertAllClose(dR_row, dR_exact_row)
def test_lennard_jones_cell_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.lennard_jones_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) energy_fn = energy.lennard_jones_cell_list(displacement, box_size, R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R), True)
def _canonicalize_displacement_or_metric(displacement_or_metric): """Checks whether or not a displacement or metric was provided.""" for dim in range(4): try: R = ShapedArray((1, dim), f32) dR_or_dr = pe.abstract_eval_fun(displacement_or_metric, R, R, t=0) if len(dR_or_dr.shape) == 2: return displacement_or_metric else: return space.metric(displacement_or_metric) except ValueError: continue raise ValueError( 'Canonicalize displacement not implemented for spatial dimension larger' 'than 4.')
def main(unused_argv): key = random.PRNGKey(0) # Setup some variables describing the system. N = 500 dimension = 2 box_size = f32(25.0) # Create helper functions to define a periodic box of some size. displacement, shift = space.periodic(box_size) metric = space.metric(displacement) # Use JAX's random number generator to generate random initial positions. key, split = random.split(key) R = random.uniform(split, (N, dimension), minval=0.0, maxval=box_size, dtype=f32) # The system ought to be a 50:50 mixture of two types of particles, one # large and one small. sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32) N_2 = int(N / 2) species = np.array([0] * N_2 + [1] * N_2, dtype=i32) # Create an energy function. energy_fn = energy.soft_sphere_pair(displacement, species, sigma) force_fn = quantity.force(energy_fn) # Create a minimizer. init_fn, apply_fn = minimize.fire_descent(energy_fn, shift) opt_state = init_fn(R) # Minimize the system. minimize_steps = 50 print_every = 10 print('Minimizing.') print('Step\tEnergy\tMax Force') print('-----------------------------------') for step in range(minimize_steps): opt_state = apply_fn(opt_state) if step % print_every == 0: R = opt_state.position print('{:.2f}\t{:.2f}\t{:.2f}'.format(step, energy_fn(R), np.max(force_fn(R))))
def test_pair_cell_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(9.0) cell_size = f32(1.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.soft_sphere_pair(displacement) energy_fn = smap.cartesian_product(energy.soft_sphere, metric) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) cell_energy_fn = smap.cell_list(energy_fn, box_size, cell_size, R) self.assertAllClose( np.array(exact_energy_fn(R), dtype=dtype), cell_energy_fn(R), True)
def test_morse_small_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(5.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.morse_pair(displacement) R = box_size * random.uniform(key, (10, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.morse_neighbor_list( displacement, box_size) nbrs = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, nbrs))
def test_cell_list_incommensurate(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(12.1) cell_size = f32(3.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pair(displacement) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) cell_list_energy = smap.cartesian_product( energy.soft_sphere, space.metric(displacement)) cell_list_energy = \ jit(smap.cell_list(cell_list_energy, box_size, cell_size, R)) self.assertAllClose( np.array(energy_fn(R), dtype=dtype), cell_list_energy(R), True)
def test_lennard_jones_cell_neighbor_list_energy(self, spatial_dimension, dtype, format): key = random.PRNGKey(1) box_size = f32(15) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.lennard_jones_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, nbrs))
def test_cell_list_direct_force_jit(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(9.0) cell_size = f32(1.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pair(displacement) force_fn = quantity.force(energy_fn) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) grid_energy_fn = smap.cartesian_product( energy.soft_sphere, space.metric(displacement)) grid_force_fn = quantity.force(grid_energy_fn) grid_force_fn = jit(smap.cell_list(grid_force_fn, box_size, cell_size, R)) self.assertAllClose( np.array(force_fn(R), dtype=dtype), grid_force_fn(R), True)
def test_canonicalize_displacement_or_metric(self, spatial_dimension, dtype): key = random.PRNGKey(0) displacement, _ = space.periodic_general(np.eye(spatial_dimension)) metric = space.metric(displacement) test_metric = space.canonicalize_displacement_or_metric(displacement) metric = space.map_product(metric) test_metric = space.map_product(test_metric) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.normal( split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) self.assertAllClose(metric(R, R), test_metric(R, R), True)
def test_bond_no_type_static(self, spatial_dimension, dtype): harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2) disp, _ = space.free() metric = space.metric(disp) mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32)) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2])) self.assertAllClose(mapped(R), dtype(accum))
def test_morse_neighbor_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force(energy.morse_pair(displacement)) r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.morse_neighbor_list( displacement, box_size) force_fn = quantity.force(energy_fn) nbrs = neighbor_fn(r) self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs))
def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, R) force_fn = quantity.force(energy_fn) idx = neighbor_fn(R) self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype), force_fn(R, idx), True)
def test_neighbor_list_build_sparse(self, dtype, dim): key = random.PRNGKey(1) box_size = (np.array([9.0, 4.0, 7.25], f32) if dim == 3 else np.array([9.0, 4.25], f32)) cutoff = f32(1.23) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_fn = partition.neighbor_list(displacement, box_size, cutoff, 0.0, 1.1, format=partition.Sparse) nbrs = neighbor_fn.allocate(R) mask = partition.neighbor_list_mask(nbrs) d = space.map_bond(metric) dR = d(R[nbrs.idx[0]], R[nbrs.idx[1]]) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, f32(0)) * mask mask_exact = 1. - np.eye(dR_exact.shape[0]) dR_exact = np.where(dR_exact < cutoff, dR_exact, f32(0)) * mask_exact dR_exact = np.sort(dR_exact, axis=1) for i in range(N): dR_row = dR[nbrs.idx[0] == i] dR_row = dR_row[dR_row > 0.] dR_row = np.sort(dR_row) dR_exact_row = dR_exact[i] dR_exact_row = np.array(dR_exact_row[dR_exact_row > 0.], dtype) self.assertAllClose(dR_row, dR_exact_row)
def test_pair_grid_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f16(9.0) cell_size = f16(2.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) energy_fn = smap.pair(energy.soft_sphere, metric, quantity.Dynamic, reduce_axis=(1, ), keepdims=True) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) grid_energy_fn = smap.grid(energy_fn, box_size, cell_size, R) species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64) self.assertAllClose(np.array(energy_fn(R, species, 1), dtype=dtype), grid_energy_fn(R), True)
def test_bond_params_dynamic(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond(harmonic, metric, sigma=1.0) bonds = np.array([[0, 1], [0, 2]], i32) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2) self.assertAllClose(mapped(R, bonds, sigma=sigma), dtype(accum))
def test_neighbor_list_build(self, dtype, dim): key = random.PRNGKey(1) box_size = ( np.array([9.0, 4.0, 7.25], f32) if dim is 3 else np.array([9.0, 4.25], f32)) cutoff = f32(1.23) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_list_fn = partition.neighbor_list( displacement, box_size, cutoff, R) idx = neighbor_list_fn(R) R_neigh = R[idx] mask = idx < N d = vmap(vmap(metric, (None, 0))) dR = d(R, R_neigh) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, f32(0)) * mask mask_exact = 1. - np.eye(dR_exact.shape[0]) dR_exact = np.where(dR_exact < cutoff, dR_exact, f32(0)) * mask_exact dR = np.sort(dR, axis=1) dR_exact = np.sort(dR_exact, axis=1) for i in range(dR.shape[0]): dR_row = dR[i] dR_row = dR_row[dR_row > 0.] dR_exact_row = dR_exact[i] dR_exact_row = np.array(dR_exact_row[dR_exact_row > 0.], dtype) self.assertAllClose(dR_row, dR_exact_row, True)
def test_cell_list_force_nonuniform(self, spatial_dimension, dtype): key = random.PRNGKey(1) if spatial_dimension == 2: box_size = f32(np.array([[8.0, 10.0]])) else: box_size = f32(np.array([[8.0, 10.0, 12.0]])) cell_size = f32(2.0) displacement, _ = space.periodic(box_size[0]) energy_fn = energy.soft_sphere_pair(displacement) force_fn = quantity.force(energy_fn) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) cell_energy_fn = smap.cartesian_product( energy.soft_sphere, space.metric(displacement)) cell_force_fn = quantity.force(cell_energy_fn) cell_force_fn = smap.cell_list(cell_force_fn, box_size, cell_size, R) df = np.sum((force_fn(R) - cell_force_fn(R)) ** 2, axis=1) self.assertAllClose( np.array(force_fn(R), dtype=dtype), cell_force_fn(R), True)
def hybrid_swap_mc( space_fns: space.Space, energy_fn: Callable[[Array, Array], Array], neighbor_fn: partition.NeighborFn, dt: float, kT: float, t_md: float, N_swap: int, sigma_fn: Optional[Callable[[Array], Array]] = None) -> Simulator: """Simulation of Hybrid Swap Monte-Carlo. This code simulates the hybrid Swap Monte Carlo algorithm introduced in [1]. Here an NVT simulation is performed for `t_md` time and then `N_swap` MC moves are performed that swap the radii of randomly chosen particles. The random swaps are accepted with Metropolis-Hastings step. Each call to the step function runs molecular dynamics for `t_md` and then performs the swaps. Note that this code doesn't feature some of the convenience functions in the other simulations. In particular, there is no support for dynamics keyword arguments and the energy function must be a simple callable of two variables: the distance between adjacent particles and the diameter of the particles. If you want support for a better notion of potential or dynamic keyword arguments, please file an issue! Args: space_fns: A tuple of a displacement function and a shift function defined in `space.py`. energy_fn: A function that computes the energy between one pair of particles as a function of the distance between the particles and the diameter. This function should not have been passed to `smap.xxx`. neighbor_fn: A function to construct neighbor lists outlined in `partition.py`. dt: The timestep used for the continuous time MD portion of the simulation. kT: The temperature of heat bath that the system is coupled to during MD. t_md: The time of each MD block. N_swap: The number of swapping moves between MD blocks. sigma_fn: An optional function for combining radii if they are to be non-additive. Returns: See above. [1] L. Berthier, E. Flenner, C. J. Fullerton, C. Scalliet, and M. Singh. "Efficient swap algorithms for molecular dynamics simulations of equilibrium supercooled liquids" J. Stat. Mech. (2019) 064004 """ displacement_fn, shift_fn = space_fns metric_fn = space.metric(displacement_fn) nbr_metric_fn = space.map_neighbor(metric_fn) md_steps = int(t_md // dt) # Canonicalize the argument names to be dr and sigma. wrapped_energy_fn = lambda dr, sigma: energy_fn(dr, sigma) if sigma_fn is None: sigma_fn = lambda si, sj: 0.5 * (si + sj) nbr_energy_fn = smap.pair_neighbor_list(wrapped_energy_fn, metric_fn, sigma=sigma_fn) nvt_init_fn, nvt_step_fn = nvt_nose_hoover(nbr_energy_fn, shift_fn, dt, kT=kT, chain_length=3) def init_fn(key, position, sigma, nbrs=None): key, sim_key = random.split(key) nbrs = neighbor_fn(position, nbrs) # pytype: disable=wrong-arg-count md_state = nvt_init_fn(sim_key, position, neighbor=nbrs, sigma=sigma) return SwapMCState(md_state, sigma, key, nbrs) # pytype: disable=wrong-arg-count def md_step_fn(i, state): md, sigma, key, nbrs = dataclasses.unpack(state) md = nvt_step_fn(md, neighbor=nbrs, sigma=sigma) # pytype: disable=wrong-keyword-args nbrs = neighbor_fn(md.position, nbrs) return SwapMCState(md, sigma, key, nbrs) # pytype: disable=wrong-arg-count def swap_step_fn(i, state): md, sigma, key, nbrs = dataclasses.unpack(state) N = md.position.shape[0] # Swap a random pair of particle radii. key, particle_key, accept_key = random.split(key, 3) ij = random.randint(particle_key, (2, ), jnp.array(0), jnp.array(N)) new_sigma = sigma.at[ij].set([sigma[ij[1]], sigma[ij[0]]]) # Collect neighborhoods around the two swapped particles. nbrs_ij = nbrs.idx[ij] R_ij = md.position[ij] R_neigh = md.position[nbrs_ij] sigma_ij = sigma[ij][:, None] sigma_neigh = sigma[nbrs_ij] new_sigma_ij = new_sigma[ij][:, None] new_sigma_neigh = new_sigma[nbrs_ij] dR = nbr_metric_fn(R_ij, R_neigh) # Compute the energy before the swap. energy = energy_fn(dR, sigma_fn(sigma_ij, sigma_neigh)) energy = jnp.sum(energy * (nbrs_ij < N)) # Compute the energy after the swap. new_energy = energy_fn(dR, sigma_fn(new_sigma_ij, new_sigma_neigh)) new_energy = jnp.sum(new_energy * (nbrs_ij < N)) # Accept or reject with a metropolis probability. p = random.uniform(accept_key, ()) accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT)) sigma = jnp.where(p < accept_prob, new_sigma, sigma) return SwapMCState(md, sigma, key, nbrs) # pytype: disable=wrong-arg-count def block_fn(state): state = lax.fori_loop(0, md_steps, md_step_fn, state) state = lax.fori_loop(0, N_swap, swap_step_fn, state) return state return init_fn, block_fn