コード例 #1
0
    def test_transform_inverse(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2 = random.split(key, 3)

            R = random.normal(split1, (PARTICLE_COUNT, spatial_dimension),
                              dtype=dtype)

            T = random.normal(split2, (spatial_dimension, spatial_dimension),
                              dtype=dtype)
            T_inv = space._small_inverse(T)

            R_test = space.transform(T_inv, space.transform(T, R))

            self.assertAllClose(R, R_test, True)
コード例 #2
0
def pressure(energy_fn: EnergyFn,
             position: Array,
             box: Box,
             kinetic_energy: float = 0.0,
             **kwargs) -> float:
    """Computes the internal pressure of a system.

  Note: This function requires that `energy_fn` take a `box` keyword argument.
  Most frequently, this is accomplished by using `periodic_general` boundary
  conditions combined with any of the energy functions in `energy.py`. This
  will not work with `space.periodic`.
  """
    dim = position.shape[1]

    vol_0 = volume(dim, box)
    box_fn = lambda vol: (vol / vol_0)**(1 / dim) * box

    def U(vol):
        return energy_fn(position, box=box_fn(vol), **kwargs)

    dUdV = grad(U)
    KE = kinetic_energy
    F = force(energy_fn)(position, box=box, **kwargs)
    R = space.transform(box, position)
    RdotF = util.high_precision_sum(R * F)

    return 1 / (dim * vol_0) * (2 * KE + RdotF - dim * vol_0 * dUdV(vol_0))
コード例 #3
0
    def test_transform_grad(self, spatial_dimension):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2 = random.split(key, 3)

            R = random.normal(split1, (PARTICLE_COUNT, spatial_dimension))
            T = random.normal(split2, (spatial_dimension, spatial_dimension))

            R_prime = space.transform(T, R)

            energy_direct = lambda R: np.sum(R**2)
            energy_indirect = lambda T, R: np.sum(space.transform(T, R)**2)

            grad_direct = grad(energy_direct)(R_prime)
            grad_indirect = grad(energy_indirect, 1)(T, R)

            self.assertAllClose(grad_direct, grad_indirect, True)
コード例 #4
0
ファイル: space_test.py プロジェクト: cagrikymk/jax-md
  def test_periodic_general_deform_shift(self,
                                        spatial_dimension, dtype, box_format):
    N = 16
    R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g) = \
      make_periodic_general_test_system(N, spatial_dimension, dtype, box_format)
    deformed_box = box * 0.9

    R_new = s_g(R, grad(E_g)(R), new_box=deformed_box)
    R_gf_new = space.transform(deformed_box, s_gf(R_f, grad(E_gf)(R_f)))

    self.assertAllClose(R_new, R_gf_new)
コード例 #5
0
ファイル: space_test.py プロジェクト: cagrikymk/jax-md
  def test_periodic_general_shift(self, spatial_dimension, dtype, box_format):
    N = 16
    R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g) = \
      make_periodic_general_test_system(N, spatial_dimension, dtype, box_format)

    R_new = s(R, grad(E)(R))
    R_gf_new = s_gf(R_f, grad(E_gf)(R_f))
    R_g_new = s_g(R, grad(E_g)(R))

    self.assertAllClose(R_new, space.transform(box, R_gf_new))
    self.assertAllClose(R_new, R_g_new)
コード例 #6
0
    def box_force(alpha, vol, box_fn, position, velocity, mass, force,
                  pressure, **kwargs):
        N, dim = position.shape

        def U(vol):
            return energy_fn(position, box=box_fn(vol), **kwargs)

        dUdV = grad(U)
        KE2 = util.high_precision_sum(velocity**2 * mass)
        R = space.transform(box_fn(vol), position)
        RdotF = util.high_precision_sum(R * force)

        return alpha * KE2 + RdotF - dim * vol * dUdV(
            vol) - pressure * vol * dim
コード例 #7
0
    def test_transform(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2 = random.split(key, 3)

            R = random.normal(split1, (PARTICLE_COUNT, spatial_dimension),
                              dtype=dtype)
            T = random.normal(split2, (spatial_dimension, spatial_dimension),
                              dtype=dtype)

            R_prime_exact = np.array(np.dot(R, T), dtype=dtype)
            R_prime = space.transform(T, R)

            self.assertAllClose(R_prime_exact, R_prime, True)
コード例 #8
0
    def test_EMT_from_db_dynamic(self, spatial_dimension, dtype, low_pressure):
        if spatial_dimension == 2:
            N = 64
        else:
            N = 128

        if dtype == jnp.float32:
            max_grad_thresh = 1e-5
            atol = 1e-4
            rtol = 1e-3
        else:
            max_grad_thresh = 1e-10
            atol = 1e-8
            rtol = 1e-5

        for index in range(NUM_SAMPLES):
            cijkl, R, sigma, box = test_util.load_elasticity_test_data(
                spatial_dimension, low_pressure, dtype, index)
            R = space.transform(box, R)
            box = box[0, 0]

            displacement, shift = space.periodic(box)
            #Below we use the wrong sigma, so we must pass it dynamically
            energy_fn = energy.soft_sphere_pair(displacement, sigma=1.0)
            maxgrad = jnp.max(jnp.abs(grad(energy_fn)(R, sigma=sigma)))
            assert (maxgrad < max_grad_thresh)

            EMT_fn = jit(
                elasticity.athermal_moduli(energy_fn, check_convergence=True))
            C, converged = EMT_fn(R, box, sigma=sigma)
            assert (C.dtype == dtype)
            assert (C.shape == (spatial_dimension, spatial_dimension,
                                spatial_dimension, spatial_dimension))
            if converged:
                self.assertAllClose(cijkl,
                                    elasticity._extract_elements(C, False),
                                    atol=atol,
                                    rtol=rtol)

                #make sure the symmetries are there
                self.assertAllClose(C, jnp.einsum("ijkl->jikl", C))
                self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C))
                self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
コード例 #9
0
    def potential(R: space.Array, neighbor: NeighborList, *args,
                  **kwargs) -> PotentialProperties:
        # a function to symmetrize the deformation tensor and apply it to the box
        transform_box_fn = lambda deformation: space.transform(
            jnp.eye(3, dtype=dtype) + (deformation + deformation.T) * 0.5, box)

        # atomwise and total energy functions that act on the transformed box.
        strained_energy_fn = (lambda R, deformation, neighbor, *args, **kwargs:
                              energy_fn(R,
                                        *args,
                                        **kwargs,
                                        box=transform_box_fn(deformation),
                                        neighbor=neighbor))

        total_strained_energy_fn = (
            lambda R, deformation, neighbor, *args, **kwargs: jnp.sum(
                strained_energy_fn(
                    R, deformation, *args, **kwargs, neighbor=neighbor)))

        # same for force ...
        force_fn = (lambda R, deformation, neighbor, *args, **kwargs: grad(
            total_strained_energy_fn, argnums=0)
                    (R, deformation, *args, **kwargs, neighbor=neighbor) * -1)

        # ... and stress
        box_volume = jnp.linalg.det(box)
        stress_fn = (lambda R, deformation, neighbor, *args, **kwargs: grad(
            total_strained_energy_fn, argnums=1)
                     (R, deformation, neighbor, *args, **kwargs) / box_volume)

        total_energy = total_strained_energy_fn(R, deformation, neighbor,
                                                *args, **kwargs)
        atomwise_energies = strained_energy_fn(R, deformation, neighbor, *args,
                                               **kwargs)
        forces = force_fn(R, deformation, neighbor, *args, **kwargs)
        stress = stress_fn(R, deformation, neighbor, *args, **kwargs)

        return total_energy, atomwise_energies, forces, stress
コード例 #10
0
ファイル: space_test.py プロジェクト: cagrikymk/jax-md
def make_periodic_general_test_system(N, dim, dtype, box_format):
  assert box_format in BOX_FORMATS

  box_size = quantity.box_size_at_number_density(N, 1.0, dim)
  box = dtype(box_size)
  if box_format == 'vector':
    box = jnp.array(jnp.ones(dim) * box_size, dtype)
  elif box_format == 'matrix':
    box = jnp.array(jnp.eye(dim) * box_size, dtype)

  d, s = space.periodic(jnp.diag(box) if box_format == 'matrix' else box)
  d_gf, s_gf = space.periodic_general(box)
  d_g, s_g = space.periodic_general(box, fractional_coordinates=False)

  key = random.PRNGKey(0)

  R_f = random.uniform(key, (N, dim), dtype=dtype)
  R = space.transform(box, R_f)

  E = jit(energy.soft_sphere_pair(d))
  E_gf = jit(energy.soft_sphere_pair(d_gf))
  E_g = jit(energy.soft_sphere_pair(d_g))

  return R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g)