Esempio n. 1
0
    def test_stress_non_minimized_periodic_general(self, dim, dtype, coords):
        key = random.PRNGKey(0)
        N = 64

        box = quantity.box_size_at_number_density(N, 0.8, dim)
        displacement_fn, _ = space.periodic_general(box,
                                                    coords == 'fractional')

        pos = random.uniform(key, (N, dim))
        pos = pos if coords == 'fractional' else pos * box

        energy_fn = energy.soft_sphere_pair(displacement_fn)

        def exact_stress(R):
            dR = space.map_product(displacement_fn)(R, R)
            dr = space.distance(dR)
            g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
            V = quantity.volume(dim, box)
            dUdr = 0.5 * g(dr)[:, :, None, None]
            dr = (dr + jnp.eye(N))[:, :, None, None]
            return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] /
                            (V * dr),
                            axis=(0, 1))

        exact_stress = exact_stress(pos)
        ad_stress = quantity.stress(energy_fn, pos, box)

        tol = 1e-7 if dtype is f64 else 2e-5

        self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
Esempio n. 2
0
    def test_pressure_non_minimized_free(self, dim, dtype):
        key = random.PRNGKey(0)
        N = 64

        box = quantity.box_size_at_number_density(N, 0.8, dim)
        displacement_fn, _ = space.free()

        pos = random.uniform(key, (N, dim)) * box

        energy_fn = energy.soft_sphere_pair(displacement_fn)

        def exact_stress(R):
            dR = space.map_product(displacement_fn)(R, R)
            dr = space.distance(dR)
            g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
            V = quantity.volume(dim, box)
            dUdr = 0.5 * g(dr)[:, :, None, None]
            dr = (dr + jnp.eye(N))[:, :, None, None]
            return jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] /
                           (V * dr),
                           axis=(0, 1))

        exact_pressure = -1 / dim * jnp.trace(exact_stress(pos))
        ad_pressure = quantity.pressure(energy_fn, pos, box)

        tol = 1e-7 if dtype is f64 else 2e-5

        self.assertAllClose(exact_pressure, ad_pressure, atol=tol, rtol=tol)
Esempio n. 3
0
    def test_nvt_nose_hoover(self, spatial_dimension, dtype, sy_steps):
        key = random.PRNGKey(0)

        box_size = quantity.box_size_at_number_density(PARTICLE_COUNT,
                                                       f32(1.2),
                                                       spatial_dimension)
        displacement_fn, shift_fn = space.periodic(box_size)

        bonds_i = np.arange(PARTICLE_COUNT)
        bonds_j = np.roll(bonds_i, 1)
        bonds = np.stack([bonds_i, bonds_j])

        E = energy.simple_spring_bond(displacement_fn, bonds)

        invariant = partial(simulate.nvt_nose_hoover_invariant, E)

        for _ in range(STOCHASTIC_SAMPLES):
            key, pos_key, vel_key, T_key, masses_key = random.split(key, 5)

            R = box_size * random.uniform(pos_key,
                                          (PARTICLE_COUNT, spatial_dimension),
                                          dtype=dtype)
            T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype)
            mass = 1 + random.uniform(masses_key, (PARTICLE_COUNT, ),
                                      dtype=dtype)
            init_fn, apply_fn = simulate.nvt_nose_hoover(E,
                                                         shift_fn,
                                                         1e-3,
                                                         T,
                                                         sy_steps=sy_steps)
            apply_fn = jit(apply_fn)

            state = init_fn(vel_key, R, mass=mass)

            initial = invariant(state, T)

            for _ in range(DYNAMICS_STEPS):
                state = apply_fn(state)

            T_final = quantity.temperature(state.velocity, state.mass)
            assert np.abs(T_final - T) / T < 0.1
            tol = 5e-4 if dtype is f32 else 1e-6
            self.assertAllClose(invariant(state, T), initial, rtol=tol)
            self.assertEqual(state.position.dtype, dtype)
Esempio n. 4
0
def make_periodic_general_test_system(N, dim, dtype, box_format):
  assert box_format in BOX_FORMATS

  box_size = quantity.box_size_at_number_density(N, 1.0, dim)
  box = dtype(box_size)
  if box_format == 'vector':
    box = jnp.array(jnp.ones(dim) * box_size, dtype)
  elif box_format == 'matrix':
    box = jnp.array(jnp.eye(dim) * box_size, dtype)

  d, s = space.periodic(jnp.diag(box) if box_format == 'matrix' else box)
  d_gf, s_gf = space.periodic_general(box)
  d_g, s_g = space.periodic_general(box, fractional_coordinates=False)

  key = random.PRNGKey(0)

  R_f = random.uniform(key, (N, dim), dtype=dtype)
  R = space.transform(box, R_f)

  E = jit(energy.soft_sphere_pair(d))
  E_gf = jit(energy.soft_sphere_pair(d_gf))
  E_g = jit(energy.soft_sphere_pair(d_g))

  return R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g)