def test_stress_non_minimized_periodic_general(self, dim, dtype, coords): key = random.PRNGKey(0) N = 64 box = quantity.box_size_at_number_density(N, 0.8, dim) displacement_fn, _ = space.periodic_general(box, coords == 'fractional') pos = random.uniform(key, (N, dim)) pos = pos if coords == 'fractional' else pos * box energy_fn = energy.soft_sphere_pair(displacement_fn) def exact_stress(R): dR = space.map_product(displacement_fn)(R, R) dr = space.distance(dR) g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()') V = quantity.volume(dim, box) dUdr = 0.5 * g(dr)[:, :, None, None] dr = (dr + jnp.eye(N))[:, :, None, None] return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr), axis=(0, 1)) exact_stress = exact_stress(pos) ad_stress = quantity.stress(energy_fn, pos, box) tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
def test_stress_lammps_periodic_general(self, dim, dtype): key = random.PRNGKey(0) N = 64 (box, R, V), (E, C) = test_util.load_lammps_stress_data(dtype) displacement_fn, _ = space.periodic_general(box) energy_fn = smap.pair( lambda dr, **kwargs: jnp.where(dr < f32(2.5), energy.lennard_jones(dr), f32(0.0)), space.canonicalize_displacement_or_metric(displacement_fn)) ad_stress = quantity.stress(energy_fn, R, box, velocity=V) tol = 5e-5 self.assertAllClose(energy_fn(R) / len(R), E, atol=tol, rtol=tol) self.assertAllClose(C, ad_stress, atol=tol, rtol=tol)