Example #1
0
    def test_EMT_from_db_nbrlist(self, spatial_dimension, dtype, low_pressure):
        if spatial_dimension == 2:
            N = 64
        else:
            N = 128

        if dtype == jnp.float32:
            max_grad_thresh = 1e-5
            atol = 1e-4
            rtol = 1e-3
        else:
            max_grad_thresh = 1e-10
            atol = 1e-8
            rtol = 1e-5

        for index in range(NUM_SAMPLES):
            cijkl, R, sigma, box = test_util.load_elasticity_test_data(
                spatial_dimension, low_pressure, dtype, index)

            displacement, shift = space.periodic_general(
                box, fractional_coordinates=True)
            neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
                displacement, box, sigma=sigma, fractional_coordinates=True)
            nbrs = neighbor_fn.allocate(R)
            assert (jnp.max(jnp.abs(grad(energy_fn)(R, nbrs))) <
                    max_grad_thresh)

            EMT_fn = jit(
                elasticity.athermal_moduli(energy_fn, check_convergence=True))
            C, converged = EMT_fn(R, box, neighbor=nbrs)
            assert (C.dtype == dtype)
            assert (C.shape == (spatial_dimension, spatial_dimension,
                                spatial_dimension, spatial_dimension))
            if converged:
                self.assertAllClose(cijkl,
                                    elasticity._extract_elements(C, False),
                                    atol=atol,
                                    rtol=rtol)

                #make sure the symmetries are there
                self.assertAllClose(C, jnp.einsum("ijkl->jikl", C))
                self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C))
                self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
Example #2
0
    def test_EMT_from_db_dynamic(self, spatial_dimension, dtype, low_pressure):
        if spatial_dimension == 2:
            N = 64
        else:
            N = 128

        if dtype == jnp.float32:
            max_grad_thresh = 1e-5
            atol = 1e-4
            rtol = 1e-3
        else:
            max_grad_thresh = 1e-10
            atol = 1e-8
            rtol = 1e-5

        for index in range(NUM_SAMPLES):
            cijkl, R, sigma, box = test_util.load_elasticity_test_data(
                spatial_dimension, low_pressure, dtype, index)
            R = space.transform(box, R)
            box = box[0, 0]

            displacement, shift = space.periodic(box)
            #Below we use the wrong sigma, so we must pass it dynamically
            energy_fn = energy.soft_sphere_pair(displacement, sigma=1.0)
            maxgrad = jnp.max(jnp.abs(grad(energy_fn)(R, sigma=sigma)))
            assert (maxgrad < max_grad_thresh)

            EMT_fn = jit(
                elasticity.athermal_moduli(energy_fn, check_convergence=True))
            C, converged = EMT_fn(R, box, sigma=sigma)
            assert (C.dtype == dtype)
            assert (C.shape == (spatial_dimension, spatial_dimension,
                                spatial_dimension, spatial_dimension))
            if converged:
                self.assertAllClose(cijkl,
                                    elasticity._extract_elements(C, False),
                                    atol=atol,
                                    rtol=rtol)

                #make sure the symmetries are there
                self.assertAllClose(C, jnp.einsum("ijkl->jikl", C))
                self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C))
                self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))