Пример #1
0
    def test_soft_sphere(self, spatial_dimension, alpha, dtype):
        key = random.PRNGKey(0)
        alpha = f32(alpha)
        for _ in range(STOCHASTIC_SAMPLES):
            key, split_sigma, split_epsilon = random.split(key, 3)
            sigma = np.array(random.uniform(split_sigma, (1, ),
                                            minval=0.0,
                                            maxval=3.0)[0],
                             dtype=dtype)
            epsilon = np.array(random.uniform(split_epsilon, (1, ),
                                              minval=0.0,
                                              maxval=4.0)[0],
                               dtype=dtype)
            self.assertAllClose(
                energy.soft_sphere(dtype(0), sigma, epsilon, alpha),
                epsilon / alpha)
            self.assertAllClose(
                energy.soft_sphere(dtype(sigma), sigma, epsilon, alpha),
                np.array(0.0, dtype=dtype))
            self.assertAllClose(
                grad(energy.soft_sphere)(dtype(2 * sigma), sigma, epsilon,
                                         alpha), np.array(0.0, dtype=dtype))

            if alpha > 2.0:
                grad_energy = grad(energy.soft_sphere)
                g = grad_energy(dtype(sigma), sigma, epsilon, alpha)
                self.assertAllClose(g, np.array(0, dtype=dtype))
Пример #2
0
  def test_swap_mc_jammed(self, dtype):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    space_fn = space.periodic(state.box[0, 0])
    displacement_fn, shift_fn = space_fn

    sigma = np.diag(state.sigma)[state.species]

    energy_fn = lambda dr, sigma: energy.soft_sphere(dr, sigma=sigma)
    neighbor_fn = partition.neighbor_list(displacement_fn,
                                          state.box[0, 0],
                                          np.max(sigma) + 0.1,
                                          dr_threshold=0.5)

    kT = 1e-2
    t_md = 0.1
    N_swap = 10
    init_fn, apply_fn = simulate.hybrid_swap_mc(space_fn,
                                                energy_fn,
                                                neighbor_fn,
                                                1e-3,
                                                kT,
                                                t_md,
                                                N_swap)
    state = init_fn(key, state.real_position, sigma)

    Ts = np.zeros((DYNAMICS_STEPS,))

    def step_fn(i, state_and_temp):
      state, temp = state_and_temp
      state = apply_fn(state)
      temp = temp.at[i].set(quantity.temperature(state.md.velocity))
      return state, temp

    state, Ts = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Ts))

    tol = 5e-4
    self.assertAllClose(Ts[10:],
                        kT * np.ones((DYNAMICS_STEPS - 10)),
                        rtol=5e-1,
                        atol=5e-3)
    self.assertAllClose(np.mean(Ts[10:]), kT, rtol=tol, atol=tol)
    self.assertTrue(not np.all(state.sigma == sigma))