def test_periodic_general_wrapped_vs_unwrapped(self, spatial_dimension, dtype): key = random.PRNGKey(0) eye = np.eye(spatial_dimension, dtype=dtype) for _ in range(STOCHASTIC_SAMPLES): key, split_R, split_T = random.split(key, 3) dT = random.normal(split_T, (spatial_dimension, spatial_dimension), dtype=dtype) T = eye + dT + np.transpose(dT) R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = R unwrapped_R = R displacement, shift = space.periodic_general(T) _, unwrapped_shift = space.periodic_general(T, wrapped=False) displacement = space.map_product(displacement) for _ in range(SHIFT_STEPS): key, split = random.split(key) dR = random.normal(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R = shift(R, dR) unwrapped_R = unwrapped_shift(unwrapped_R, dR) self.assertAllClose(displacement(R, R0), displacement(unwrapped_R, R0), True) assert not (np.all(unwrapped_R > 0) and np.all(unwrapped_R < 1))
def test_periodic_against_periodic_general_grad(self, spatial_dimension, dtype): key = random.PRNGKey(0) tol = 1e-13 if dtype is f32: tol = 1e-5 for _ in range(STOCHASTIC_SAMPLES): key, split1, split2, split3 = random.split(key, 4) max_box_size = f32(10.0) box_size = max_box_size * random.uniform( split1, (spatial_dimension,), dtype=dtype) transform = jnp.diag(box_size) R = random.uniform( split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R_scaled = R * box_size dR = random.normal( split3, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) disp_fn, shift_fn = space.periodic(box_size) general_disp_fn, general_shift_fn = space.periodic_general(transform) disp_fn = space.map_product(disp_fn) general_disp_fn = space.map_product(general_disp_fn) grad_fn = grad(lambda R: jnp.sum(disp_fn(R, R) ** 2)) general_grad_fn = grad(lambda R: jnp.sum(general_disp_fn(R, R) ** 2)) self.assertAllClose(grad_fn(R_scaled), general_grad_fn(R)) assert general_grad_fn(R).dtype == dtype
def test_periodic_against_periodic_general(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2, split3 = random.split(key, 4) max_box_size = f16(10.0) box_size = max_box_size * random.uniform(split1, (spatial_dimension, ), dtype=dtype) transform = np.diag(box_size) R = random.uniform(split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R_scaled = R * box_size dR = random.normal(split3, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) disp_fn, shift_fn = space.periodic(box_size) general_disp_fn, general_shift_fn = space.periodic_general( transform) disp_fn = space.map_product(disp_fn) general_disp_fn = space.map_product(general_disp_fn) self.assertAllClose(disp_fn(R_scaled, R_scaled), general_disp_fn(R, R), True) assert disp_fn(R_scaled, R_scaled).dtype == dtype self.assertAllClose(shift_fn(R_scaled, dR), general_shift_fn(R, dR) * box_size, True) assert shift_fn(R_scaled, dR).dtype == dtype
def test_nve_jammed_periodic_general(self, dtype, coords): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general( state.box, coords == 'fractional') E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_eam_neighbor_list(self, num_repetitions, dtype, format): if format is partition.OrderedSparse: self.skipTest('OrderedSparse neighbor lists not supported for EAM ' 'potential.') latvec = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) * f32(4.05 / 2) atoms = np.array([[0, 0, 0]], dtype=dtype) atoms_repeated, latvec_repeated = lattice_repeater( atoms, latvec, num_repetitions) inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)), dtype=dtype) R = np.dot(atoms_repeated, inv_latvec) displacement, _ = space.periodic_general(latvec_repeated) box_size = np.linalg.det(latvec_repeated)**(1 / 3) neighbor_fn, energy_fn = energy.eam_neighbor_list( displacement, box_size, *make_eam_test_splines(), format=format) nbrs = neighbor_fn.allocate(R) E = energy_fn(R, nbrs) / num_repetitions**3 if dtype is f64: self.assertAllClose(E, dtype(-3.3633387837793505), atol=1e-8, rtol=1e-8) else: self.assertAllClose(E, dtype(-3.3633387837793505))
def test_stress_non_minimized_periodic_general(self, dim, dtype, coords): key = random.PRNGKey(0) N = 64 box = quantity.box_size_at_number_density(N, 0.8, dim) displacement_fn, _ = space.periodic_general(box, coords == 'fractional') pos = random.uniform(key, (N, dim)) pos = pos if coords == 'fractional' else pos * box energy_fn = energy.soft_sphere_pair(displacement_fn) def exact_stress(R): dR = space.map_product(displacement_fn)(R, R) dr = space.distance(dR) g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()') V = quantity.volume(dim, box) dUdr = 0.5 * g(dr)[:, :, None, None] dr = (dr + jnp.eye(N))[:, :, None, None] return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr), axis=(0, 1)) exact_stress = exact_stress(pos) ad_stress = quantity.stress(energy_fn, pos, box) tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
def test_periodic_general_time_dependence(self, spatial_dimension, dtype): key = random.PRNGKey(0) eye = np.eye(spatial_dimension) for _ in range(STOCHASTIC_SAMPLES): key, split_T0_scale, split_T0_dT = random.split(key, 3) key, split_T1_scale, split_T1_dT = random.split(key, 3) key, split_t, split_R, split_dR = random.split(key, 4) size_0 = 10.0 * random.uniform(split_T0_scale, ()) dtransform_0 = 0.5 * random.normal( split_T0_dT, (spatial_dimension, spatial_dimension)) T_0 = np.array(size_0 * (eye + dtransform_0), dtype=dtype) size_1 = 10.0 * random.uniform(split_T1_scale, (), dtype=dtype) dtransform_1 = 0.5 * random.normal( split_T1_dT, (spatial_dimension, spatial_dimension), dtype=dtype) T_1 = np.array(size_1 * (eye + dtransform_1), dtype=dtype) T = lambda t: t * T_0 + (f32(1.0) - t) * T_1 t_g = random.uniform(split_t, (), dtype=dtype) disp_fn, shift_fn = space.periodic_general(T) true_disp_fn, true_shift_fn = space.periodic_general(T(t_g)) disp_fn = partial(disp_fn, t=t_g) disp_fn = space.map_product(disp_fn) true_disp_fn = space.map_product(true_disp_fn) R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) dR = random.normal(split_dR, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) self.assertAllClose(disp_fn(R, R), np.array(true_disp_fn(R, R), dtype=dtype), True) self.assertAllClose(shift_fn(R, dR, t=t_g), np.array(true_shift_fn(R, dR), dtype=dtype), True)
def test_stillinger_weber(self, dtype, num_repetitions): lattice_vectors = lattice_vectors = np.array([[0, .5, .5], [.5, 0, .5], [.5, .5, 0]]) * 5.428 positions = np.array([[0, 0, 0], [0.25, 0.25, 0.25]]) positions = lattice(positions, num_repetitions, lattice_vectors) lattice_vectors *= num_repetitions displacement, shift = space.periodic_general(lattice_vectors) energy_fn = jit(energy.stillinger_weber_energy(displacement)) self.assertAllClose( energy_fn(positions) / positions.shape[0], -4.336503)
def test_pressure_jammed(self, dtype, coords): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general(state.box, coords == 'fractional') E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) pos = getattr(state, coords + '_position') self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure)
def test_neighbor_list_build_time_dependent(self, dtype, dim): key = random.PRNGKey(1) if dim == 2: box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32) elif dim == 3: box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0], [0.0, 0.0, 7.25]]) min_length = np.min(np.diag(box_fn(0.))) cutoff = f32(1.23) # TODO(schsam): Get cell-list working with anisotropic cell sizes. cell_size = cutoff / min_length displacement, _ = space.periodic_general(box_fn) metric = space.metric(displacement) R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_list_fn = partition.neighbor_list(metric, 1., cutoff, 0.0, 1.1, cell_size=cell_size, t=np.array(0.)) idx = neighbor_list_fn(R, t=np.array(0.25)).idx R_neigh = R[idx] mask = idx < N metric = partial(metric, t=f32(0.25)) d = vmap(vmap(metric, (None, 0))) dR = d(R, R_neigh) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, 0) * mask dR_exact = np.where(dR_exact < cutoff, dR_exact, 0) dR = np.sort(dR, axis=1) dR_exact = np.sort(dR_exact, axis=1) for i in range(dR.shape[0]): dR_row = dR[i] dR_row = dR_row[dR_row > 0.] dR_exact_row = dR_exact[i] dR_exact_row = dR_exact_row[dR_exact_row > 0.] self.assertAllClose(dR_row, dR_exact_row)
def test_periodic_general_dynamic(self, spatial_dimension, dtype): key = random.PRNGKey(0) eye = jnp.eye(spatial_dimension) for _ in range(STOCHASTIC_SAMPLES): key, split_T0_scale, split_T0_dT = random.split(key, 3) key, split_T1_scale, split_T1_dT = random.split(key, 3) key, split_t, split_R, split_dR = random.split(key, 4) size_0 = 10.0 * random.uniform(split_T0_scale, ()) dtransform_0 = 0.5 * random.normal( split_T0_dT, (spatial_dimension, spatial_dimension)) T_0 = jnp.array(size_0 * (eye + dtransform_0), dtype=dtype) size_1 = 10.0 * random.uniform(split_T1_scale, (), dtype=dtype) dtransform_1 = 0.5 * random.normal( split_T1_dT, (spatial_dimension, spatial_dimension), dtype=dtype) T_1 = jnp.array(size_1 * (eye + dtransform_1), dtype=dtype) disp_fn, shift_fn = space.periodic_general(T_0) true_disp_fn, true_shift_fn = space.periodic_general(T_1) disp_fn = partial(disp_fn, box=T_1) disp_fn = space.map_product(disp_fn) true_disp_fn = space.map_product(true_disp_fn) R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) dR = random.normal(split_dR, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) self.assertAllClose(disp_fn(R, R), jnp.array(true_disp_fn(R, R), dtype=dtype)) self.assertAllClose(shift_fn(R, dR, box=T_1), jnp.array(true_shift_fn(R, dR), dtype=dtype))
def make_periodic_general_test_system(N, dim, dtype, box_format): assert box_format in BOX_FORMATS box_size = quantity.box_size_at_number_density(N, 1.0, dim) box = dtype(box_size) if box_format == 'vector': box = jnp.array(jnp.ones(dim) * box_size, dtype) elif box_format == 'matrix': box = jnp.array(jnp.eye(dim) * box_size, dtype) d, s = space.periodic(jnp.diag(box) if box_format == 'matrix' else box) d_gf, s_gf = space.periodic_general(box) d_g, s_g = space.periodic_general(box, fractional_coordinates=False) key = random.PRNGKey(0) R_f = random.uniform(key, (N, dim), dtype=dtype) R = space.transform(box, R_f) E = jit(energy.soft_sphere_pair(d)) E_gf = jit(energy.soft_sphere_pair(d_gf)) E_g = jit(energy.soft_sphere_pair(d_g)) return R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g)
def test_eam(self, num_repetitions, dtype): latvec = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) * f32(4.05 / 2) atoms = np.array([[0, 0, 0]], dtype=dtype) atoms_repeated, latvec_repeated = lattice_repeater( atoms, latvec, num_repetitions) inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated))) displacement, shift = space.periodic_general(latvec_repeated) assert charge_fn(dtype(1.0)).dtype == dtype assert embedding_fn(dtype(1.0)).dtype == dtype assert pairwise_fn(dtype(1.0)).dtype == dtype eam_energy = energy.eam(displacement, charge_fn, embedding_fn, pairwise_fn) self.assertAllClose( eam_energy(np.dot(atoms_repeated, inv_latvec)) / f32(num_repetitions**3), dtype(-3.363338), True)
def test_canonicalize_displacement_or_metric(self, spatial_dimension, dtype): key = random.PRNGKey(0) displacement, _ = space.periodic_general(np.eye(spatial_dimension)) metric = space.metric(displacement) test_metric = space.canonicalize_displacement_or_metric(displacement) metric = space.map_product(metric) test_metric = space.map_product(test_metric) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.normal( split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) self.assertAllClose(metric(R, R), test_metric(R, R), True)
def test_stress_lammps_periodic_general(self, dim, dtype): key = random.PRNGKey(0) N = 64 (box, R, V), (E, C) = test_util.load_lammps_stress_data(dtype) displacement_fn, _ = space.periodic_general(box) energy_fn = smap.pair( lambda dr, **kwargs: jnp.where(dr < f32(2.5), energy.lennard_jones(dr), f32(0.0)), space.canonicalize_displacement_or_metric(displacement_fn)) ad_stress = quantity.stress(energy_fn, R, box, velocity=V) tol = 5e-5 self.assertAllClose(energy_fn(R) / len(R), E, atol=tol, rtol=tol) self.assertAllClose(C, ad_stress, atol=tol, rtol=tol)
def test_EMT_from_db_nbrlist(self, spatial_dimension, dtype, low_pressure): if spatial_dimension == 2: N = 64 else: N = 128 if dtype == jnp.float32: max_grad_thresh = 1e-5 atol = 1e-4 rtol = 1e-3 else: max_grad_thresh = 1e-10 atol = 1e-8 rtol = 1e-5 for index in range(NUM_SAMPLES): cijkl, R, sigma, box = test_util.load_elasticity_test_data( spatial_dimension, low_pressure, dtype, index) displacement, shift = space.periodic_general( box, fractional_coordinates=True) neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list( displacement, box, sigma=sigma, fractional_coordinates=True) nbrs = neighbor_fn.allocate(R) assert (jnp.max(jnp.abs(grad(energy_fn)(R, nbrs))) < max_grad_thresh) EMT_fn = jit( elasticity.athermal_moduli(energy_fn, check_convergence=True)) C, converged = EMT_fn(R, box, neighbor=nbrs) assert (C.dtype == dtype) assert (C.shape == (spatial_dimension, spatial_dimension, spatial_dimension, spatial_dimension)) if converged: self.assertAllClose(cijkl, elasticity._extract_elements(C, False), atol=atol, rtol=rtol) #make sure the symmetries are there self.assertAllClose(C, jnp.einsum("ijkl->jikl", C)) self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C)) self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
def test_eam(self, num_repetitions, dtype): latvec = np.array( [[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) * f32(4.05 / 2) atoms = np.array([[0, 0, 0]], dtype=dtype) atoms_repeated, latvec_repeated = lattice_repeater( atoms, latvec, num_repetitions) inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)), dtype=dtype) displacement, _ = space.periodic_general(latvec_repeated) charge_fn, embedding_fn, pairwise_fn, _ = make_eam_test_splines() assert charge_fn(np.array(1.0, dtype)).dtype == dtype assert embedding_fn(np.array(1.0, dtype)).dtype == dtype assert pairwise_fn(np.array(1.0, dtype)).dtype == dtype eam_energy = energy.eam(displacement, charge_fn, embedding_fn, pairwise_fn) E = eam_energy(np.dot(atoms_repeated, inv_latvec)) / num_repetitions ** 3 if dtype is f64: self.assertAllClose(E, dtype(-3.3633387837793505), atol=1e-8, rtol=1e-8) else: self.assertAllClose(E, dtype(-3.3633387837793505))
def test_stillinger_weber_neighbor_list(self, dtype, num_repetitions, format): if format in [partition.OrderedSparse, partition.Sparse]: self.skipTest(f'{format} not supported for Stillinger-Weber.') lattice_vectors = np.array([[0, .5, .5], [.5, 0, .5], [.5, .5, 0]]) * 5.428 positions = np.array([[0,0,0], [0.25, 0.25, 0.25]]) positions = lattice(positions, num_repetitions, lattice_vectors) lattice_vectors *= num_repetitions displacement, shift = space.periodic_general(lattice_vectors) box_size = np.linalg.det(lattice_vectors) ** (1/3) * num_repetitions neighbor_fn, energy_fn = \ energy.stillinger_weber_neighbor_list(displacement, box_size, fractional_coordinates=True, format=format) nbrs = neighbor_fn.allocate(positions) N = positions.shape[0] self.assertAllClose(energy_fn(positions, neighbor=nbrs) / N, -4.336503155764325)
def test_npt_nose_hoover_jammed(self, dtype, sy_steps): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general(state.box) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) invariant = partial(simulate.npt_nose_hoover_invariant, E) pressure_fn = partial(quantity.pressure, E) nhc_kwargs = {sy_steps: sy_steps} kT = 1e-3 P = state.pressure init_fn, apply_fn = simulate.npt_nose_hoover(E, shift_fn, 1e-3, P, kT, nhc_kwargs, nhc_kwargs) apply_fn = jit(apply_fn) state = init_fn(key, state.fractional_position, state.box) E_initial = invariant(state, P, kT) * np.ones((DYNAMICS_STEPS, )) P_target = P * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_energy_pressure): state, energy, pressure = state_energy_pressure state = apply_fn(state) energy = ops.index_update(energy, i, invariant(state, P, kT)) box = simulate.npt_box(state) KE = quantity.kinetic_energy(state.velocity, state.mass) p = pressure_fn(state.position, box, KE) pressure = ops.index_update(pressure, i, p) return state, energy, pressure Es = np.zeros((DYNAMICS_STEPS, )) Ps = np.zeros((DYNAMICS_STEPS, )) state, Es, Ps = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es, Ps)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol) self.assertAllClose(Ps, P_target, rtol=0.05, atol=0.05)
def test_eam(self, num_repetitions, dtype): latvec = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) * f32(4.05 / 2) atoms = np.array([[0, 0, 0]], dtype=dtype) atoms_repeated, latvec_repeated = lattice_repeater( atoms, latvec, num_repetitions) inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)), dtype=dtype) displacement, _ = space.periodic_general(latvec_repeated) charge_fn, embedding_fn, pairwise_fn = make_eam_test_splines() assert charge_fn(np.array(1.0, dtype)).dtype == dtype assert embedding_fn(np.array(1.0, dtype)).dtype == dtype assert pairwise_fn(np.array(1.0, dtype)).dtype == dtype eam_energy = energy.eam(displacement, charge_fn, embedding_fn, pairwise_fn) tol = 1e-5 if dtype == np.float32 else 1e-6 self.assertAllClose( eam_energy(np.dot(atoms_repeated, inv_latvec)) / np.array(num_repetitions**3, dtype), dtype(-3.363338), True, tol, tol)
def get_displacement(self, atoms: Atoms): if not all(atoms.get_pbc()): return space.free() return space.periodic_general(self.box, fractional_coordinates=False)