Beispiel #1
0
    def test_periodic_general_wrapped_vs_unwrapped(self, spatial_dimension,
                                                   dtype):
        key = random.PRNGKey(0)

        eye = np.eye(spatial_dimension, dtype=dtype)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_R, split_T = random.split(key, 3)

            dT = random.normal(split_T, (spatial_dimension, spatial_dimension),
                               dtype=dtype)
            T = eye + dT + np.transpose(dT)

            R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R0 = R
            unwrapped_R = R

            displacement, shift = space.periodic_general(T)
            _, unwrapped_shift = space.periodic_general(T, wrapped=False)

            displacement = space.map_product(displacement)

            for _ in range(SHIFT_STEPS):
                key, split = random.split(key)
                dR = random.normal(split, (PARTICLE_COUNT, spatial_dimension),
                                   dtype=dtype)
                R = shift(R, dR)
                unwrapped_R = unwrapped_shift(unwrapped_R, dR)
                self.assertAllClose(displacement(R, R0),
                                    displacement(unwrapped_R, R0), True)
            assert not (np.all(unwrapped_R > 0) and np.all(unwrapped_R < 1))
Beispiel #2
0
  def test_periodic_against_periodic_general_grad(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    tol = 1e-13
    if dtype is f32:
      tol = 1e-5

    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2, split3 = random.split(key, 4)

      max_box_size = f32(10.0)
      box_size = max_box_size * random.uniform(
        split1, (spatial_dimension,), dtype=dtype)
      transform = jnp.diag(box_size)

      R = random.uniform(
        split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      R_scaled = R * box_size

      dR = random.normal(
        split3, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      disp_fn, shift_fn = space.periodic(box_size)
      general_disp_fn, general_shift_fn = space.periodic_general(transform)

      disp_fn = space.map_product(disp_fn)
      general_disp_fn = space.map_product(general_disp_fn)

      grad_fn = grad(lambda R: jnp.sum(disp_fn(R, R) ** 2))
      general_grad_fn = grad(lambda R: jnp.sum(general_disp_fn(R, R) ** 2))

      self.assertAllClose(grad_fn(R_scaled), general_grad_fn(R))
      assert general_grad_fn(R).dtype == dtype
Beispiel #3
0
    def test_periodic_against_periodic_general(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2, split3 = random.split(key, 4)

            max_box_size = f16(10.0)
            box_size = max_box_size * random.uniform(split1,
                                                     (spatial_dimension, ),
                                                     dtype=dtype)
            transform = np.diag(box_size)

            R = random.uniform(split2, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R_scaled = R * box_size

            dR = random.normal(split3, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            disp_fn, shift_fn = space.periodic(box_size)
            general_disp_fn, general_shift_fn = space.periodic_general(
                transform)

            disp_fn = space.map_product(disp_fn)
            general_disp_fn = space.map_product(general_disp_fn)

            self.assertAllClose(disp_fn(R_scaled, R_scaled),
                                general_disp_fn(R, R), True)
            assert disp_fn(R_scaled, R_scaled).dtype == dtype
            self.assertAllClose(shift_fn(R_scaled, dR),
                                general_shift_fn(R, dR) * box_size, True)
            assert shift_fn(R_scaled, dR).dtype == dtype
Beispiel #4
0
    def test_nve_jammed_periodic_general(self, dtype, coords):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic_general(
            state.box, coords == 'fractional')

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)

        init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, E_T(state))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
Beispiel #5
0
 def test_eam_neighbor_list(self, num_repetitions, dtype, format):
     if format is partition.OrderedSparse:
         self.skipTest('OrderedSparse neighbor lists not supported for EAM '
                       'potential.')
     latvec = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]],
                       dtype=dtype) * f32(4.05 / 2)
     atoms = np.array([[0, 0, 0]], dtype=dtype)
     atoms_repeated, latvec_repeated = lattice_repeater(
         atoms, latvec, num_repetitions)
     inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)),
                           dtype=dtype)
     R = np.dot(atoms_repeated, inv_latvec)
     displacement, _ = space.periodic_general(latvec_repeated)
     box_size = np.linalg.det(latvec_repeated)**(1 / 3)
     neighbor_fn, energy_fn = energy.eam_neighbor_list(
         displacement, box_size, *make_eam_test_splines(), format=format)
     nbrs = neighbor_fn.allocate(R)
     E = energy_fn(R, nbrs) / num_repetitions**3
     if dtype is f64:
         self.assertAllClose(E,
                             dtype(-3.3633387837793505),
                             atol=1e-8,
                             rtol=1e-8)
     else:
         self.assertAllClose(E, dtype(-3.3633387837793505))
Beispiel #6
0
    def test_stress_non_minimized_periodic_general(self, dim, dtype, coords):
        key = random.PRNGKey(0)
        N = 64

        box = quantity.box_size_at_number_density(N, 0.8, dim)
        displacement_fn, _ = space.periodic_general(box,
                                                    coords == 'fractional')

        pos = random.uniform(key, (N, dim))
        pos = pos if coords == 'fractional' else pos * box

        energy_fn = energy.soft_sphere_pair(displacement_fn)

        def exact_stress(R):
            dR = space.map_product(displacement_fn)(R, R)
            dr = space.distance(dR)
            g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
            V = quantity.volume(dim, box)
            dUdr = 0.5 * g(dr)[:, :, None, None]
            dr = (dr + jnp.eye(N))[:, :, None, None]
            return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] /
                            (V * dr),
                            axis=(0, 1))

        exact_stress = exact_stress(pos)
        ad_stress = quantity.stress(energy_fn, pos, box)

        tol = 1e-7 if dtype is f64 else 2e-5

        self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
Beispiel #7
0
    def test_periodic_general_time_dependence(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        eye = np.eye(spatial_dimension)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_T0_scale, split_T0_dT = random.split(key, 3)
            key, split_T1_scale, split_T1_dT = random.split(key, 3)
            key, split_t, split_R, split_dR = random.split(key, 4)

            size_0 = 10.0 * random.uniform(split_T0_scale, ())
            dtransform_0 = 0.5 * random.normal(
                split_T0_dT, (spatial_dimension, spatial_dimension))
            T_0 = np.array(size_0 * (eye + dtransform_0), dtype=dtype)

            size_1 = 10.0 * random.uniform(split_T1_scale, (), dtype=dtype)
            dtransform_1 = 0.5 * random.normal(
                split_T1_dT, (spatial_dimension, spatial_dimension),
                dtype=dtype)
            T_1 = np.array(size_1 * (eye + dtransform_1), dtype=dtype)

            T = lambda t: t * T_0 + (f32(1.0) - t) * T_1

            t_g = random.uniform(split_t, (), dtype=dtype)

            disp_fn, shift_fn = space.periodic_general(T)
            true_disp_fn, true_shift_fn = space.periodic_general(T(t_g))

            disp_fn = partial(disp_fn, t=t_g)

            disp_fn = space.map_product(disp_fn)
            true_disp_fn = space.map_product(true_disp_fn)

            R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            dR = random.normal(split_dR, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            self.assertAllClose(disp_fn(R, R),
                                np.array(true_disp_fn(R, R), dtype=dtype),
                                True)
            self.assertAllClose(shift_fn(R, dR, t=t_g),
                                np.array(true_shift_fn(R, dR), dtype=dtype),
                                True)
Beispiel #8
0
 def test_stillinger_weber(self, dtype, num_repetitions):
     lattice_vectors = lattice_vectors = np.array([[0, .5, .5], [.5, 0, .5],
                                                   [.5, .5, 0]]) * 5.428
     positions = np.array([[0, 0, 0], [0.25, 0.25, 0.25]])
     positions = lattice(positions, num_repetitions, lattice_vectors)
     lattice_vectors *= num_repetitions
     displacement, shift = space.periodic_general(lattice_vectors)
     energy_fn = jit(energy.stillinger_weber_energy(displacement))
     self.assertAllClose(
         energy_fn(positions) / positions.shape[0], -4.336503)
Beispiel #9
0
  def test_pressure_jammed(self, dtype, coords):
    key = random.PRNGKey(0)

    state = test_util.load_test_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic_general(state.box,
                                                       coords == 'fractional')

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    pos = getattr(state, coords + '_position')

    self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure)
Beispiel #10
0
    def test_neighbor_list_build_time_dependent(self, dtype, dim):
        key = random.PRNGKey(1)

        if dim == 2:
            box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32)
        elif dim == 3:
            box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0],
                                         [0.0, 0.0, 7.25]])
        min_length = np.min(np.diag(box_fn(0.)))
        cutoff = f32(1.23)
        # TODO(schsam): Get cell-list working with anisotropic cell sizes.
        cell_size = cutoff / min_length

        displacement, _ = space.periodic_general(box_fn)
        metric = space.metric(displacement)

        R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
        N = R.shape[0]
        neighbor_list_fn = partition.neighbor_list(metric,
                                                   1.,
                                                   cutoff,
                                                   0.0,
                                                   1.1,
                                                   cell_size=cell_size,
                                                   t=np.array(0.))

        idx = neighbor_list_fn(R, t=np.array(0.25)).idx
        R_neigh = R[idx]
        mask = idx < N

        metric = partial(metric, t=f32(0.25))
        d = vmap(vmap(metric, (None, 0)))
        dR = d(R, R_neigh)

        d_exact = space.map_product(metric)
        dR_exact = d_exact(R, R)

        dR = np.where(dR < cutoff, dR, 0) * mask
        dR_exact = np.where(dR_exact < cutoff, dR_exact, 0)

        dR = np.sort(dR, axis=1)
        dR_exact = np.sort(dR_exact, axis=1)

        for i in range(dR.shape[0]):
            dR_row = dR[i]
            dR_row = dR_row[dR_row > 0.]

            dR_exact_row = dR_exact[i]
            dR_exact_row = dR_exact_row[dR_exact_row > 0.]

            self.assertAllClose(dR_row, dR_exact_row)
Beispiel #11
0
    def test_periodic_general_dynamic(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        eye = jnp.eye(spatial_dimension)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_T0_scale, split_T0_dT = random.split(key, 3)
            key, split_T1_scale, split_T1_dT = random.split(key, 3)
            key, split_t, split_R, split_dR = random.split(key, 4)

            size_0 = 10.0 * random.uniform(split_T0_scale, ())
            dtransform_0 = 0.5 * random.normal(
                split_T0_dT, (spatial_dimension, spatial_dimension))
            T_0 = jnp.array(size_0 * (eye + dtransform_0), dtype=dtype)

            size_1 = 10.0 * random.uniform(split_T1_scale, (), dtype=dtype)
            dtransform_1 = 0.5 * random.normal(
                split_T1_dT, (spatial_dimension, spatial_dimension),
                dtype=dtype)
            T_1 = jnp.array(size_1 * (eye + dtransform_1), dtype=dtype)

            disp_fn, shift_fn = space.periodic_general(T_0)
            true_disp_fn, true_shift_fn = space.periodic_general(T_1)

            disp_fn = partial(disp_fn, box=T_1)

            disp_fn = space.map_product(disp_fn)
            true_disp_fn = space.map_product(true_disp_fn)

            R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            dR = random.normal(split_dR, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            self.assertAllClose(disp_fn(R, R),
                                jnp.array(true_disp_fn(R, R), dtype=dtype))
            self.assertAllClose(shift_fn(R, dR, box=T_1),
                                jnp.array(true_shift_fn(R, dR), dtype=dtype))
Beispiel #12
0
def make_periodic_general_test_system(N, dim, dtype, box_format):
  assert box_format in BOX_FORMATS

  box_size = quantity.box_size_at_number_density(N, 1.0, dim)
  box = dtype(box_size)
  if box_format == 'vector':
    box = jnp.array(jnp.ones(dim) * box_size, dtype)
  elif box_format == 'matrix':
    box = jnp.array(jnp.eye(dim) * box_size, dtype)

  d, s = space.periodic(jnp.diag(box) if box_format == 'matrix' else box)
  d_gf, s_gf = space.periodic_general(box)
  d_g, s_g = space.periodic_general(box, fractional_coordinates=False)

  key = random.PRNGKey(0)

  R_f = random.uniform(key, (N, dim), dtype=dtype)
  R = space.transform(box, R_f)

  E = jit(energy.soft_sphere_pair(d))
  E_gf = jit(energy.soft_sphere_pair(d_gf))
  E_g = jit(energy.soft_sphere_pair(d_g))

  return R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g)
Beispiel #13
0
 def test_eam(self, num_repetitions, dtype):
     latvec = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]],
                       dtype=dtype) * f32(4.05 / 2)
     atoms = np.array([[0, 0, 0]], dtype=dtype)
     atoms_repeated, latvec_repeated = lattice_repeater(
         atoms, latvec, num_repetitions)
     inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)))
     displacement, shift = space.periodic_general(latvec_repeated)
     assert charge_fn(dtype(1.0)).dtype == dtype
     assert embedding_fn(dtype(1.0)).dtype == dtype
     assert pairwise_fn(dtype(1.0)).dtype == dtype
     eam_energy = energy.eam(displacement, charge_fn, embedding_fn,
                             pairwise_fn)
     self.assertAllClose(
         eam_energy(np.dot(atoms_repeated, inv_latvec)) /
         f32(num_repetitions**3), dtype(-3.363338), True)
Beispiel #14
0
  def test_canonicalize_displacement_or_metric(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    displacement, _ = space.periodic_general(np.eye(spatial_dimension))
    metric = space.metric(displacement)
    test_metric = space.canonicalize_displacement_or_metric(displacement)

    metric = space.map_product(metric)
    test_metric = space.map_product(test_metric)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2 = random.split(key, 3)

      R = random.normal(
        split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      self.assertAllClose(metric(R, R), test_metric(R, R), True)
Beispiel #15
0
    def test_stress_lammps_periodic_general(self, dim, dtype):
        key = random.PRNGKey(0)
        N = 64

        (box, R, V), (E, C) = test_util.load_lammps_stress_data(dtype)

        displacement_fn, _ = space.periodic_general(box)
        energy_fn = smap.pair(
            lambda dr, **kwargs: jnp.where(dr < f32(2.5),
                                           energy.lennard_jones(dr), f32(0.0)),
            space.canonicalize_displacement_or_metric(displacement_fn))

        ad_stress = quantity.stress(energy_fn, R, box, velocity=V)

        tol = 5e-5

        self.assertAllClose(energy_fn(R) / len(R), E, atol=tol, rtol=tol)
        self.assertAllClose(C, ad_stress, atol=tol, rtol=tol)
Beispiel #16
0
    def test_EMT_from_db_nbrlist(self, spatial_dimension, dtype, low_pressure):
        if spatial_dimension == 2:
            N = 64
        else:
            N = 128

        if dtype == jnp.float32:
            max_grad_thresh = 1e-5
            atol = 1e-4
            rtol = 1e-3
        else:
            max_grad_thresh = 1e-10
            atol = 1e-8
            rtol = 1e-5

        for index in range(NUM_SAMPLES):
            cijkl, R, sigma, box = test_util.load_elasticity_test_data(
                spatial_dimension, low_pressure, dtype, index)

            displacement, shift = space.periodic_general(
                box, fractional_coordinates=True)
            neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
                displacement, box, sigma=sigma, fractional_coordinates=True)
            nbrs = neighbor_fn.allocate(R)
            assert (jnp.max(jnp.abs(grad(energy_fn)(R, nbrs))) <
                    max_grad_thresh)

            EMT_fn = jit(
                elasticity.athermal_moduli(energy_fn, check_convergence=True))
            C, converged = EMT_fn(R, box, neighbor=nbrs)
            assert (C.dtype == dtype)
            assert (C.shape == (spatial_dimension, spatial_dimension,
                                spatial_dimension, spatial_dimension))
            if converged:
                self.assertAllClose(cijkl,
                                    elasticity._extract_elements(C, False),
                                    atol=atol,
                                    rtol=rtol)

                #make sure the symmetries are there
                self.assertAllClose(C, jnp.einsum("ijkl->jikl", C))
                self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C))
                self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
Beispiel #17
0
 def test_eam(self, num_repetitions, dtype):
   latvec = np.array(
       [[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) * f32(4.05 / 2)
   atoms = np.array([[0, 0, 0]], dtype=dtype)
   atoms_repeated, latvec_repeated = lattice_repeater(
       atoms, latvec, num_repetitions)
   inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)),
                         dtype=dtype)
   displacement, _ = space.periodic_general(latvec_repeated)
   charge_fn, embedding_fn, pairwise_fn, _ = make_eam_test_splines()
   assert charge_fn(np.array(1.0, dtype)).dtype == dtype
   assert embedding_fn(np.array(1.0, dtype)).dtype == dtype
   assert pairwise_fn(np.array(1.0, dtype)).dtype == dtype
   eam_energy = energy.eam(displacement, charge_fn, embedding_fn, pairwise_fn)
   E = eam_energy(np.dot(atoms_repeated, inv_latvec)) / num_repetitions ** 3
   if dtype is f64:
     self.assertAllClose(E, dtype(-3.3633387837793505), atol=1e-8, rtol=1e-8)
   else:
     self.assertAllClose(E, dtype(-3.3633387837793505))
Beispiel #18
0
 def test_stillinger_weber_neighbor_list(self, dtype, num_repetitions,
                                         format):
   if format in [partition.OrderedSparse, partition.Sparse]:
     self.skipTest(f'{format} not supported for Stillinger-Weber.')
   lattice_vectors = np.array([[0, .5, .5],
                               [.5, 0, .5],
                               [.5, .5, 0]]) * 5.428
   positions = np.array([[0,0,0], [0.25, 0.25, 0.25]])
   positions = lattice(positions, num_repetitions, lattice_vectors)
   lattice_vectors *= num_repetitions
   displacement, shift = space.periodic_general(lattice_vectors)
   box_size =  np.linalg.det(lattice_vectors) ** (1/3) * num_repetitions
   neighbor_fn, energy_fn = \
     energy.stillinger_weber_neighbor_list(displacement, box_size,
                                           fractional_coordinates=True,
                                           format=format)
   nbrs = neighbor_fn.allocate(positions)
   N = positions.shape[0]
   self.assertAllClose(energy_fn(positions, neighbor=nbrs) / N, -4.336503155764325)
Beispiel #19
0
    def test_npt_nose_hoover_jammed(self, dtype, sy_steps):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic_general(state.box)

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)
        invariant = partial(simulate.npt_nose_hoover_invariant, E)
        pressure_fn = partial(quantity.pressure, E)

        nhc_kwargs = {sy_steps: sy_steps}
        kT = 1e-3
        P = state.pressure
        init_fn, apply_fn = simulate.npt_nose_hoover(E, shift_fn, 1e-3, P, kT,
                                                     nhc_kwargs, nhc_kwargs)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.fractional_position, state.box)

        E_initial = invariant(state, P, kT) * np.ones((DYNAMICS_STEPS, ))
        P_target = P * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_energy_pressure):
            state, energy, pressure = state_energy_pressure
            state = apply_fn(state)
            energy = ops.index_update(energy, i, invariant(state, P, kT))
            box = simulate.npt_box(state)
            KE = quantity.kinetic_energy(state.velocity, state.mass)
            p = pressure_fn(state.position, box, KE)
            pressure = ops.index_update(pressure, i, p)
            return state, energy, pressure

        Es = np.zeros((DYNAMICS_STEPS, ))
        Ps = np.zeros((DYNAMICS_STEPS, ))
        state, Es, Ps = lax.fori_loop(0, DYNAMICS_STEPS, step_fn,
                                      (state, Es, Ps))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
        self.assertAllClose(Ps, P_target, rtol=0.05, atol=0.05)
Beispiel #20
0
 def test_eam(self, num_repetitions, dtype):
     latvec = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]],
                       dtype=dtype) * f32(4.05 / 2)
     atoms = np.array([[0, 0, 0]], dtype=dtype)
     atoms_repeated, latvec_repeated = lattice_repeater(
         atoms, latvec, num_repetitions)
     inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)),
                           dtype=dtype)
     displacement, _ = space.periodic_general(latvec_repeated)
     charge_fn, embedding_fn, pairwise_fn = make_eam_test_splines()
     assert charge_fn(np.array(1.0, dtype)).dtype == dtype
     assert embedding_fn(np.array(1.0, dtype)).dtype == dtype
     assert pairwise_fn(np.array(1.0, dtype)).dtype == dtype
     eam_energy = energy.eam(displacement, charge_fn, embedding_fn,
                             pairwise_fn)
     tol = 1e-5 if dtype == np.float32 else 1e-6
     self.assertAllClose(
         eam_energy(np.dot(atoms_repeated, inv_latvec)) /
         np.array(num_repetitions**3, dtype), dtype(-3.363338), True, tol,
         tol)
Beispiel #21
0
    def get_displacement(self, atoms: Atoms):
        if not all(atoms.get_pbc()):
            return space.free()

        return space.periodic_general(self.box, fractional_coordinates=False)