示例#1
0
    def test_nve_jammed_periodic_general(self, dtype, coords):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic_general(
            state.box, coords == 'fractional')

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)

        init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, E_T(state))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
示例#2
0
    def test_nvt_nose_hoover_jammed(self, dtype, sy_steps):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic(state.box[0, 0])

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)
        invariant = partial(simulate.nvt_nose_hoover_invariant, E)

        kT = 1e-3
        init_fn, apply_fn = simulate.nvt_nose_hoover(E,
                                                     shift_fn,
                                                     1e-3,
                                                     kT=kT,
                                                     sy_steps=sy_steps)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.real_position)

        E_initial = invariant(state, kT) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, invariant(state, kT))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
示例#3
0
    def test_pressure_non_minimized_free(self, dim, dtype):
        key = random.PRNGKey(0)
        N = 64

        box = quantity.box_size_at_number_density(N, 0.8, dim)
        displacement_fn, _ = space.free()

        pos = random.uniform(key, (N, dim)) * box

        energy_fn = energy.soft_sphere_pair(displacement_fn)

        def exact_stress(R):
            dR = space.map_product(displacement_fn)(R, R)
            dr = space.distance(dR)
            g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
            V = quantity.volume(dim, box)
            dUdr = 0.5 * g(dr)[:, :, None, None]
            dr = (dr + jnp.eye(N))[:, :, None, None]
            return jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] /
                           (V * dr),
                           axis=(0, 1))

        exact_pressure = -1 / dim * jnp.trace(exact_stress(pos))
        ad_pressure = quantity.pressure(energy_fn, pos, box)

        tol = 1e-7 if dtype is f64 else 2e-5

        self.assertAllClose(exact_pressure, ad_pressure, atol=tol, rtol=tol)
示例#4
0
    def test_nve_jammed(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic(state.box[0, 0])

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)

        init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.real_position, kT=1e-3)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = energy.at[i].set(E_T(state))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
示例#5
0
    def test_stress_non_minimized_periodic_general(self, dim, dtype, coords):
        key = random.PRNGKey(0)
        N = 64

        box = quantity.box_size_at_number_density(N, 0.8, dim)
        displacement_fn, _ = space.periodic_general(box,
                                                    coords == 'fractional')

        pos = random.uniform(key, (N, dim))
        pos = pos if coords == 'fractional' else pos * box

        energy_fn = energy.soft_sphere_pair(displacement_fn)

        def exact_stress(R):
            dR = space.map_product(displacement_fn)(R, R)
            dr = space.distance(dR)
            g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
            V = quantity.volume(dim, box)
            dUdr = 0.5 * g(dr)[:, :, None, None]
            dr = (dr + jnp.eye(N))[:, :, None, None]
            return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] /
                            (V * dr),
                            axis=(0, 1))

        exact_stress = exact_stress(pos)
        ad_stress = quantity.stress(energy_fn, pos, box)

        tol = 1e-7 if dtype is f64 else 2e-5

        self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
示例#6
0
    def test_nve_ensemble_time_dependence(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        pos_key, center_key, vel_key, mass_key = random.split(key, 4)
        R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension),
                          dtype=dtype)
        R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension),
                           dtype=dtype)
        mass = random.uniform(mass_key, (PARTICLE_COUNT, ),
                              minval=0.1,
                              maxval=5.0,
                              dtype=dtype)
        displacement, shift = space.free()

        E = energy.soft_sphere_pair(displacement)

        init_fn, apply_fn = simulate.nve(E, shift, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(vel_key, R, mass=mass)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state)

        for t in range(SHORT_DYNAMICS_STEPS):
            state = apply_fn(state, t=t * 1e-3)
            E_total = E_T(state)
            assert np.abs(E_total - E_initial) < E_initial * 0.01
            assert state.position.dtype == dtype
示例#7
0
  def test_pressure_jammed(self, dtype, coords):
    key = random.PRNGKey(0)

    state = test_util.load_test_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic_general(state.box,
                                                       coords == 'fractional')

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    pos = getattr(state, coords + '_position')

    self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure)
示例#8
0
def run(N=32, n_iter=1000, with_jit=True):
    import jax.numpy as jnp
    from jax import random, jit
    from jax_md import space, energy, simulate

    # MD configs
    dt = 1e-1
    temperature = 0.1

    # R: current position
    # dR: displacement
    # displacement(Ra, Rb):
    #   dR = Ra - Rb
    # periodic displacement(Ra, Rb):
    #   dR = Ra - Rb
    #   np.mod(dR + side * f32(0.5), side) - f32(0.5) * side
    # periodic shift:
    #   np.mod(R + dR, side)
    # shift:
    #   R + dR
    displacement, shift = space.free()

    # Simulation init
    # dr: pairwise distances
    # epsilon: interaction energy scale (const)
    # alpha: interaction stiffness
    # dr = distance(R)
    # U(dr) = np.where(dr < 1.0, (1 - dr) ** 2, 0)
    # energy_fn(R) = diagonal_mask(U(dr))
    energy_fn = energy.soft_sphere_pair(displacement)

    # force(energy) = -d(energy)/dR
    # xi = random.normal(R.shape, R.dtype)
    # gamma = 0.1
    # nu = 1 / (mass * gamma)
    # dR = force(R) * dt * nu + np.sqrt(2 * temperature * dt * nu) * xi
    # BrownianState(position, mass, rng)
    pos_key, sim_key = random.split(random.PRNGKey(0))
    R = random.uniform(pos_key, (N, 2), dtype=jnp.float32)
    init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, temperature)
    if with_jit:
        apply_fn = jit(apply_fn)
    state = init_fn(sim_key, R)

    # Start simulation
    times = []
    for i in range(n_iter):
        time_start = time.perf_counter_ns()
        state = apply_fn(state)
        time_end = time.perf_counter_ns()
        times.append(time_end - time_start)

    # Finish with profiling times
    return times
示例#9
0
    def test_soft_sphere_cell_list_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        exact_energy_fn = energy.soft_sphere_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        energy_fn = energy.soft_sphere_cell_list(displacement, box_size, R)

        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R), True)
示例#10
0
  def test_pressure_jammed_periodic(self, dtype):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic(jnp.diag(state.box))

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    pos = state.real_position

    tol = 1e-7 if dtype is f64 else 2e-5

    self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure,
                        atol=tol, rtol=tol)
示例#11
0
    def test_pair_grid_force_incommensurate(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(12.1)
        cell_size = f32(3.0)
        displacement, _ = space.periodic(box_size)
        energy_fn = energy.soft_sphere_pair(displacement, quantity.Dynamic)
        force_fn = quantity.force(energy_fn)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        grid_force_fn = jit(smap.grid(force_fn, box_size, cell_size, R))
        species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64)
        self.assertAllClose(np.array(force_fn(R, species, 1), dtype=dtype),
                            grid_force_fn(R), True)
示例#12
0
def make_periodic_general_test_system(N, dim, dtype, box_format):
  assert box_format in BOX_FORMATS

  box_size = quantity.box_size_at_number_density(N, 1.0, dim)
  box = dtype(box_size)
  if box_format == 'vector':
    box = jnp.array(jnp.ones(dim) * box_size, dtype)
  elif box_format == 'matrix':
    box = jnp.array(jnp.eye(dim) * box_size, dtype)

  d, s = space.periodic(jnp.diag(box) if box_format == 'matrix' else box)
  d_gf, s_gf = space.periodic_general(box)
  d_g, s_g = space.periodic_general(box, fractional_coordinates=False)

  key = random.PRNGKey(0)

  R_f = random.uniform(key, (N, dim), dtype=dtype)
  R = space.transform(box, R_f)

  E = jit(energy.soft_sphere_pair(d))
  E_gf = jit(energy.soft_sphere_pair(d_gf))
  E_g = jit(energy.soft_sphere_pair(d_g))

  return R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g)
示例#13
0
def main(unused_argv):
    key = random.PRNGKey(0)

    # Setup some variables describing the system.
    N = 500
    dimension = 2
    box_size = f32(25.0)

    # Create helper functions to define a periodic box of some size.
    displacement, shift = space.periodic(box_size)

    metric = space.metric(displacement)

    # Use JAX's random number generator to generate random initial positions.
    key, split = random.split(key)
    R = random.uniform(split, (N, dimension),
                       minval=0.0,
                       maxval=box_size,
                       dtype=f32)

    # The system ought to be a 50:50 mixture of two types of particles, one
    # large and one small.
    sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32)
    N_2 = int(N / 2)
    species = np.array([0] * N_2 + [1] * N_2, dtype=i32)

    # Create an energy function.
    energy_fn = energy.soft_sphere_pair(displacement, species, sigma)
    force_fn = quantity.force(energy_fn)

    # Create a minimizer.
    init_fn, apply_fn = minimize.fire_descent(energy_fn, shift)
    opt_state = init_fn(R)

    # Minimize the system.
    minimize_steps = 50
    print_every = 10

    print('Minimizing.')
    print('Step\tEnergy\tMax Force')
    print('-----------------------------------')
    for step in range(minimize_steps):
        opt_state = apply_fn(opt_state)

        if step % print_every == 0:
            R = opt_state.position
            print('{:.2f}\t{:.2f}\t{:.2f}'.format(step, energy_fn(R),
                                                  np.max(force_fn(R))))
示例#14
0
  def test_cell_list_incommensurate(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(12.1)
    cell_size = f32(3.0)
    displacement, _ = space.periodic(box_size)
    energy_fn = energy.soft_sphere_pair(displacement)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    cell_list_energy = smap.cartesian_product(
      energy.soft_sphere, space.metric(displacement))
    cell_list_energy = \
      jit(smap.cell_list(cell_list_energy, box_size, cell_size, R))
    self.assertAllClose(
      np.array(energy_fn(R), dtype=dtype), cell_list_energy(R), True)
示例#15
0
  def test_pair_cell_list_energy(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(9.0)
    cell_size = f32(1.0)
    displacement, _ = space.periodic(box_size)
    metric = space.metric(displacement)
    exact_energy_fn = energy.soft_sphere_pair(displacement)
    energy_fn = smap.cartesian_product(energy.soft_sphere, metric)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    cell_energy_fn = smap.cell_list(energy_fn, box_size, cell_size, R)
    self.assertAllClose(
      np.array(exact_energy_fn(R), dtype=dtype),
      cell_energy_fn(R), True)
示例#16
0
  def test_cell_list_direct_force_jit(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(9.0)
    cell_size = f32(1.0)
    displacement, _ = space.periodic(box_size)
    energy_fn = energy.soft_sphere_pair(displacement)
    force_fn = quantity.force(energy_fn)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    grid_energy_fn = smap.cartesian_product(
      energy.soft_sphere, space.metric(displacement))
    grid_force_fn = quantity.force(grid_energy_fn)
    grid_force_fn = jit(smap.cell_list(grid_force_fn, box_size, cell_size, R))
    self.assertAllClose(
      np.array(force_fn(R), dtype=dtype), grid_force_fn(R), True)
示例#17
0
    def test_soft_sphere_neighbor_list_energy(self, spatial_dimension, dtype,
                                              format):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        exact_energy_fn = energy.soft_sphere_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
            displacement, box_size, format=format)

        nbrs = neighbor_fn.allocate(R)

        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, nbrs))
示例#18
0
    def test_pair_grid_force_nonuniform(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        if spatial_dimension == 2:
            box_size = f32(np.array([[8.0, 10.0]]))
        else:
            box_size = f32(np.array([[8.0, 10.0, 12.0]]))

        cell_size = f32(2.0)
        displacement, _ = space.periodic(box_size[0])
        energy_fn = energy.soft_sphere_pair(displacement, quantity.Dynamic)
        force_fn = quantity.force(energy_fn)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        grid_force_fn = smap.grid(force_fn, box_size, cell_size, R)
        species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64)
        self.assertAllClose(np.array(force_fn(R, species, 1), dtype=dtype),
                            grid_force_fn(R), True)
示例#19
0
    def test_EMT_from_db_dynamic(self, spatial_dimension, dtype, low_pressure):
        if spatial_dimension == 2:
            N = 64
        else:
            N = 128

        if dtype == jnp.float32:
            max_grad_thresh = 1e-5
            atol = 1e-4
            rtol = 1e-3
        else:
            max_grad_thresh = 1e-10
            atol = 1e-8
            rtol = 1e-5

        for index in range(NUM_SAMPLES):
            cijkl, R, sigma, box = test_util.load_elasticity_test_data(
                spatial_dimension, low_pressure, dtype, index)
            R = space.transform(box, R)
            box = box[0, 0]

            displacement, shift = space.periodic(box)
            #Below we use the wrong sigma, so we must pass it dynamically
            energy_fn = energy.soft_sphere_pair(displacement, sigma=1.0)
            maxgrad = jnp.max(jnp.abs(grad(energy_fn)(R, sigma=sigma)))
            assert (maxgrad < max_grad_thresh)

            EMT_fn = jit(
                elasticity.athermal_moduli(energy_fn, check_convergence=True))
            C, converged = EMT_fn(R, box, sigma=sigma)
            assert (C.dtype == dtype)
            assert (C.shape == (spatial_dimension, spatial_dimension,
                                spatial_dimension, spatial_dimension))
            if converged:
                self.assertAllClose(cijkl,
                                    elasticity._extract_elements(C, False),
                                    atol=atol,
                                    rtol=rtol)

                #make sure the symmetries are there
                self.assertAllClose(C, jnp.einsum("ijkl->jikl", C))
                self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C))
                self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
