def test_transform_inverse(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.normal(split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) T = random.normal(split2, (spatial_dimension, spatial_dimension), dtype=dtype) T_inv = space._small_inverse(T) R_test = space.transform(T_inv, space.transform(T, R)) self.assertAllClose(R, R_test, True)
def pressure(energy_fn: EnergyFn, position: Array, box: Box, kinetic_energy: float = 0.0, **kwargs) -> float: """Computes the internal pressure of a system. Note: This function requires that `energy_fn` take a `box` keyword argument. Most frequently, this is accomplished by using `periodic_general` boundary conditions combined with any of the energy functions in `energy.py`. This will not work with `space.periodic`. """ dim = position.shape[1] vol_0 = volume(dim, box) box_fn = lambda vol: (vol / vol_0)**(1 / dim) * box def U(vol): return energy_fn(position, box=box_fn(vol), **kwargs) dUdV = grad(U) KE = kinetic_energy F = force(energy_fn)(position, box=box, **kwargs) R = space.transform(box, position) RdotF = util.high_precision_sum(R * F) return 1 / (dim * vol_0) * (2 * KE + RdotF - dim * vol_0 * dUdV(vol_0))
def test_transform_grad(self, spatial_dimension): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.normal(split1, (PARTICLE_COUNT, spatial_dimension)) T = random.normal(split2, (spatial_dimension, spatial_dimension)) R_prime = space.transform(T, R) energy_direct = lambda R: np.sum(R**2) energy_indirect = lambda T, R: np.sum(space.transform(T, R)**2) grad_direct = grad(energy_direct)(R_prime) grad_indirect = grad(energy_indirect, 1)(T, R) self.assertAllClose(grad_direct, grad_indirect, True)
def test_periodic_general_deform_shift(self, spatial_dimension, dtype, box_format): N = 16 R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g) = \ make_periodic_general_test_system(N, spatial_dimension, dtype, box_format) deformed_box = box * 0.9 R_new = s_g(R, grad(E_g)(R), new_box=deformed_box) R_gf_new = space.transform(deformed_box, s_gf(R_f, grad(E_gf)(R_f))) self.assertAllClose(R_new, R_gf_new)
def test_periodic_general_shift(self, spatial_dimension, dtype, box_format): N = 16 R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g) = \ make_periodic_general_test_system(N, spatial_dimension, dtype, box_format) R_new = s(R, grad(E)(R)) R_gf_new = s_gf(R_f, grad(E_gf)(R_f)) R_g_new = s_g(R, grad(E_g)(R)) self.assertAllClose(R_new, space.transform(box, R_gf_new)) self.assertAllClose(R_new, R_g_new)
def box_force(alpha, vol, box_fn, position, velocity, mass, force, pressure, **kwargs): N, dim = position.shape def U(vol): return energy_fn(position, box=box_fn(vol), **kwargs) dUdV = grad(U) KE2 = util.high_precision_sum(velocity**2 * mass) R = space.transform(box_fn(vol), position) RdotF = util.high_precision_sum(R * force) return alpha * KE2 + RdotF - dim * vol * dUdV( vol) - pressure * vol * dim
def test_transform(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.normal(split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) T = random.normal(split2, (spatial_dimension, spatial_dimension), dtype=dtype) R_prime_exact = np.array(np.dot(R, T), dtype=dtype) R_prime = space.transform(T, R) self.assertAllClose(R_prime_exact, R_prime, True)
def test_EMT_from_db_dynamic(self, spatial_dimension, dtype, low_pressure): if spatial_dimension == 2: N = 64 else: N = 128 if dtype == jnp.float32: max_grad_thresh = 1e-5 atol = 1e-4 rtol = 1e-3 else: max_grad_thresh = 1e-10 atol = 1e-8 rtol = 1e-5 for index in range(NUM_SAMPLES): cijkl, R, sigma, box = test_util.load_elasticity_test_data( spatial_dimension, low_pressure, dtype, index) R = space.transform(box, R) box = box[0, 0] displacement, shift = space.periodic(box) #Below we use the wrong sigma, so we must pass it dynamically energy_fn = energy.soft_sphere_pair(displacement, sigma=1.0) maxgrad = jnp.max(jnp.abs(grad(energy_fn)(R, sigma=sigma))) assert (maxgrad < max_grad_thresh) EMT_fn = jit( elasticity.athermal_moduli(energy_fn, check_convergence=True)) C, converged = EMT_fn(R, box, sigma=sigma) assert (C.dtype == dtype) assert (C.shape == (spatial_dimension, spatial_dimension, spatial_dimension, spatial_dimension)) if converged: self.assertAllClose(cijkl, elasticity._extract_elements(C, False), atol=atol, rtol=rtol) #make sure the symmetries are there self.assertAllClose(C, jnp.einsum("ijkl->jikl", C)) self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C)) self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
def potential(R: space.Array, neighbor: NeighborList, *args, **kwargs) -> PotentialProperties: # a function to symmetrize the deformation tensor and apply it to the box transform_box_fn = lambda deformation: space.transform( jnp.eye(3, dtype=dtype) + (deformation + deformation.T) * 0.5, box) # atomwise and total energy functions that act on the transformed box. strained_energy_fn = (lambda R, deformation, neighbor, *args, **kwargs: energy_fn(R, *args, **kwargs, box=transform_box_fn(deformation), neighbor=neighbor)) total_strained_energy_fn = ( lambda R, deformation, neighbor, *args, **kwargs: jnp.sum( strained_energy_fn( R, deformation, *args, **kwargs, neighbor=neighbor))) # same for force ... force_fn = (lambda R, deformation, neighbor, *args, **kwargs: grad( total_strained_energy_fn, argnums=0) (R, deformation, *args, **kwargs, neighbor=neighbor) * -1) # ... and stress box_volume = jnp.linalg.det(box) stress_fn = (lambda R, deformation, neighbor, *args, **kwargs: grad( total_strained_energy_fn, argnums=1) (R, deformation, neighbor, *args, **kwargs) / box_volume) total_energy = total_strained_energy_fn(R, deformation, neighbor, *args, **kwargs) atomwise_energies = strained_energy_fn(R, deformation, neighbor, *args, **kwargs) forces = force_fn(R, deformation, neighbor, *args, **kwargs) stress = stress_fn(R, deformation, neighbor, *args, **kwargs) return total_energy, atomwise_energies, forces, stress
def make_periodic_general_test_system(N, dim, dtype, box_format): assert box_format in BOX_FORMATS box_size = quantity.box_size_at_number_density(N, 1.0, dim) box = dtype(box_size) if box_format == 'vector': box = jnp.array(jnp.ones(dim) * box_size, dtype) elif box_format == 'matrix': box = jnp.array(jnp.eye(dim) * box_size, dtype) d, s = space.periodic(jnp.diag(box) if box_format == 'matrix' else box) d_gf, s_gf = space.periodic_general(box) d_g, s_g = space.periodic_general(box, fractional_coordinates=False) key = random.PRNGKey(0) R_f = random.uniform(key, (N, dim), dtype=dtype) R = space.transform(box, R_f) E = jit(energy.soft_sphere_pair(d)) E_gf = jit(energy.soft_sphere_pair(d_gf)) E_g = jit(energy.soft_sphere_pair(d_g)) return R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g)