示例#1
0
def npt_box(state: NPTNoseHooverState) -> Box:
    """Get the current box from an NPT simulation."""
    dim = state.position.shape[1]
    ref = state.reference_box
    V_0 = quantity.volume(dim, ref)
    V = V_0 * jnp.exp(dim * state.box_position)
    return (V / V_0)**(1 / dim) * ref
示例#2
0
def _npt_box_info(
        state: NPTNoseHooverState) -> Tuple[float, Callable[[float], float]]:
    """Gets the current volume and a function to compute the box from volume."""
    dim = state.position.shape[1]
    ref = state.reference_box
    V_0 = quantity.volume(dim, ref)
    V = V_0 * jnp.exp(dim * state.box_position)
    return V, lambda V: (V / V_0)**(1 / dim) * ref
示例#3
0
 def exact_stress(R):
   dR = space.map_product(displacement_fn)(R, R)
   dr = space.distance(dR)
   g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
   V = quantity.volume(dim, box)
   dUdr = 0.5 * g(dr)[:, :, None, None]
   dr = (dr + jnp.eye(N))[:, :, None, None]
   return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr),
                   axis=(0, 1))
示例#4
0
    def calculate_emt(R: Array, box: Array, **kwargs) -> Array:
        """Calculate the elastic modulus tensor.

    energy_fn(R) corresponds to the state around which we are expanding
      
    Args:
      R: array of shape (N,dimension) of particle positions. This does not
        generalize to arbitrary dimensions and is only implemented for
          dimension == 2
          dimension == 3
      box: A box specifying the shape of the simulation volume. Used to infer
        the volume of the unit cell.
    
    Return: C or the tuple (C,converged)
      where C is the Elastic modulus tensor as an array of shape (dimension,
      dimension,dimension,dimension) that respects the major and minor 
      symmetries, and converged is a boolean flag (see above).

    """
        if not (R.shape[-1] == 2 or R.shape[-1] == 3):
            raise AssertionError('Only implemented for 2d and 3d systems.')

        if R.dtype is not jnp.dtype('float64'):
            logging.warning('Elastic modulus calculations can sometimes lose '
                            'precision when not using 64-bit precision.')

        dim = R.shape[-1]

        def setup_energy_fn_general(strain_tensor):
            I = jnp.eye(dim, dtype=R.dtype)

            @jit
            def energy_fn_general(R, gamma):
                perturbation = I + gamma * strain_tensor
                return energy_fn(R, perturbation=perturbation, **kwargs)

            return energy_fn_general

        def get_affine_response(strain_tensor):
            energy_fn_general = setup_energy_fn_general(strain_tensor)
            d2U_dRdgamma = jacfwd(jacrev(energy_fn_general, argnums=0),
                                  argnums=1)(R, 0.)
            d2U_dgamma2 = jacfwd(jacrev(energy_fn_general, argnums=1),
                                 argnums=1)(R, 0.)
            return d2U_dRdgamma, d2U_dgamma2

        strain_tensors = _get_strain_tensor_list(dim, R.dtype)
        d2U_dRdgamma_all, d2U_dgamma2_all = vmap(get_affine_response)(
            strain_tensors)

        #Solve the system of equations.
        energy_fn_Ronly = partial(energy_fn, **kwargs)

        def hvp(f, primals, tangents):
            return jvp(grad(f), primals, tangents)[1]

        def hvp_specific_with_tether(v):
            return hvp(energy_fn_Ronly, (R, ), (v, )) + tether_strength * v

        non_affine_response_all = jsp.sparse.linalg.cg(
            vmap(hvp_specific_with_tether), d2U_dRdgamma_all, tol=cg_tol)[0]
        #The above line should be functionally equivalent to:
        #H0=hessian(energy_fn)(R, box=box, **kwargs).reshape(R.size,R.size) \
        #    + tether_strength * jnp.identity(R.size)
        #non_affine_response_all = jnp.transpose(jnp.linalg.solve(
        #   H0,
        #   jnp.transpose(d2U_dRdgamma_all))
        #   )

        residual = jnp.linalg.norm(
            vmap(hvp_specific_with_tether)(non_affine_response_all) -
            d2U_dRdgamma_all)
        converged = residual / jnp.linalg.norm(d2U_dRdgamma_all) < cg_tol

        response_all = d2U_dgamma2_all - jnp.einsum(
            "nij,nij->n", d2U_dRdgamma_all, non_affine_response_all)

        vol_0 = quantity.volume(dim, box)
        response_all = response_all / vol_0
        C = _convert_responses_to_elastic_constants(response_all)

        # JAX does not allow proper runtime error handling in jitted function.
        # Instead, if the user requests a gradient check and the check fails,
        # we convert C into jnp.nan's. While this doesn't raise an exception,
        # it at least is very "loud".
        if gradient_check is not None:
            maxgrad = jnp.amax(jnp.abs(grad(energy_fn)(R, **kwargs)))
            C = lax.cond(maxgrad > gradient_check, lambda _: jnp.nan * C,
                         lambda _: C, None)

        if check_convergence:
            return C, converged
        else:
            return C