示例#20
0
    def test_npt_nose_hoover_jammed(self, dtype, sy_steps):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic_general(state.box)

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)
        invariant = partial(simulate.npt_nose_hoover_invariant, E)
        pressure_fn = partial(quantity.pressure, E)

        nhc_kwargs = {sy_steps: sy_steps}
        kT = 1e-3
        P = state.pressure
        init_fn, apply_fn = simulate.npt_nose_hoover(E, shift_fn, 1e-3, P, kT,
                                                     nhc_kwargs, nhc_kwargs)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.fractional_position, state.box)

        E_initial = invariant(state, P, kT) * np.ones((DYNAMICS_STEPS, ))
        P_target = P * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_energy_pressure):
            state, energy, pressure = state_energy_pressure
            state = apply_fn(state)
            energy = ops.index_update(energy, i, invariant(state, P, kT))
            box = simulate.npt_box(state)
            KE = quantity.kinetic_energy(state.velocity, state.mass)
            p = pressure_fn(state.position, box, KE)
            pressure = ops.index_update(pressure, i, p)
            return state, energy, pressure

        Es = np.zeros((DYNAMICS_STEPS, ))
        Ps = np.zeros((DYNAMICS_STEPS, ))
        state, Es, Ps = lax.fori_loop(0, DYNAMICS_STEPS, step_fn,
                                      (state, Es, Ps))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
        self.assertAllClose(Ps, P_target, rtol=0.05, atol=0.05)
示例#21
0
    def test_EMT_from_db_fraccoord(self, spatial_dimension, dtype,
                                   low_pressure):
        if spatial_dimension == 2:
            N = 64
        else:
            N = 128

        if dtype == jnp.float32:
            max_grad_thresh = 1e-5
            atol = 1e-4
            rtol = 1e-3
        else:
            max_grad_thresh = 1e-10
            atol = 1e-8
            rtol = 1e-5

        for index in range(NUM_SAMPLES):
            cijkl, R, sigma, box = test_util.load_elasticity_test_data(
                spatial_dimension, low_pressure, dtype, index)

            displacement, shift = space.periodic_general(
                box, fractional_coordinates=True)
            energy_fn = energy.soft_sphere_pair(displacement, sigma=sigma)
            assert (jnp.max(jnp.abs(grad(energy_fn)(R))) < max_grad_thresh)

            EMT_fn = jit(
                elasticity.athermal_moduli(energy_fn, check_convergence=True))
            C, converged = EMT_fn(R, box)
            assert (C.dtype == dtype)
            assert (C.shape == (spatial_dimension, spatial_dimension,
                                spatial_dimension, spatial_dimension))
            if converged:
                self.assertAllClose(cijkl,
                                    elasticity._extract_elements(C, False),
                                    atol=atol,
                                    rtol=rtol)

                #make sure the symmetries are there
                self.assertAllClose(C, jnp.einsum("ijkl->jikl", C))
                self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C))
                self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
示例#22
0
  def test_cell_list_force_nonuniform(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    if spatial_dimension == 2:
      box_size = f32(np.array([[8.0, 10.0]]))
    else:
      box_size = f32(np.array([[8.0, 10.0, 12.0]]))
    cell_size = f32(2.0)
    displacement, _ = space.periodic(box_size[0])
    energy_fn = energy.soft_sphere_pair(displacement)
    force_fn = quantity.force(energy_fn)
    
    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

    cell_energy_fn = smap.cartesian_product(
      energy.soft_sphere, space.metric(displacement))
    cell_force_fn = quantity.force(cell_energy_fn)
    cell_force_fn = smap.cell_list(cell_force_fn, box_size, cell_size, R)
    df = np.sum((force_fn(R) - cell_force_fn(R)) ** 2, axis=1)
    self.assertAllClose(
      np.array(force_fn(R), dtype=dtype), cell_force_fn(R), True)