Exemple #1
0
    def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype,
                                               format):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(
            energy.lennard_jones_pair(displacement))

        r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, box_size, format=format)
        force_fn = quantity.force(energy_fn)

        nbrs = neighbor_fn.allocate(r)
        if dtype == f32 and format is partition.OrderedSparse:
            self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype),
                                force_fn(r, nbrs),
                                atol=5e-5,
                                rtol=5e-5)
        else:
            self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype),
                                force_fn(r, nbrs))
Exemple #2
0
    def test_pair_correlation_neighbor_list_species(self, dim, dtype, format):
        if format is partition.OrderedSparse:
            self.skipTest('OrderedSparse not supported for pair correlation '
                          'function.')

        N = 100
        L = 10.
        displacement, _ = space.periodic(L)
        R = random.uniform(random.PRNGKey(0), (N, dim), dtype=dtype)
        species = np.where(np.arange(N) < N // 2, 0, 1)
        rs = np.linspace(0, 2, 60, dtype=dtype)
        g = quantity.pair_correlation(displacement, rs, f32(0.1), species)
        nbr_fn, g_neigh = quantity.pair_correlation_neighbor_list(
            displacement, L, rs, f32(0.1), species, format=format)
        nbrs = nbr_fn.allocate(R)

        g_0, g_1 = g(R)
        g_0 = np.mean(g_0, axis=0)
        g_1 = np.mean(g_1, axis=0)

        g_0_neigh, g_1_neigh = g_neigh(R, neighbor=nbrs)
        g_0_neigh = np.mean(g_0_neigh, axis=0)
        g_1_neigh = np.mean(g_1_neigh, axis=0)
        self.assertAllClose(g_0, g_0_neigh)
        self.assertAllClose(g_1, g_1_neigh)
Exemple #3
0
  def test_pair_neighbor_list_force_scalar_diverging_potential(
      self, spatial_dimension, dtype, format):
    key = random.PRNGKey(0)

    def potential(dr, sigma):
      return np.where(dr < sigma, dr ** -6, f32(0.))

    N = NEIGHBOR_LIST_PARTICLE_COUNT
    box_size = 4. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)
    d = space.metric(disp)

    neighbor_square = smap.pair_neighbor_list(potential, d, sigma=1.0)
    neighbor_square = jit(quantity.force(neighbor_square))
    mapped_square = jit(quantity.force(smap.pair(potential, d, sigma=1.0)))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(
        split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (), minval=0.5, maxval=4.5)
      neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0,
                                            format=format)
      nbrs = neighbor_fn.allocate(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, nbrs, sigma=sigma))
Exemple #4
0
    def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension,
                                                     dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
        mapped_square = jit(smap.pair(truncated_square, d))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = partition.neighbor_list(disp, box_size,
                                                  np.max(sigma), 0.)
            nbrs = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))
Exemple #5
0
    def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        tol = 2e-10 if dtype == np.float32 else None

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 4. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
        mapped_square = jit(smap.pair(truncated_square, d))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
            neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma,
                                                      R))
            idx = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, idx, sigma=sigma), True,
                                tol, tol)
Exemple #6
0
    def test_periodic_against_periodic_general(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2, split3 = random.split(key, 4)

            max_box_size = f16(10.0)
            box_size = max_box_size * random.uniform(split1,
                                                     (spatial_dimension, ),
                                                     dtype=dtype)
            transform = np.diag(box_size)

            R = random.uniform(split2, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R_scaled = R * box_size

            dR = random.normal(split3, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            disp_fn, shift_fn = space.periodic(box_size)
            general_disp_fn, general_shift_fn = space.periodic_general(
                transform)

            disp_fn = space.map_product(disp_fn)
            general_disp_fn = space.map_product(general_disp_fn)

            self.assertAllClose(disp_fn(R_scaled, R_scaled),
                                general_disp_fn(R, R), True)
            assert disp_fn(R_scaled, R_scaled).dtype == dtype
            self.assertAllClose(shift_fn(R_scaled, dR),
                                general_shift_fn(R, dR) * box_size, True)
            assert shift_fn(R_scaled, dR).dtype == dtype
Exemple #7
0
  def test_pair_neighbor_list_vector(self, spatial_dimension, dtype, format):
    if format is partition.OrderedSparse:
      self.skipTest('Vector valued pair_neighbor_list not supported.')
    key = random.PRNGKey(0)

    def truncated_square(dR, sigma):
      dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1,))
      return np.where(dr < sigma, dR ** 2, f32(0.))

    N = PARTICLE_COUNT
    box_size = 2. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)

    neighbor_square = jit(smap.pair_neighbor_list(
      truncated_square, disp, sigma=1.0, reduce_axis=(1,)))
    mapped_square = jit(smap.pair(truncated_square,
                                  disp, sigma=1.0, reduce_axis=(1,)))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(
        split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (), minval=0.5, maxval=1.5)
      neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.,
                                            format=format)
      nbrs = neighbor_fn.allocate(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, nbrs, sigma=sigma))
Exemple #8
0
    def test_pair_neighbor_list_scalar_nonadditive(self, spatial_dimension,
                                                   dtype, format):
        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = space.distance(dR)
            return np.where(dr < sigma, dr**2, f32(0.))

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square,
                                    disp,
                                    sigma=lambda x, y: x * y))
        mapped_square = jit(smap.pair(truncated_square, disp, sigma=1.0))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5)
            sigma_pair = sigma[:, None] * sigma[None, :]
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  np.max(sigma)**2,
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma_pair),
                                neighbor_square(R, nbrs, sigma=sigma))
Exemple #9
0
  def test_pair_neighbor_list_scalar_params_species_dynamic(
      self, spatial_dimension, dtype, format):
    key = random.PRNGKey(0)

    def truncated_square(dr, sigma, **kwargs):
      return np.where(dr < sigma, dr ** 2, f32(0.))

    N = NEIGHBOR_LIST_PARTICLE_COUNT
    box_size = 2. * N ** (1. / spatial_dimension)
    species = np.zeros((N,), np.int32)
    species = np.where(np.arange(N) > N / 3, 1, species)
    species = np.where(np.arange(N) > 2 * N / 3, 2, species)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)
    d = space.metric(disp)

    neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0)
    neighbor_square = jit(neighbor_square)
    mapped_square = smap.pair(truncated_square, d, species=species, sigma=1.0)
    mapped_square = jit(mapped_square)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5)
      sigma = 0.5 * (sigma + sigma.T)
      neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0.,
                                            format=format)
      nbrs = neighbor_fn.allocate(R)
      self.assertAllClose(
        mapped_square(R, sigma=sigma),
        neighbor_square(R, nbrs, sigma=sigma, species=species))
Exemple #10
0
  def test_radial_symmetry_functions_neighbor_list(self,
                                                   N_types,
                                                   N_etas,
                                                   dtype,
                                                   dim):
    key = random.PRNGKey(0)

    N = 128
    box_size = 12.0
    r_cutoff = 3.

    displacement, shift = space.periodic(box_size)
    R_key, species_key = random.split(key)
    R = box_size * random.uniform(R_key, (N, dim))
    species = random.choice(species_key, N_types, (N,))

    neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff, 0.)

    gr = nn.radial_symmetry_functions(displacement,
                                      species, 
                                      np.linspace(1.0, 2.0, N_etas, dtype=dtype), 
                                      r_cutoff)
    gr_neigh = nn.radial_symmetry_functions_neighbor_list(
      displacement,
      species,
      np.linspace(1.0, 2.0, N_etas, dtype=dtype),
      r_cutoff)
    nbrs = neighbor_fn(R)
    gr_exact = gr(R)
    gr_nbrs = gr_neigh(R, neighbor=nbrs)

    self.assertAllClose(gr_exact, gr_nbrs)
Exemple #11
0
    def test_pair_neighbor_list_vector(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, ))
            return np.where(dr < sigma, dR**2, f32(0.))

        tol = 5e-6 if dtype == np.float32 else 1e-14

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square, disp, reduce_axis=(1, )))
        mapped_square = jit(
            smap.pair(truncated_square, disp, reduce_axis=(1, )))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=1.5)
            neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma,
                                                      R))
            idx = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, idx, sigma=sigma), True,
                                tol, tol)
Exemple #12
0
    def test_stress_non_minimized_periodic(self, dim, dtype):
        key = random.PRNGKey(0)
        N = 64

        box = quantity.box_size_at_number_density(N, 0.8, dim)
        displacement_fn, _ = space.periodic(box)

        pos = random.uniform(key, (N, dim)) * box

        energy_fn = energy.soft_sphere_pair(displacement_fn)

        def exact_stress(R):
            dR = space.map_product(displacement_fn)(R, R)
            dr = space.distance(dR)
            g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
            V = quantity.volume(dim, box)
            dUdr = 0.5 * g(dr)[:, :, None, None]
            dr = (dr + jnp.eye(N))[:, :, None, None]
            return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] /
                            (V * dr),
                            axis=(0, 1))

        exact_stress = exact_stress(pos)
        ad_stress = quantity.stress(energy_fn, pos, box)

        tol = 1e-7 if dtype is f64 else 2e-5

        self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
Exemple #13
0
  def test_periodic_against_periodic_general_grad(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    tol = 1e-13
    if dtype is f32:
      tol = 1e-5

    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2, split3 = random.split(key, 4)

      max_box_size = f32(10.0)
      box_size = max_box_size * random.uniform(
        split1, (spatial_dimension,), dtype=dtype)
      transform = jnp.diag(box_size)

      R = random.uniform(
        split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      R_scaled = R * box_size

      dR = random.normal(
        split3, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      disp_fn, shift_fn = space.periodic(box_size)
      general_disp_fn, general_shift_fn = space.periodic_general(transform)

      disp_fn = space.map_product(disp_fn)
      general_disp_fn = space.map_product(general_disp_fn)

      grad_fn = grad(lambda R: jnp.sum(disp_fn(R, R) ** 2))
      general_grad_fn = grad(lambda R: jnp.sum(general_disp_fn(R, R) ** 2))

      self.assertAllClose(grad_fn(R_scaled), general_grad_fn(R))
      assert general_grad_fn(R).dtype == dtype
Exemple #14
0
    def test_nvt_nose_hoover_jammed(self, dtype, sy_steps):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic(state.box[0, 0])

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)
        invariant = partial(simulate.nvt_nose_hoover_invariant, E)

        kT = 1e-3
        init_fn, apply_fn = simulate.nvt_nose_hoover(E,
                                                     shift_fn,
                                                     1e-3,
                                                     kT=kT,
                                                     sy_steps=sy_steps)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.real_position)

        E_initial = invariant(state, kT) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, invariant(state, kT))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
Exemple #15
0
    def test_nve_jammed(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic(state.box[0, 0])

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)

        init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.real_position, kT=1e-3)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, E_T(state))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
Exemple #16
0
    def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension,
                                                      dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        tol = 2e-6 if dtype == np.float32 else None

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)
        species = np.zeros((N, ), np.int32)
        species = np.where(np.arange(N) > N / 3, 1, species)
        species = np.where(np.arange(N) > 2 * N / 3, 2, species)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square, d, species=species))
        mapped_square = jit(smap.pair(truncated_square, d, species=species))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = jit(
                partition.neighbor_list(disp, box_size, np.max(sigma), R))
            idx = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, idx, sigma=sigma), True,
                                tol, tol)
Exemple #17
0
 def test_bks(self, dtype):
     LATCON = 3.5660930663857577e+01
     displacement, shift = space.periodic(LATCON)
     dist_fun = space.metric(displacement)
     species = np.tile(np.array([0, 1, 1]), 1000)
     R_f = test_util.load_silica_data()
     energy_fn = energy.bks_silica_pair(dist_fun, species=species)
     self.assertAllClose(-857939.528386092, energy_fn(R_f))
Exemple #18
0
 def test_bks_neighbor_list(self, dtype, format):
   LATCON = 3.5660930663857577e+01
   displacement, shift = space.periodic(LATCON)
   dist_fun = space.metric(displacement)
   species = np.tile(np.array([0, 1, 1]), 1000)
   R_f = test_util.load_silica_data()
   neighbor_fn, energy_nei = energy.bks_silica_neighbor_list(
     dist_fun, LATCON, species=species, format=format)
   nbrs = neighbor_fn.allocate(R_f)
   self.assertAllClose(-857939.528386092, energy_nei(R_f, nbrs))
Exemple #19
0
 def test_bks(self, dtype):
     LATCON = 3.5660930663857577e+01
     displacement, shift = space.periodic(LATCON)
     dist_fun = space.metric(displacement)
     species = np.tile(np.array([0, 1, 1]), 1000)
     current_dir = os.getcwd()
     filename = os.path.join(current_dir, 'tests/data/silica_positions.npy')
     with open(filename, 'rb') as f:
         R_f = np.array(np.load(f))
     energy_fn = energy.bks_silica_pair(dist_fun, species=species)
     self.assertAllClose(-857939.528386092, energy_fn(R_f))
Exemple #20
0
    def test_soft_sphere_cell_list_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        exact_energy_fn = energy.soft_sphere_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        energy_fn = energy.soft_sphere_cell_list(displacement, box_size, R)

        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R), True)
Exemple #21
0
  def test_pressure_jammed_periodic(self, dtype):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic(jnp.diag(state.box))

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    pos = state.real_position

    tol = 1e-7 if dtype is f64 else 2e-5

    self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure,
                        atol=tol, rtol=tol)
Exemple #22
0
    def test_nve_neighbor_list(self, spatial_dimension, dtype):
        Nx = particles_per_side = 8
        spacing = f32(1.25)

        tol = 5e-12 if dtype == np.float64 else 5e-3

        L = Nx * spacing
        if spatial_dimension == 2:
            R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing
        elif spatial_dimension == 3:
            R = np.stack([np.array(r)
                          for r in onp.ndindex(Nx, Nx, Nx)]) * spacing

        R = np.array(R, dtype)

        displacement, shift = space.periodic(L)

        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, L)
        exact_energy_fn = energy.lennard_jones_pair(displacement)

        init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
        exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift,
                                                     1e-3)

        nbrs = neighbor_fn(R)
        state = init_fn(random.PRNGKey(0), R, neighbor=nbrs)
        exact_state = exact_init_fn(random.PRNGKey(0), R)

        def body_fn(i, state):
            state, nbrs, exact_state = state
            nbrs = neighbor_fn(state.position, nbrs)
            state = apply_fn(state, neighbor=nbrs)
            return state, nbrs, exact_apply_fn(exact_state)

        step = 0
        for i in range(20):
            new_state, nbrs, new_exact_state = lax.fori_loop(
                0, 100, body_fn, (state, nbrs, exact_state))
            if nbrs.did_buffer_overflow:
                nbrs = neighbor_fn(state.position)
            else:
                state = new_state
                exact_state = new_exact_state
                step += 1
        assert state.position.dtype == dtype
        self.assertAllClose(state.position,
                            exact_state.position,
                            atol=tol,
                            rtol=tol)
Exemple #23
0
 def test_behler_parrinello_network(self, N_types, dtype):
     key = random.PRNGKey(1)
     R = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 0]], dtype)
     species = np.array([1, 1, N_types]) if N_types > 1 else None
     box_size = f32(1.5)
     displacement, _ = space.periodic(box_size)
     nn_init, nn_apply = energy.behler_parrinello(displacement, species)
     params = nn_init(key, R)
     nn_force_fn = grad(nn_apply, argnums=1)
     nn_force = jit(nn_force_fn)(params, R)
     nn_energy = jit(nn_apply)(params, R)
     self.assertAllClose(np.any(np.isnan(nn_energy)), False)
     self.assertAllClose(np.any(np.isnan(nn_force)), False)
     self.assertAllClose(nn_force.shape, [3, 3])
Exemple #24
0
    def test_pair_grid_force_incommensurate(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(12.1)
        cell_size = f32(3.0)
        displacement, _ = space.periodic(box_size)
        energy_fn = energy.soft_sphere_pair(displacement, quantity.Dynamic)
        force_fn = quantity.force(energy_fn)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        grid_force_fn = jit(smap.grid(force_fn, box_size, cell_size, R))
        species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64)
        self.assertAllClose(np.array(force_fn(R, species, 1), dtype=dtype),
                            grid_force_fn(R), True)
Exemple #25
0
def main(unused_argv):
    key = random.PRNGKey(0)

    # Setup some variables describing the system.
    N = 500
    dimension = 2
    box_size = f32(25.0)

    # Create helper functions to define a periodic box of some size.
    displacement, shift = space.periodic(box_size)

    metric = space.metric(displacement)

    # Use JAX's random number generator to generate random initial positions.
    key, split = random.split(key)
    R = random.uniform(split, (N, dimension),
                       minval=0.0,
                       maxval=box_size,
                       dtype=f32)

    # The system ought to be a 50:50 mixture of two types of particles, one
    # large and one small.
    sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32)
    N_2 = int(N / 2)
    species = np.array([0] * N_2 + [1] * N_2, dtype=i32)

    # Create an energy function.
    energy_fn = energy.soft_sphere_pair(displacement, species, sigma)
    force_fn = quantity.force(energy_fn)

    # Create a minimizer.
    init_fn, apply_fn = minimize.fire_descent(energy_fn, shift)
    opt_state = init_fn(R)

    # Minimize the system.
    minimize_steps = 50
    print_every = 10

    print('Minimizing.')
    print('Step\tEnergy\tMax Force')
    print('-----------------------------------')
    for step in range(minimize_steps):
        opt_state = apply_fn(opt_state)

        if step % print_every == 0:
            R = opt_state.position
            print('{:.2f}\t{:.2f}\t{:.2f}'.format(step, energy_fn(R),
                                                  np.max(force_fn(R))))
Exemple #26
0
    def test_morse_small_neighbor_list_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(5.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_energy_fn = energy.morse_pair(displacement)

        R = box_size * random.uniform(key, (10, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.morse_neighbor_list(
            displacement, box_size)

        nbrs = neighbor_fn(R)
        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, nbrs))
Exemple #27
0
  def test_pair_cell_list_energy(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(9.0)
    cell_size = f32(1.0)
    displacement, _ = space.periodic(box_size)
    metric = space.metric(displacement)
    exact_energy_fn = energy.soft_sphere_pair(displacement)
    energy_fn = smap.cartesian_product(energy.soft_sphere, metric)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    cell_energy_fn = smap.cell_list(energy_fn, box_size, cell_size, R)
    self.assertAllClose(
      np.array(exact_energy_fn(R), dtype=dtype),
      cell_energy_fn(R), True)
Exemple #28
0
  def test_cell_list_incommensurate(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(12.1)
    cell_size = f32(3.0)
    displacement, _ = space.periodic(box_size)
    energy_fn = energy.soft_sphere_pair(displacement)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    cell_list_energy = smap.cartesian_product(
      energy.soft_sphere, space.metric(displacement))
    cell_list_energy = \
      jit(smap.cell_list(cell_list_energy, box_size, cell_size, R))
    self.assertAllClose(
      np.array(energy_fn(R), dtype=dtype), cell_list_energy(R), True)
Exemple #29
0
    def test_lennard_jones_cell_list_force(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(
            energy.lennard_jones_pair(displacement))

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        force_fn = quantity.force(
            energy.lennard_jones_cell_list(displacement, box_size, R))

        self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype),
                            force_fn(R), True)
Exemple #30
0
    def test_morse_neighbor_list_force(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(energy.morse_pair(displacement))

        r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.morse_neighbor_list(
            displacement, box_size)
        force_fn = quantity.force(energy_fn)

        nbrs = neighbor_fn(r)
        self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype),
                            force_fn(r, nbrs))