def npt_box(state: NPTNoseHooverState) -> Box: """Get the current box from an NPT simulation.""" dim = state.position.shape[1] ref = state.reference_box V_0 = quantity.volume(dim, ref) V = V_0 * jnp.exp(dim * state.box_position) return (V / V_0)**(1 / dim) * ref
def _npt_box_info( state: NPTNoseHooverState) -> Tuple[float, Callable[[float], float]]: """Gets the current volume and a function to compute the box from volume.""" dim = state.position.shape[1] ref = state.reference_box V_0 = quantity.volume(dim, ref) V = V_0 * jnp.exp(dim * state.box_position) return V, lambda V: (V / V_0)**(1 / dim) * ref
def exact_stress(R): dR = space.map_product(displacement_fn)(R, R) dr = space.distance(dR) g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()') V = quantity.volume(dim, box) dUdr = 0.5 * g(dr)[:, :, None, None] dr = (dr + jnp.eye(N))[:, :, None, None] return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr), axis=(0, 1))
def calculate_emt(R: Array, box: Array, **kwargs) -> Array: """Calculate the elastic modulus tensor. energy_fn(R) corresponds to the state around which we are expanding Args: R: array of shape (N,dimension) of particle positions. This does not generalize to arbitrary dimensions and is only implemented for dimension == 2 dimension == 3 box: A box specifying the shape of the simulation volume. Used to infer the volume of the unit cell. Return: C or the tuple (C,converged) where C is the Elastic modulus tensor as an array of shape (dimension, dimension,dimension,dimension) that respects the major and minor symmetries, and converged is a boolean flag (see above). """ if not (R.shape[-1] == 2 or R.shape[-1] == 3): raise AssertionError('Only implemented for 2d and 3d systems.') if R.dtype is not jnp.dtype('float64'): logging.warning('Elastic modulus calculations can sometimes lose ' 'precision when not using 64-bit precision.') dim = R.shape[-1] def setup_energy_fn_general(strain_tensor): I = jnp.eye(dim, dtype=R.dtype) @jit def energy_fn_general(R, gamma): perturbation = I + gamma * strain_tensor return energy_fn(R, perturbation=perturbation, **kwargs) return energy_fn_general def get_affine_response(strain_tensor): energy_fn_general = setup_energy_fn_general(strain_tensor) d2U_dRdgamma = jacfwd(jacrev(energy_fn_general, argnums=0), argnums=1)(R, 0.) d2U_dgamma2 = jacfwd(jacrev(energy_fn_general, argnums=1), argnums=1)(R, 0.) return d2U_dRdgamma, d2U_dgamma2 strain_tensors = _get_strain_tensor_list(dim, R.dtype) d2U_dRdgamma_all, d2U_dgamma2_all = vmap(get_affine_response)( strain_tensors) #Solve the system of equations. energy_fn_Ronly = partial(energy_fn, **kwargs) def hvp(f, primals, tangents): return jvp(grad(f), primals, tangents)[1] def hvp_specific_with_tether(v): return hvp(energy_fn_Ronly, (R, ), (v, )) + tether_strength * v non_affine_response_all = jsp.sparse.linalg.cg( vmap(hvp_specific_with_tether), d2U_dRdgamma_all, tol=cg_tol)[0] #The above line should be functionally equivalent to: #H0=hessian(energy_fn)(R, box=box, **kwargs).reshape(R.size,R.size) \ # + tether_strength * jnp.identity(R.size) #non_affine_response_all = jnp.transpose(jnp.linalg.solve( # H0, # jnp.transpose(d2U_dRdgamma_all)) # ) residual = jnp.linalg.norm( vmap(hvp_specific_with_tether)(non_affine_response_all) - d2U_dRdgamma_all) converged = residual / jnp.linalg.norm(d2U_dRdgamma_all) < cg_tol response_all = d2U_dgamma2_all - jnp.einsum( "nij,nij->n", d2U_dRdgamma_all, non_affine_response_all) vol_0 = quantity.volume(dim, box) response_all = response_all / vol_0 C = _convert_responses_to_elastic_constants(response_all) # JAX does not allow proper runtime error handling in jitted function. # Instead, if the user requests a gradient check and the check fails, # we convert C into jnp.nan's. While this doesn't raise an exception, # it at least is very "loud". if gradient_check is not None: maxgrad = jnp.amax(jnp.abs(grad(energy_fn)(R, **kwargs))) C = lax.cond(maxgrad > gradient_check, lambda _: jnp.nan * C, lambda _: C, None) if check_convergence: return C, converged else: return C