def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype, format): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, format=format) force_fn = quantity.force(energy_fn) nbrs = neighbor_fn.allocate(r) if dtype == f32 and format is partition.OrderedSparse: self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs), atol=5e-5, rtol=5e-5) else: self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs))
def test_pair_correlation_neighbor_list_species(self, dim, dtype, format): if format is partition.OrderedSparse: self.skipTest('OrderedSparse not supported for pair correlation ' 'function.') N = 100 L = 10. displacement, _ = space.periodic(L) R = random.uniform(random.PRNGKey(0), (N, dim), dtype=dtype) species = np.where(np.arange(N) < N // 2, 0, 1) rs = np.linspace(0, 2, 60, dtype=dtype) g = quantity.pair_correlation(displacement, rs, f32(0.1), species) nbr_fn, g_neigh = quantity.pair_correlation_neighbor_list( displacement, L, rs, f32(0.1), species, format=format) nbrs = nbr_fn.allocate(R) g_0, g_1 = g(R) g_0 = np.mean(g_0, axis=0) g_1 = np.mean(g_1, axis=0) g_0_neigh, g_1_neigh = g_neigh(R, neighbor=nbrs) g_0_neigh = np.mean(g_0_neigh, axis=0) g_1_neigh = np.mean(g_1_neigh, axis=0) self.assertAllClose(g_0, g_0_neigh) self.assertAllClose(g_1, g_1_neigh)
def test_pair_neighbor_list_force_scalar_diverging_potential( self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def potential(dr, sigma): return np.where(dr < sigma, dr ** -6, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(potential, d, sigma=1.0) neighbor_square = jit(quantity.force(neighbor_square)) mapped_square = jit(quantity.force(smap.pair(potential, d, sigma=1.0))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform( split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=4.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d)) mapped_square = jit(smap.pair(truncated_square, d)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0.) nbrs = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) tol = 2e-10 if dtype == np.float32 else None N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d)) mapped_square = jit(smap.pair(truncated_square, d)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=2.5) neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol)
def test_periodic_against_periodic_general(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2, split3 = random.split(key, 4) max_box_size = f16(10.0) box_size = max_box_size * random.uniform(split1, (spatial_dimension, ), dtype=dtype) transform = np.diag(box_size) R = random.uniform(split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R_scaled = R * box_size dR = random.normal(split3, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) disp_fn, shift_fn = space.periodic(box_size) general_disp_fn, general_shift_fn = space.periodic_general( transform) disp_fn = space.map_product(disp_fn) general_disp_fn = space.map_product(general_disp_fn) self.assertAllClose(disp_fn(R_scaled, R_scaled), general_disp_fn(R, R), True) assert disp_fn(R_scaled, R_scaled).dtype == dtype self.assertAllClose(shift_fn(R_scaled, dR), general_shift_fn(R, dR) * box_size, True) assert shift_fn(R_scaled, dR).dtype == dtype
def test_pair_neighbor_list_vector(self, spatial_dimension, dtype, format): if format is partition.OrderedSparse: self.skipTest('Vector valued pair_neighbor_list not supported.') key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1,)) return np.where(dr < sigma, dR ** 2, f32(0.)) N = PARTICLE_COUNT box_size = 2. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit(smap.pair_neighbor_list( truncated_square, disp, sigma=1.0, reduce_axis=(1,))) mapped_square = jit(smap.pair(truncated_square, disp, sigma=1.0, reduce_axis=(1,))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform( split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=1.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def test_pair_neighbor_list_scalar_nonadditive(self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = space.distance(dR) return np.where(dr < sigma, dr**2, f32(0.)) N = PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, disp, sigma=lambda x, y: x * y)) mapped_square = jit(smap.pair(truncated_square, disp, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5) sigma_pair = sigma[:, None] * sigma[None, :] neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma)**2, 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma_pair), neighbor_square(R, nbrs, sigma=sigma))
def test_pair_neighbor_list_scalar_params_species_dynamic( self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dr, sigma, **kwargs): return np.where(dr < sigma, dr ** 2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N ** (1. / spatial_dimension) species = np.zeros((N,), np.int32) species = np.where(np.arange(N) > N / 3, 1, species) species = np.where(np.arange(N) > 2 * N / 3, 2, species) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0) neighbor_square = jit(neighbor_square) mapped_square = smap.pair(truncated_square, d, species=species, sigma=1.0) mapped_square = jit(mapped_square) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose( mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma, species=species))
def test_radial_symmetry_functions_neighbor_list(self, N_types, N_etas, dtype, dim): key = random.PRNGKey(0) N = 128 box_size = 12.0 r_cutoff = 3. displacement, shift = space.periodic(box_size) R_key, species_key = random.split(key) R = box_size * random.uniform(R_key, (N, dim)) species = random.choice(species_key, N_types, (N,)) neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff, 0.) gr = nn.radial_symmetry_functions(displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype), r_cutoff) gr_neigh = nn.radial_symmetry_functions_neighbor_list( displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype), r_cutoff) nbrs = neighbor_fn(R) gr_exact = gr(R) gr_nbrs = gr_neigh(R, neighbor=nbrs) self.assertAllClose(gr_exact, gr_nbrs)
def test_pair_neighbor_list_vector(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, )) return np.where(dr < sigma, dR**2, f32(0.)) tol = 5e-6 if dtype == np.float32 else 1e-14 N = PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, disp, reduce_axis=(1, ))) mapped_square = jit( smap.pair(truncated_square, disp, reduce_axis=(1, ))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=1.5) neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol)
def test_stress_non_minimized_periodic(self, dim, dtype): key = random.PRNGKey(0) N = 64 box = quantity.box_size_at_number_density(N, 0.8, dim) displacement_fn, _ = space.periodic(box) pos = random.uniform(key, (N, dim)) * box energy_fn = energy.soft_sphere_pair(displacement_fn) def exact_stress(R): dR = space.map_product(displacement_fn)(R, R) dr = space.distance(dR) g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()') V = quantity.volume(dim, box) dUdr = 0.5 * g(dr)[:, :, None, None] dr = (dr + jnp.eye(N))[:, :, None, None] return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr), axis=(0, 1)) exact_stress = exact_stress(pos) ad_stress = quantity.stress(energy_fn, pos, box) tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
def test_periodic_against_periodic_general_grad(self, spatial_dimension, dtype): key = random.PRNGKey(0) tol = 1e-13 if dtype is f32: tol = 1e-5 for _ in range(STOCHASTIC_SAMPLES): key, split1, split2, split3 = random.split(key, 4) max_box_size = f32(10.0) box_size = max_box_size * random.uniform( split1, (spatial_dimension,), dtype=dtype) transform = jnp.diag(box_size) R = random.uniform( split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R_scaled = R * box_size dR = random.normal( split3, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) disp_fn, shift_fn = space.periodic(box_size) general_disp_fn, general_shift_fn = space.periodic_general(transform) disp_fn = space.map_product(disp_fn) general_disp_fn = space.map_product(general_disp_fn) grad_fn = grad(lambda R: jnp.sum(disp_fn(R, R) ** 2)) general_grad_fn = grad(lambda R: jnp.sum(general_disp_fn(R, R) ** 2)) self.assertAllClose(grad_fn(R_scaled), general_grad_fn(R)) assert general_grad_fn(R).dtype == dtype
def test_nvt_nose_hoover_jammed(self, dtype, sy_steps): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) invariant = partial(simulate.nvt_nose_hoover_invariant, E) kT = 1e-3 init_fn, apply_fn = simulate.nvt_nose_hoover(E, shift_fn, 1e-3, kT=kT, sy_steps=sy_steps) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position) E_initial = invariant(state, kT) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, invariant(state, kT)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_nve_jammed(self, spatial_dimension, dtype): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position, kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) tol = 2e-6 if dtype == np.float32 else None N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) species = np.zeros((N, ), np.int32) species = np.where(np.arange(N) > N / 3, 1, species) species = np.where(np.arange(N) > 2 * N / 3, 2, species) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, d, species=species)) mapped_square = jit(smap.pair(truncated_square, d, species=species)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = jit( partition.neighbor_list(disp, box_size, np.max(sigma), R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol)
def test_bks(self, dtype): LATCON = 3.5660930663857577e+01 displacement, shift = space.periodic(LATCON) dist_fun = space.metric(displacement) species = np.tile(np.array([0, 1, 1]), 1000) R_f = test_util.load_silica_data() energy_fn = energy.bks_silica_pair(dist_fun, species=species) self.assertAllClose(-857939.528386092, energy_fn(R_f))
def test_bks_neighbor_list(self, dtype, format): LATCON = 3.5660930663857577e+01 displacement, shift = space.periodic(LATCON) dist_fun = space.metric(displacement) species = np.tile(np.array([0, 1, 1]), 1000) R_f = test_util.load_silica_data() neighbor_fn, energy_nei = energy.bks_silica_neighbor_list( dist_fun, LATCON, species=species, format=format) nbrs = neighbor_fn.allocate(R_f) self.assertAllClose(-857939.528386092, energy_nei(R_f, nbrs))
def test_bks(self, dtype): LATCON = 3.5660930663857577e+01 displacement, shift = space.periodic(LATCON) dist_fun = space.metric(displacement) species = np.tile(np.array([0, 1, 1]), 1000) current_dir = os.getcwd() filename = os.path.join(current_dir, 'tests/data/silica_positions.npy') with open(filename, 'rb') as f: R_f = np.array(np.load(f)) energy_fn = energy.bks_silica_pair(dist_fun, species=species) self.assertAllClose(-857939.528386092, energy_fn(R_f))
def test_soft_sphere_cell_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) exact_energy_fn = energy.soft_sphere_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) energy_fn = energy.soft_sphere_cell_list(displacement, box_size, R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R), True)
def test_pressure_jammed_periodic(self, dtype): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(jnp.diag(state.box)) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) pos = state.real_position tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure, atol=tol, rtol=tol)
def test_nve_neighbor_list(self, spatial_dimension, dtype): Nx = particles_per_side = 8 spacing = f32(1.25) tol = 5e-12 if dtype == np.float64 else 5e-3 L = Nx * spacing if spatial_dimension == 2: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing elif spatial_dimension == 3: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx, Nx)]) * spacing R = np.array(R, dtype) displacement, shift = space.periodic(L) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, L) exact_energy_fn = energy.lennard_jones_pair(displacement) init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn(R) state = init_fn(random.PRNGKey(0), R, neighbor=nbrs) exact_state = exact_init_fn(random.PRNGKey(0), R) def body_fn(i, state): state, nbrs, exact_state = state nbrs = neighbor_fn(state.position, nbrs) state = apply_fn(state, neighbor=nbrs) return state, nbrs, exact_apply_fn(exact_state) step = 0 for i in range(20): new_state, nbrs, new_exact_state = lax.fori_loop( 0, 100, body_fn, (state, nbrs, exact_state)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn(state.position) else: state = new_state exact_state = new_exact_state step += 1 assert state.position.dtype == dtype self.assertAllClose(state.position, exact_state.position, atol=tol, rtol=tol)
def test_behler_parrinello_network(self, N_types, dtype): key = random.PRNGKey(1) R = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 0]], dtype) species = np.array([1, 1, N_types]) if N_types > 1 else None box_size = f32(1.5) displacement, _ = space.periodic(box_size) nn_init, nn_apply = energy.behler_parrinello(displacement, species) params = nn_init(key, R) nn_force_fn = grad(nn_apply, argnums=1) nn_force = jit(nn_force_fn)(params, R) nn_energy = jit(nn_apply)(params, R) self.assertAllClose(np.any(np.isnan(nn_energy)), False) self.assertAllClose(np.any(np.isnan(nn_force)), False) self.assertAllClose(nn_force.shape, [3, 3])
def test_pair_grid_force_incommensurate(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(12.1) cell_size = f32(3.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pair(displacement, quantity.Dynamic) force_fn = quantity.force(energy_fn) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) grid_force_fn = jit(smap.grid(force_fn, box_size, cell_size, R)) species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64) self.assertAllClose(np.array(force_fn(R, species, 1), dtype=dtype), grid_force_fn(R), True)
def main(unused_argv): key = random.PRNGKey(0) # Setup some variables describing the system. N = 500 dimension = 2 box_size = f32(25.0) # Create helper functions to define a periodic box of some size. displacement, shift = space.periodic(box_size) metric = space.metric(displacement) # Use JAX's random number generator to generate random initial positions. key, split = random.split(key) R = random.uniform(split, (N, dimension), minval=0.0, maxval=box_size, dtype=f32) # The system ought to be a 50:50 mixture of two types of particles, one # large and one small. sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32) N_2 = int(N / 2) species = np.array([0] * N_2 + [1] * N_2, dtype=i32) # Create an energy function. energy_fn = energy.soft_sphere_pair(displacement, species, sigma) force_fn = quantity.force(energy_fn) # Create a minimizer. init_fn, apply_fn = minimize.fire_descent(energy_fn, shift) opt_state = init_fn(R) # Minimize the system. minimize_steps = 50 print_every = 10 print('Minimizing.') print('Step\tEnergy\tMax Force') print('-----------------------------------') for step in range(minimize_steps): opt_state = apply_fn(opt_state) if step % print_every == 0: R = opt_state.position print('{:.2f}\t{:.2f}\t{:.2f}'.format(step, energy_fn(R), np.max(force_fn(R))))
def test_morse_small_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(5.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.morse_pair(displacement) R = box_size * random.uniform(key, (10, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.morse_neighbor_list( displacement, box_size) nbrs = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, nbrs))
def test_pair_cell_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(9.0) cell_size = f32(1.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.soft_sphere_pair(displacement) energy_fn = smap.cartesian_product(energy.soft_sphere, metric) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) cell_energy_fn = smap.cell_list(energy_fn, box_size, cell_size, R) self.assertAllClose( np.array(exact_energy_fn(R), dtype=dtype), cell_energy_fn(R), True)
def test_cell_list_incommensurate(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(12.1) cell_size = f32(3.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pair(displacement) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) cell_list_energy = smap.cartesian_product( energy.soft_sphere, space.metric(displacement)) cell_list_energy = \ jit(smap.cell_list(cell_list_energy, box_size, cell_size, R)) self.assertAllClose( np.array(energy_fn(R), dtype=dtype), cell_list_energy(R), True)
def test_lennard_jones_cell_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) force_fn = quantity.force( energy.lennard_jones_cell_list(displacement, box_size, R)) self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype), force_fn(R), True)
def test_morse_neighbor_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force(energy.morse_pair(displacement)) r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.morse_neighbor_list( displacement, box_size) force_fn = quantity.force(energy_fn) nbrs = neighbor_fn(r) self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs))