示例#1
0
  def test_triplet_static_species_scalar(self, spatial_dimension, dtype):
      key = random.PRNGKey(0)
      angle_fn = lambda dR1, dR2, param=5.0: param * np.sum(np.square(dR1))
      square = lambda dR, param: param * np.sum(np.square(dR))
      params = f32(np.array([[[1., 1.], [2., 0.]], [[0., 2.], [1., 1.]]]))

      count = PARTICLE_COUNT // 50
      key, split = random.split(key)
      species = random.randint(split, (count,), 0, 2)
      displacement, _ = space.free()
      metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)
      triplet_square = smap.triplet(angle_fn,
                                    displacement,
                                    species=species,
                                    param=params,
                                    reduce_axis=None)

      metric = space.map_product(metric)
      for _ in range(STOCHASTIC_SAMPLES):
        key, split = random.split(key)
        R = random.uniform(
            split, (count, spatial_dimension), dtype=dtype)
        total = 0.
        for i in range(2):
          for j in range(2):
            R_1 = R[species == i]
            R_2 = R[species == j]
            total += 0.5 * np.sum(metric(R_1, R_2))
        self.assertAllClose(triplet_square(R) / count, np.array(total, dtype=dtype))
示例#2
0
    def test_brownian(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        key, T_split, mass_split = random.split(key, 3)

        _, shift = space.free()
        energy_fn = lambda R, **kwargs: f32(0)

        R = np.zeros((BROWNIAN_PARTICLE_COUNT, 2), dtype=dtype)
        mass = random.uniform(mass_split, (),
                              minval=0.1,
                              maxval=10.0,
                              dtype=dtype)
        T = random.uniform(T_split, (), minval=0.3, maxval=1.4, dtype=dtype)

        dt = f32(1e-2)
        gamma = f32(0.1)

        init_fn, apply_fn = simulate.brownian(energy_fn,
                                              shift,
                                              dt,
                                              T,
                                              gamma=gamma)
        apply_fn = jit(apply_fn)

        state = init_fn(key, R, mass)

        sim_t = f32(BROWNIAN_DYNAMICS_STEPS * dt)
        for _ in range(BROWNIAN_DYNAMICS_STEPS):
            state = apply_fn(state)

        msd = np.var(state.position)
        th_msd = dtype(2 * T / (mass * gamma) * sim_t)
        assert np.abs(msd - th_msd) / msd < 1e-2
        assert state.position.dtype == dtype
示例#3
0
    def test_nve_ensemble_time_dependence(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        pos_key, center_key, vel_key, mass_key = random.split(key, 4)
        R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension),
                          dtype=dtype)
        R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension),
                           dtype=dtype)
        mass = random.uniform(mass_key, (PARTICLE_COUNT, ),
                              minval=0.1,
                              maxval=5.0,
                              dtype=dtype)
        displacement, shift = space.free()

        E = energy.soft_sphere_pair(displacement)

        init_fn, apply_fn = simulate.nve(E, shift, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(vel_key, R, mass=mass)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state)

        for t in range(SHORT_DYNAMICS_STEPS):
            state = apply_fn(state, t=t * 1e-3)
            E_total = E_T(state)
            assert np.abs(E_total - E_initial) < E_initial * 0.01
            assert state.position.dtype == dtype
示例#4
0
  def test_nvt_langevin(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, R_key, R0_key, T_key, masses_key = random.split(key, 5)

      R = random.normal(
        R_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      R0 = random.normal(
        R0_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      _, shift = space.free()

      E = functools.partial(
          lambda R, R0, **kwargs: np.sum((R - R0) ** 2), R0=R0)

      T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype)
      mass = random.uniform(
        masses_key, (LANGEVIN_PARTICLE_COUNT,), minval=0.1, maxval=10.0, dtype=dtype)
      init_fn, apply_fn = simulate.nvt_langevin(E, shift, f32(1e-2), T, gamma=f32(0.3))
      apply_fn = jit(apply_fn)

      state = init_fn(key, R, mass=mass, T_initial=dtype(1.0))

      T_list = []
      for step in range(LANGEVIN_DYNAMICS_STEPS):
        state = apply_fn(state)
        if step > 4000 and step % 100 == 0:
          T_list += [quantity.temperature(state.velocity, state.mass)]

      T_emp = np.mean(np.array(T_list))
      assert np.abs(T_emp - T) < 0.1
      assert state.position.dtype == dtype
示例#5
0
    def test_nve_ensemble(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        pos_key, center_key, vel_key, mass_key = random.split(key, 4)
        R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension),
                          dtype=dtype)
        R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension),
                           dtype=dtype)
        mass = random.uniform(mass_key, (PARTICLE_COUNT, ),
                              minval=0.1,
                              maxval=5.0,
                              dtype=dtype)
        _, shift = space.free()

        E = lambda R, **kwargs: np.sum((R - R0)**2)

        init_fn, apply_fn = simulate.nve(E, shift, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(vel_key, R, mass=mass)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state)

        for _ in range(DYNAMICS_STEPS):
            state = apply_fn(state)
            E_total = E_T(state)
            assert np.abs(E_total - E_initial) < E_initial * 0.01
            assert state.position.dtype == dtype
示例#6
0
    def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * dr**2
        params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square,
                                  metric,
                                  species=quantity.Dynamic,
                                  param=params)

        metric = space.map_product(metric)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(
                        square(metric(R_1, R_2), param))
            self.assertAllClose(mapped_square(R, species, 2),
                                np.array(total, dtype=dtype))
示例#7
0
    def test_stress_non_minimized_free(self, dim, dtype):
        key = random.PRNGKey(0)
        N = 64

        box = quantity.box_size_at_number_density(N, 0.8, dim)
        displacement_fn, _ = space.free()

        pos = random.uniform(key, (N, dim)) * box

        energy_fn = energy.soft_sphere_pair(displacement_fn)

        def exact_stress(R):
            dR = space.map_product(displacement_fn)(R, R)
            dr = space.distance(dR)
            g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
            V = quantity.volume(dim, box)
            dUdr = 0.5 * g(dr)[:, :, None, None]
            dr = (dr + jnp.eye(N))[:, :, None, None]
            return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] /
                            (V * dr),
                            axis=(0, 1))

        exact_stress = exact_stress(pos)
        ad_stress = quantity.stress(energy_fn, pos, box)

        tol = 1e-7 if dtype is f64 else 2e-5

        self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
示例#8
0
  def test_gradient_descent(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split, split0 = random.split(key, 3)
      R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      R0 = random.uniform(split0, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      energy = lambda R, **kwargs: np.sum((R - R0) ** 2)
      _, shift_fn = space.free()

      opt_init, opt_apply = minimize.gradient_descent(energy, shift_fn, f32(1e-1))

      E_current = energy(R)
      dr_current = np.sum((R - R0) ** 2)

      for _ in range(OPTIMIZATION_STEPS):
        R = opt_apply(R)
        E_new = energy(R)
        dr_new = np.sum((R - R0) ** 2)
        assert E_new < E_current
        assert E_new.dtype == dtype
        assert dr_new < dr_current
        assert dr_new.dtype == dtype
        E_current = E_new
        dr_current = dr_new
示例#9
0
    def test_pair_dynamic_species_vector(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * np.sum(dr**2, axis=2)
        params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        disp, _ = space.free()

        mapped_square = smap.pair(square,
                                  disp,
                                  species=quantity.Dynamic,
                                  param=params)

        disp = vmap(vmap(disp, (0, None), 0), (None, 0), 0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(square(disp(R_1, R_2), param))
            self.assertAllClose(mapped_square(R, species, 2),
                                np.array(total, dtype=dtype))
示例#10
0
    def test_fire_descent(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split, split0 = random.split(key, 3)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R0 = random.uniform(split0, (PARTICLE_COUNT, spatial_dimension),
                                dtype=dtype)

            energy = lambda R, **kwargs: np.sum((R - R0)**2)
            _, shift_fn = space.free()

            opt_init, opt_apply = minimize.fire_descent(energy, shift_fn)

            opt_state = opt_init(R)
            E_current = energy(R)
            dr_current = np.sum((R - R0)**2)

            @jit
            def three_steps(state):
                return opt_apply(opt_apply(opt_apply(state)))

            for _ in range(OPTIMIZATION_STEPS):
                opt_state = three_steps(opt_state)
                R = opt_state.position
                E_new = energy(R)
                dr_new = np.sum((R - R0)**2)
                assert E_new < E_current
                assert E_new.dtype == dtype
                assert dr_new < dr_current
                assert dr_new.dtype == dtype
                E_current = E_new
                dr_current = dr_new
示例#11
0
  def test_graph_network_neighbor_list_moving(self,
                                              spatial_dimension,
                                              dtype,
                                              format):
    if format is partition.OrderedSparse:
      self.skipTest('OrderedSparse format incompatible with GNN '
                    'force field.')

    key = random.PRNGKey(0)

    R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

    d, _ = space.free()

    cutoff = 0.3
    dr_threshold = 0.1

    init_fn, energy_fn = energy.graph_network(d, cutoff)
    params = init_fn(key, R)

    neighbor_fn, _, nl_energy_fn = \
      energy.graph_network_neighbor_list(d, 1.0, cutoff,
                                         dr_threshold, format=format)

    nbrs = neighbor_fn.allocate(R)
    key = random.fold_in(key, 1)
    R = R + random.uniform(key, (32, spatial_dimension),
                           minval=-0.05, maxval=0.05, dtype=dtype)
    if format is partition.Dense:
      self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs))
    else:
      self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs),
                          rtol=2e-4, atol=2e-4)
示例#12
0
    def test_cell_list_overflow(self):
        displacement_fn, shift_fn = space.free()

        box_size = 100.0
        r_cutoff = 3.0
        dr_threshold = 0.0

        neighbor_list_fn = partition.neighbor_list(
            displacement_fn,
            box_size=box_size,
            r_cutoff=r_cutoff,
            dr_threshold=dr_threshold,
        )

        # all far from eachother
        R = jnp.array([
            [20.0, 20.0],
            [30.0, 30.0],
            [40.0, 40.0],
            [50.0, 50.0],
        ])
        neighbors = neighbor_list_fn.allocate(R)
        self.assertEqual(neighbors.idx.dtype, jnp.int32)

        # two first point are close to eachother
        R = jnp.array([
            [20.0, 20.0],
            [20.0, 20.0],
            [40.0, 40.0],
            [50.0, 50.0],
        ])

        neighbors = neighbors.update(R)
        self.assertTrue(neighbors.did_buffer_overflow)
        self.assertEqual(neighbors.idx.dtype, jnp.int32)
示例#13
0
 def test_radial_symmetry_functions(self, N_types, N_etas, dtype):
   displacement, shift = space.free()
   gr = nn.radial_symmetry_functions(displacement, 
                                     np.array([1, 1, N_types]), 
                                     np.linspace(1.0, 2.0, N_etas, dtype=dtype), 
                                     4)
   R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
   gr_out = gr(R)
   self.assertAllClose(gr_out.shape, (3, N_types * N_etas))
   self.assertAllClose(gr_out[2, 0], dtype(0.411717), rtol=1e-6, atol=1e-6)
示例#14
0
 def test_angular_symmetry_functions(self, N_types, N_etas, dtype):
   displacement, shift = space.free()
   gr = nn.angular_symmetry_functions(displacement,np.array([1, 1, N_types]), 
                                      etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), 
                                      lambdas=np.array([-1.0] * N_etas, dtype), 
                                      zetas=np.array([1.0] * N_etas, dtype), 
                                      cutoff_distance=8.0)
   R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
   gr_out = gr(R)
   self.assertAllClose(gr_out.shape, (3, N_etas *  N_types * (N_types + 1) // 2))
   self.assertAllClose(gr_out[2, 0], dtype(1.577944), rtol=1e-6, atol=1e-6)
示例#15
0
def run(N=32, n_iter=1000, with_jit=True):
    import jax.numpy as jnp
    from jax import random, jit
    from jax_md import space, energy, simulate

    # MD configs
    dt = 1e-1
    temperature = 0.1

    # R: current position
    # dR: displacement
    # displacement(Ra, Rb):
    #   dR = Ra - Rb
    # periodic displacement(Ra, Rb):
    #   dR = Ra - Rb
    #   np.mod(dR + side * f32(0.5), side) - f32(0.5) * side
    # periodic shift:
    #   np.mod(R + dR, side)
    # shift:
    #   R + dR
    displacement, shift = space.free()

    # Simulation init
    # dr: pairwise distances
    # epsilon: interaction energy scale (const)
    # alpha: interaction stiffness
    # dr = distance(R)
    # U(dr) = np.where(dr < 1.0, (1 - dr) ** 2, 0)
    # energy_fn(R) = diagonal_mask(U(dr))
    energy_fn = energy.soft_sphere_pair(displacement)

    # force(energy) = -d(energy)/dR
    # xi = random.normal(R.shape, R.dtype)
    # gamma = 0.1
    # nu = 1 / (mass * gamma)
    # dR = force(R) * dt * nu + np.sqrt(2 * temperature * dt * nu) * xi
    # BrownianState(position, mass, rng)
    pos_key, sim_key = random.split(random.PRNGKey(0))
    R = random.uniform(pos_key, (N, 2), dtype=jnp.float32)
    init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, temperature)
    if with_jit:
        apply_fn = jit(apply_fn)
    state = init_fn(sim_key, R)

    # Start simulation
    times = []
    for i in range(n_iter):
        time_start = time.perf_counter_ns()
        state = apply_fn(state)
        time_end = time.perf_counter_ns()
        times.append(time_end - time_start)

    # Finish with profiling times
    return times
示例#16
0
 def test_cosine_angles(self, dtype):
     displacement, _ = space.free()
     displacement = space.map_product(displacement)
     R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype)
     dR = displacement(R, R)
     cangles = quantity.cosine_angles(dR)
     c45 = 1 / np.sqrt(2)
     true_cangles = np.array([[[0, 0, 0], [0, 1, c45], [0, c45, 1]],
                              [[1, 0, 0], [0, 0, 0], [0, 0, 1]],
                              [[1, c45, 0], [c45, 1, 0], [0, 0, 0]]],
                             dtype=dtype)
     self.assertAllClose(cangles, true_cangles)
示例#17
0
    def test_custom_mask_function(self):
        displacement_fn, shift_fn = space.free()

        box_size = 1.0
        r_cutoff = 3.0
        dr_threshold = 0.0
        n_particles = 10
        R = jnp.broadcast_to(jnp.zeros(3), (n_particles, 3))

        def acceptable_id_pair(id1, id2):
            '''
      Don't allow particles to have an interaction when their id's 
      are closer than 3 (eg disabling 1-2 and 1-3 interactions)
      '''
            return jnp.abs(id1 - id2) > 3

        def mask_id_based(idx: Array, ids: Array, mask_val: int,
                          _acceptable_id_pair: Callable) -> Array:
            '''
      _acceptable_id_pair mapped to act upon the neighbor list where:
          - index of particle 1 is in index in the first dimension of array
          - index of particle 2 is given by the value in the array
      '''
            @partial(vmap, in_axes=(0, 0, None))
            def acceptable_id_pair(idx, id1, ids):
                id2 = ids.at[idx].get()
                return vmap(_acceptable_id_pair, in_axes=(None, 0))(id1, id2)

            mask = acceptable_id_pair(idx, ids, ids)
            return jnp.where(mask, idx, mask_val)

        ids = jnp.arange(n_particles)  # id is just particle index here.
        mask_val = n_particles
        custom_mask_function = partial(mask_id_based,
                                       ids=ids,
                                       mask_val=mask_val,
                                       _acceptable_id_pair=acceptable_id_pair)

        neighbor_list_fn = partition.neighbor_list(
            displacement_fn,
            box_size=box_size,
            r_cutoff=r_cutoff,
            dr_threshold=dr_threshold,
            custom_mask_function=custom_mask_function,
        )

        neighbors = neighbor_list_fn.allocate(R)
        neighbors = neighbors.update(R)
        '''
    Without masking it's 9 neighbors (with mask self) -> 90 neighbors.
    With masking -> 42.
    '''
        self.assertEqual(42, (neighbors.idx != mask_val).sum())
示例#18
0
    def test_pair_scalar_dummy_arg(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr**2

        key, split = random.split(key)
        R = random.normal(key, (PARTICLE_COUNT, spatial_dimension),
                          dtype=dtype)
        displacement, shift = space.free()

        mapped = smap.pair(square, space.metric(displacement))

        mapped(R, t=f32(0))
示例#19
0
    def test_langevin_harmonic(self):
        alpha = 1.0
        E = lambda x: jnp.sum(0.5 * alpha * x**2)
        displacement, shift = space.free()

        N = 10000
        steps = 1000
        kT = 0.25
        dt = 1e-4
        gamma = 3
        mass = 2.0
        tol = 1e-3

        X = jnp.ones((N, 1, 1))
        key = random.split(random.PRNGKey(0), N)

        init_fn, step_fn = simulate.nvt_langevin(E, shift, dt, kT, gamma,
                                                 False)
        step_fn = jit(vmap(step_fn))

        state = vmap(init_fn, (0, 0, None))(key, X, mass)
        v0 = state.velocity

        for i in range(steps):
            state = step_fn(state)

        # Compare mean position and velocity autocorrelation with theoretical
        # prediction.

        d = jnp.sqrt(gamma**2 / 4 - alpha / mass)

        beta_1 = gamma / 2 + d
        beta_2 = gamma / 2 - d
        A = -beta_2 / (beta_1 - beta_2)
        B = beta_1 / (beta_1 - beta_2)
        exp1 = lambda t: jnp.exp(-beta_1 * t)
        exp2 = lambda t: jnp.exp(-beta_2 * t)
        Z = kT / (2 * d * mass)

        pos_fn = lambda t: A * exp1(t) + B * exp2(t)
        vel_fn = lambda t: Z * (-beta_2 * exp2(t) + beta_1 * exp1(t))

        t = steps * dt
        self.assertAllClose(jnp.mean(state.position),
                            pos_fn(t),
                            rtol=tol,
                            atol=tol)
        self.assertAllClose(jnp.mean(state.velocity * v0),
                            vel_fn(t),
                            rtol=tol,
                            atol=tol)
示例#20
0
    def test_nvt_nose_hoover_ensemble(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def invariant(T, state):
            """The conserved quantity for Nose-Hoover thermostat."""
            accum = \
                E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
            DOF = spatial_dimension * PARTICLE_COUNT
            accum = accum + (state.v_xi[0]) ** 2 * state.Q[0] * 0.5 + \
                DOF * T * state.xi[0]
            for xi, v_xi, Q in zip(state.xi[1:], state.v_xi[1:], state.Q[1:]):
                accum = accum + v_xi**2 * Q * 0.5 + T * xi
            return accum

        for _ in range(STOCHASTIC_SAMPLES):
            key, pos_key, center_key, vel_key, T_key, masses_key = \
                random.split(key, 6)

            R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension),
                              dtype=dtype)
            R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            _, shift = space.free()

            E = functools.partial(lambda R, R0, **kwargs: np.sum((R - R0)**2),
                                  R0=R0)

            T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype)
            mass = random.uniform(masses_key, (PARTICLE_COUNT, ),
                                  minval=0.1,
                                  maxval=10.0,
                                  dtype=dtype)
            init_fn, apply_fn = simulate.nvt_nose_hoover(E,
                                                         shift,
                                                         1e-3,
                                                         T,
                                                         tau=10)
            apply_fn = jit(apply_fn)

            state = init_fn(vel_key, R, mass=mass, T_initial=dtype(1.0))

            initial = invariant(T, state)

            for _ in range(DYNAMICS_STEPS):
                state = apply_fn(state)

            assert np.abs(
                quantity.temperature(state.velocity, state.mass) - T) < 0.1
            assert np.abs(invariant(T, state) - initial) < initial * 0.01
            assert state.position.dtype == dtype
示例#21
0
  def test_pair_no_species_vector(self, spatial_dimension, dtype):
    square = lambda dr: np.sum(dr ** 2, axis=2)
    disp, _ = space.free()

    mapped_square = smap.pair(square, disp)

    disp = space.map_product(disp)
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))), dtype=dtype)
      self.assertAllClose(mapped_square(R), mapped_ref)
示例#22
0
    def test_graph_network_shape_dtype(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

        d, _ = space.free()

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(key, R)

        E_out = energy_fn(params, R)

        assert E_out.shape == ()
        assert E_out.dtype == dtype
示例#23
0
 def test_simple_spring(self, spatial_dimension, dtype):
   key = random.PRNGKey(0)
   disp, _ = space.free()
   if spatial_dimension == 2:
     R = np.array([[0., 0.], [1., 1.]], dtype=dtype)
     dist = np.sqrt(2.)
   elif spatial_dimension == 3:
     R = np.array([[0., 0., 0.], [1., 1., 1.]], dtype=dtype)
     dist = np.sqrt(3.)
   bonds = np.array([[0, 1]], np.int32)
   for _ in range(STOCHASTIC_SAMPLES):
     key, l_key, a_key = random.split(key, 3)
     length = random.uniform(key, (), minval=0.1, maxval=3.0)
     alpha = random.uniform(key, (), minval=2., maxval=4.)
     E = energy.simple_spring_bond(disp, bonds, length=length, alpha=alpha)
     E_exact = dtype((dist - length) ** alpha / alpha)
     self.assertAllClose(E(R), E_exact, True)
示例#24
0
    def test_cosine_angles_neighbors(self, dtype):
        displacement, _ = space.free()
        displacement = vmap(vmap(displacement, (None, 0)), 0)

        R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype)
        R_neigh = np.array(
            [[[0, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]],
            dtype=dtype)

        dR = displacement(R, R_neigh)

        cangles = quantity.cosine_angles(dR)
        c45 = 1 / np.sqrt(2)
        true_cangles = np.array(
            [[[1, c45], [c45, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
            dtype=dtype)
        self.assertAllClose(cangles, true_cangles)
示例#25
0
    def test_bond_no_type_static(self, spatial_dimension, dtype):
        harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32))

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2]))

            self.assertAllClose(mapped(R), dtype(accum))
示例#26
0
    def test_pair_no_species_scalar(self, spatial_dimension, dtype):
        square = lambda dr: dr**2
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square, metric)
        metric = space.map_product(metric)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            self.assertAllClose(
                mapped_square(R),
                np.array(0.5 * np.sum(square(metric(R, R))), dtype=dtype))
示例#27
0
    def test_graph_network_neighbor_list(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

        d, _ = space.free()

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(key, R)

        neighbor_fn, _, nl_energy_fn = \
          energy.graph_network_neighbor_list(d, 1.0, cutoff, 0.0)

        nbrs = neighbor_fn(R)
        self.assertAllClose(energy_fn(params, R),
                            nl_energy_fn(params, R, nbrs))
示例#28
0
  def test_pair_no_species_vector_nonadditive(self, spatial_dimension, dtype):
    square = lambda dr, params: params * np.sum(dr ** 2, axis=2)
    disp, _ = space.free()

    mapped_square = smap.pair(square, disp, params=lambda x, y: x * y)

    disp = space.map_product(disp)
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, R_key, params_key = random.split(key, 3)
      R = random.uniform(
        R_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      params = random.uniform(
        params_key, (PARTICLE_COUNT,), dtype=dtype, minval=0.1, maxval=1.5)
      pp_params = params[None, :] * params[:, None]
      mapped_ref = np.array(0.5 * np.sum(square(disp(R, R), pp_params)),
                            dtype=dtype)
      self.assertAllClose(mapped_square(R, params=params), mapped_ref)
示例#29
0
文件: nn_test.py 项目: zheshen/jax-md
 def test_behler_parrinello_symmetry_functions_neighbor_list(self,
                                                             N_types,
                                                             N_etas,
                                                             dtype):
   displacement, shift = space.free()
   neighbor_fn = partition.neighbor_list(displacement, 10.0, 8.0, 0.0)
   gr = nn.behler_parrinello_symmetry_functions_neighbor_list(
           displacement,np.array([1, 1, N_types]),
           radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
           angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
           lambdas=np.array([-1.0] * N_etas, dtype),
           zetas=np.array([1.0] * N_etas, dtype),
           cutoff_distance=8.0)
   R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
   nbrs = neighbor_fn(R)
   gr_out = gr(R, neighbor=nbrs)
   self.assertAllClose(gr_out.shape,
                       (3, N_etas *  (N_types + N_types * (N_types + 1) // 2)))
   self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)
示例#30
0
  def test_pair_no_species_scalar_dynamic(self, spatial_dimension, dtype):
    square = lambda dr, epsilon: epsilon * dr ** 2
    displacement, _ = space.free()
    metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

    mapped_square = smap.pair(square, metric, epsilon=1.0)
    metric = space.map_product(metric)

    key = random.PRNGKey(0)
    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2 = random.split(key, 3)
      R = random.uniform(
        split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      epsilon = random.uniform(split2, (PARTICLE_COUNT,), dtype=dtype)
      mat_epsilon = 0.5 * (epsilon[:, np.newaxis] + epsilon[np.newaxis, :])
      self.assertAllClose(
        mapped_square(R, epsilon=epsilon),
        np.array(0.5 * np.sum(
          square(metric(R, R), mat_epsilon)), dtype=dtype))