Beispiel #1
0
    def test_pressure_non_minimized_periodic_general(self, dim, dtype, coords):
        key = random.PRNGKey(0)
        N = 64

        box = quantity.box_size_at_number_density(N, 0.8, dim)
        displacement_fn, _ = space.periodic_general(box,
                                                    coords == 'fractional')

        pos = random.uniform(key, (N, dim))
        pos = pos if coords == 'fractional' else pos * box

        energy_fn = energy.soft_sphere_pair(displacement_fn)

        def exact_stress(R):
            dR = space.map_product(displacement_fn)(R, R)
            dr = space.distance(dR)
            g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
            V = quantity.volume(dim, box)
            dUdr = 0.5 * g(dr)[:, :, None, None]
            dr = (dr + jnp.eye(N))[:, :, None, None]
            return jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] /
                           (V * dr),
                           axis=(0, 1))

        exact_pressure = -1 / dim * jnp.trace(exact_stress(pos))
        ad_pressure = quantity.pressure(energy_fn, pos, box)

        tol = 1e-7 if dtype is f64 else 2e-5

        self.assertAllClose(exact_pressure, ad_pressure, atol=tol, rtol=tol)
Beispiel #2
0
  def test_pressure_jammed(self, dtype, coords):
    key = random.PRNGKey(0)

    state = test_util.load_test_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic_general(state.box,
                                                       coords == 'fractional')

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    pos = getattr(state, coords + '_position')

    self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure)
Beispiel #3
0
  def test_pressure_jammed_periodic(self, dtype):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic(jnp.diag(state.box))

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    pos = state.real_position

    tol = 1e-7 if dtype is f64 else 2e-5

    self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure,
                        atol=tol, rtol=tol)