def soft_sphere_cell_list( displacement_or_metric, box_size, R_example, species=None, sigma=1.0, epsilon=1.0, alpha=2.0): """Convenience wrapper to compute soft spheres using a cell list.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) alpha = np.array(alpha, dtype=f32) if species is None: if sigma.shape or epsilon.shape or alpha.shape: raise ValueError( ('At the moment per-particle (as opposed to per-species) parameters are' ' not supported using cell lists. Please open a feature request!')) energy_fn = smap.cartesian_product( soft_sphere, space.canonicalize_displacement_or_metric(displacement_or_metric), sigma=sigma, epsilon=epsilon, alpha=alpha) return smap.cell_list(energy_fn, box_size, np.max(sigma), R_example, species)
def lennard_jones_neighbor_list( displacement_or_metric, box_size, species=None, sigma=1.0, epsilon=1.0, alpha=2.0, r_onset=2.0, r_cutoff=2.5, dr_threshold=0.5): # TODO(schsam) Optimize this. """Convenience wrapper to compute lennard-jones using a neighbor list.""" sigma = np.array(sigma, f32) epsilon = np.array(epsilon, f32) r_onset = np.array(r_onset * np.max(sigma), f32) r_cutoff = np.array(r_cutoff * np.max(sigma), f32) dr_threshold = np.array(np.max(sigma) * dr_threshold, f32) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, r_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon) return neighbor_fn, energy_fn
def lennard_jones_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, alpha: Array = 2.0, r_onset: float = 2.0, r_cutoff: float = 2.5, dr_threshold: float = 0.5, per_particle: bool = False ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute lennard-jones using a neighbor list.""" sigma = np.array(sigma, f32) epsilon = np.array(epsilon, f32) r_onset = np.array(r_onset * np.max(sigma), f32) r_cutoff = np.array(r_cutoff * np.max(sigma), f32) dr_threshold = np.array(np.max(sigma) * dr_threshold, f32) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, r_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, reduce_axis=(1, ) if per_particle else None) return neighbor_fn, energy_fn
def bks_neighbor_list(displacement_or_metric, box_size, species, Q_sq, exp_coeff, exp_decay, attractive_coeff, repulsive_coeff, coulomb_alpha, cutoff, dr_threshold=0.8): Q_sq = np.array(Q_sq, f32) exp_coeff = np.array(exp_coeff, f32) exp_decay = np.array(exp_decay, f32) attractive_coeff = np.array(attractive_coeff, f32) repulsive_coeff = np.array(repulsive_coeff, f32) dr_threshold = f32(dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( bks, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, Q_sq=Q_sq, exp_coeff=exp_coeff, exp_decay=exp_decay, attractive_coeff=attractive_coeff, repulsive_coeff=repulsive_coeff, coulomb_alpha=coulomb_alpha, cutoff=cutoff) return neighbor_fn, energy_fn
def soft_sphere_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, alpha: Array = 2.0, dr_threshold: float = 0.2, per_particle: bool = False ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute soft spheres using a neighbor list.""" sigma = maybe_downcast(sigma) epsilon = maybe_downcast(epsilon) alpha = maybe_downcast(alpha) list_cutoff = np.max(sigma) dr_threshold = list_cutoff * maybe_downcast(dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, list_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( soft_sphere, space.canonicalize_displacement_or_metric(displacement_or_metric), ignore_unused_parameters=True, species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None) return neighbor_fn, energy_fn
def morse_neighbor_list(displacement_or_metric, box_size, species=None, sigma=1.0, epsilon=5.0, alpha=5.0, r_onset=2.0, r_cutoff=2.5, dr_threshold=0.5, per_particle=False): # TODO(cpgoodri) Optimize this. """Convenience wrapper to compute Morse using a neighbor list.""" sigma = np.array(sigma, f32) epsilon = np.array(epsilon, f32) alpha = np.array(alpha, f32) r_onset = np.array(r_onset, f32) r_cutoff = np.array(r_cutoff, f32) dr_threshold = np.array(dr_threshold, f32) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, r_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None) return neighbor_fn, energy_fn
def lennard_jones_cell_list( displacement_or_metric, box_size, R_example, species=None, sigma=1.0, epsilon=1.0, alpha=2.0, r_onset=2.0, r_cutoff=2.5): """Convenience wrapper to compute soft spheres using a cell list.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) r_onset = f32(r_onset * np.max(sigma)) r_cutoff = f32(r_cutoff * np.max(sigma)) if species is None: if sigma.shape or epsilon.shape: raise ValueError( ('At the moment per-particle (as opposed to per-species) parameters are' ' not supported using cell lists. Please open a feature request!')) energy_fn = smap.cartesian_product( multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), sigma=sigma, epsilon=epsilon) return smap.cell_list(energy_fn, box_size, r_cutoff, R_example, species)
def morse_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array = None, sigma: Array = 1.0, epsilon: Array = 5.0, alpha: Array = 5.0, r_onset: float = 2.0, r_cutoff: float = 2.5, dr_threshold: float = 0.5, per_particle: bool = False ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute Morse using a neighbor list.""" sigma = maybe_downcast(sigma) epsilon = maybe_downcast(epsilon) alpha = maybe_downcast(alpha) r_onset = maybe_downcast(r_onset) r_cutoff = maybe_downcast(r_cutoff) dr_threshold = maybe_downcast(dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, r_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), ignore_unused_parameters=True, species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None) return neighbor_fn, energy_fn
def soft_sphere_neighbor_list(displacement_or_metric, box_size, species=None, sigma=1.0, epsilon=1.0, alpha=2.0, dr_threshold=0.2): """Convenience wrapper to compute soft spheres using a neighbor list.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) alpha = np.array(alpha, dtype=f32) list_cutoff = f32(np.max(sigma)) dr_threshold = f32(list_cutoff * dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, list_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( soft_sphere, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, alpha=alpha) return neighbor_fn, energy_fn
def radial_symmetry_functions( displacement_or_metric: DisplacementOrMetricFn, species: Optional[Array], etas: Array, cutoff_distance: float) -> Callable[[Array], Array]: """Returns a function that computes radial symmetry functions. Args: displacement: A function that produces an `[N_atoms, M_atoms, spatial_dimension]` of particle displacements from particle positions specified as an `[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]` respectively. species: An `[N_atoms]` that contains the species of each particle. etas: List of radial symmetry function parameters that control the spatial extension. cutoff_distance: Neighbors whose distance is larger than cutoff_distance do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance. Returns: A function that computes the radial symmetry function from input `[N_atoms, spatial_dimension]` and returns `[N_atoms, N_etas * N_types]` where N_etas is the number of eta parameters, N_types is the number of types of particles in the system. """ metric = space.canonicalize_displacement_or_metric(displacement_or_metric) radial_fn = lambda eta, dr: (jnp.exp( -eta * dr**2) * _behler_parrinello_cutoff_fn(dr, cutoff_distance)) radial_fn = vmap(radial_fn, (0, None)) if species is None: def compute_fn(R: Array, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_product(_metric) return util.high_precision_sum(radial_fn(etas, _metric(R, R)), axis=1).T elif isinstance(species, jnp.ndarray): species = onp.array(species) def compute_fn(R: Array, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_product(_metric) def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" R_neigh = R[species == atom_type, :] dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr), axis=1).T return jnp.hstack([ return_radial(atom_type) for atom_type in onp.unique(species) ]) return compute_fn
def pair_correlation(displacement_or_metric: Union[DisplacementFn, MetricFn], radii: Array, sigma: float, species: Array = None): """Computes the pair correlation function at a mesh of distances. The pair correlation function measures the number of particles at a given distance from a central particle. The pair correlation function is defined by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the approximation $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$. Args: displacement_or_metric: A function that computes the displacement or distance between two points. radii: An array of radii at which we would like to compute g(r). sigima: A float specifying the width of the approximating Gaussian. species: An optional array specifying the species of each particle. If species is None then we compute a single g(r) for all particles, otherwise we compute one g(r) for each species. Returns: A function `g_fn` that computes the pair correlation function for a collection of particles. """ d = space.canonicalize_displacement_or_metric(displacement_or_metric) d = space.map_product(d) def pairwise(dr, dim): return jnp.exp(-f32(0.5) * (dr - radii)**2 / sigma**2) / radii**(dim - 1) pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) if species is None: def g_fn(R): dim = R.shape[-1] mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype) return jnp.sum(mask[:, :, jnp.newaxis] * pairwise(d(R, R), dim), axis=(1, )) else: if not (isinstance(species, jnp.ndarray) and is_integer(species)): raise TypeError('Malformed species; expecting array of integers.') species_types = jnp.unique(species) def g_fn(R): dim = R.shape[-1] g_R = [] mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype) for s in species_types: Rs = R[species == s] mask_s = mask[:, species == s, jnp.newaxis] g_R += [jnp.sum(mask_s * pairwise(d(Rs, R), dim), axis=(1, ))] return g_R return g_fn
def soft_sphere_pair( displacement_or_metric, species=None, sigma=1.0, epsilon=1.0, alpha=2.0): """Convenience wrapper to compute soft sphere energy over a system.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) alpha = np.array(alpha, dtype=f32) return smap.pair( soft_sphere, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, alpha=alpha)
def lennard_jones_pair( displacement_or_metric, species=None, sigma=1.0, epsilon=1.0, r_onset=2.0, r_cutoff=2.5): """Convenience wrapper to compute Lennard-Jones energy over a system.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) r_onset = f32(r_onset * np.max(sigma)) r_cutoff = f32(r_cutoff * np.max(sigma)) return smap.pair( multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon)
def morse_pair( displacement_or_metric, species=None, sigma=1.0, epsilon=5.0, alpha=5.0, r_onset=2.0, r_cutoff=2.5): """Convenience wrapper to compute Morse energy over a system.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) alpha = np.array(alpha, dtype=f32) return smap.pair( multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, alpha=alpha)
def simple_spring_bond( displacement_or_metric, bond, bond_type=None, length=1, epsilon=1, alpha=2): """Convenience wrapper to compute energy of particles bonded by springs.""" length = np.array(length, f32) epsilon = np.array(epsilon, f32) alpha = np.array(alpha, f32) return smap.bond( simple_spring, space.canonicalize_displacement_or_metric(displacement_or_metric), bond, bond_type, length=length, epsilon=epsilon, alpha=alpha)
def radial_symmetry_functions_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, species: Array, etas: Array, cutoff_distance: float) -> Callable[[Array, NeighborList], Array]: """Returns a function that computes radial symmetry functions. Args: displacement: A function that produces an `[N_atoms, M_atoms, spatial_dimension]` of particle displacements from particle positions specified as an `[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]` respectively. species: An `[N_atoms]` that contains the species of each particle. etas: List of radial symmetry function parameters that control the spatial extension. cutoff_distance: Neighbors whose distance is larger than cutoff_distance do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance. Returns: A function that computes the radial symmetry function from input `[N_atoms, spatial_dimension]` and returns `[N_etas, N_atoms * N_types]` where N_etas is the number of eta parameters, N_types is the number of types of particles in the system. """ metric = space.canonicalize_displacement_or_metric(displacement_or_metric) def compute_fun(R: Array, neighbor: NeighborList, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_neighbor(_metric) radial_fn = lambda eta, dr: (np.exp( -eta * dr**2) * _behler_parrinello_cutoff_fn(dr, cutoff_distance)) def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" R_neigh = R[neighbor.idx] species_neigh = species[neighbor.idx] mask = np.logical_and(neighbor.idx < R.shape[0], species_neigh == atom_type) dr = _metric(R, R_neigh) radial = vmap(radial_fn, (0, None))(etas, dr) return util.high_precision_sum(radial * mask[np.newaxis, :, :], axis=2).T return np.hstack( [return_radial(atom_type) for atom_type in np.unique(species)]) return compute_fun
def pair_correlation(displacement_or_metric, rs, sigma): metric = space.canonicalize_displacement_or_metric(displacement_or_metric) metric = space.map_product(metric) sigma = f32(sigma) rs = np.array(rs + 1e-7, f32) def compute_fun(R, **dynamic_kwargs): dr = metric(R, R, **dynamic_kwargs) dr = np.where(dr > f32(1e-7), dr, f32(1e7)) dim = R.shape[1] exp = np.exp(-f32(0.5) * (dr[:, :, np.newaxis] - rs)**2 / sigma**2) gaussian_distances = exp / np.sqrt(2 * np.pi * sigma**2) return np.mean(gaussian_distances, axis=1) / rs**(dim - 1) return compute_fun
def test_canonicalize_displacement_or_metric(self, spatial_dimension, dtype): key = random.PRNGKey(0) displacement, _ = space.periodic_general(np.eye(spatial_dimension)) metric = space.metric(displacement) test_metric = space.canonicalize_displacement_or_metric(displacement) metric = space.map_product(metric) test_metric = space.map_product(test_metric) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.normal( split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) self.assertAllClose(metric(R, R), test_metric(R, R), True)
def simple_spring_bond(displacement_or_metric: DisplacementOrMetricFn, bond: Array, bond_type: Array = None, length: Array = 1, epsilon: Array = 1, alpha: Array = 2) -> Callable[[Array], Array]: """Convenience wrapper to compute energy of particles bonded by springs.""" length = np.array(length, f32) epsilon = np.array(epsilon, f32) alpha = np.array(alpha, f32) return smap.bond( simple_spring, space.canonicalize_displacement_or_metric(displacement_or_metric), bond, bond_type, length=length, epsilon=epsilon, alpha=alpha)
def bistable_spring_bond(displacement_or_metric, bond, bond_type=None, r0=1, a2=2, a4=5): """Convenience wrapper to compute energy of particles bonded by springs.""" r0 = jnp.array(r0, jnp.float32) a2 = jnp.array(a2, jnp.float32) a4 = jnp.array(a4, jnp.float32) return smap.bond( bistable_spring, space.canonicalize_displacement_or_metric(displacement_or_metric), bond, bond_type, r0=r0, a2=a2, a4=a4)
def test_stress_lammps_periodic_general(self, dim, dtype): key = random.PRNGKey(0) N = 64 (box, R, V), (E, C) = test_util.load_lammps_stress_data(dtype) displacement_fn, _ = space.periodic_general(box) energy_fn = smap.pair( lambda dr, **kwargs: jnp.where(dr < f32(2.5), energy.lennard_jones(dr), f32(0.0)), space.canonicalize_displacement_or_metric(displacement_fn)) ad_stress = quantity.stress(energy_fn, R, box, velocity=V) tol = 5e-5 self.assertAllClose(energy_fn(R) / len(R), E, atol=tol, rtol=tol) self.assertAllClose(C, ad_stress, atol=tol, rtol=tol)
def soft_sphere_pair(displacement_or_metric: DisplacementOrMetricFn, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, alpha: Array = 2.0, per_particle: bool = False): """Convenience wrapper to compute soft sphere energy over a system.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) alpha = np.array(alpha, dtype=f32) return smap.pair( soft_sphere, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None)
def simple_spring_bond(displacement_or_metric: DisplacementOrMetricFn, bond: Array, bond_type: Array = None, length: Array = 1, epsilon: Array = 1, alpha: Array = 2) -> Callable[[Array], Array]: """Convenience wrapper to compute energy of particles bonded by springs.""" length = maybe_downcast(length) epsilon = maybe_downcast(epsilon) alpha = maybe_downcast(alpha) return smap.bond( simple_spring, space.canonicalize_displacement_or_metric(displacement_or_metric), bond, bond_type, ignore_unused_parameters=True, length=length, epsilon=epsilon, alpha=alpha)
def bks_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array, Q_sq: Array, exp_coeff: Array, exp_decay: Array, attractive_coeff: Array, repulsive_coeff: Array, coulomb_alpha: Array, cutoff: float, dr_threshold: float = 0.8, fractional_coordinates: bool = False, ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute BKS energy using a neighbor list.""" Q_sq = maybe_downcast(Q_sq) exp_coeff = maybe_downcast(exp_coeff) exp_decay = maybe_downcast(exp_decay) attractive_coeff = maybe_downcast(attractive_coeff) repulsive_coeff = maybe_downcast(repulsive_coeff) dr_threshold = maybe_downcast(dr_threshold) neighbor_fn = partition.neighbor_list( displacement_or_metric, box_size, cutoff, dr_threshold, fractional_coordinates=fractional_coordinates) energy_fn = smap.pair_neighbor_list( bks, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, ignore_unused_parameters=True, Q_sq=Q_sq, exp_coeff=exp_coeff, exp_decay=exp_decay, attractive_coeff=attractive_coeff, repulsive_coeff=repulsive_coeff, coulomb_alpha=coulomb_alpha, cutoff=cutoff) return neighbor_fn, energy_fn
def soft_sphere_pair(displacement_or_metric: DisplacementOrMetricFn, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, alpha: Array = 2.0, per_particle: bool = False): """Convenience wrapper to compute soft sphere energy over a system.""" sigma = maybe_downcast(sigma) epsilon = maybe_downcast(epsilon) alpha = maybe_downcast(alpha) return smap.pair( soft_sphere, space.canonicalize_displacement_or_metric(displacement_or_metric), ignore_unused_parameters=True, species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None)
def pair_correlation(displacement_or_metric, rs, sigma): metric = space.canonicalize_displacement_or_metric(displacement_or_metric) metric = space.map_product(metric) sigma = f32(sigma) # NOTE(schsam): This seems rather harmless, but possibly something to look at rs = np.array(rs + 1e-7, f32) # TODO(schsam): Get this working with cell list . def compute_fun(R, **dynamic_kwargs): dr = metric(R, R, **dynamic_kwargs) # TODO(schsam): Clean up. dr = np.where(dr > f32(1e-7), dr, f32(1e7)) dim = R.shape[1] exp = np.exp(-f32(0.5) * (dr[:, :, np.newaxis] - rs)**2 / sigma**2) gaussian_distances = exp / np.sqrt(2 * np.pi * sigma**2) return np.mean(gaussian_distances, axis=1) / rs**(dim - 1) return compute_fun
def lennard_jones_pair(displacement_or_metric: DisplacementOrMetricFn, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, r_onset: Array = 2.0, r_cutoff: Array = 2.5, per_particle: bool = False) -> Callable[[Array], Array]: """Convenience wrapper to compute Lennard-Jones energy over a system.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) r_onset = r_onset * np.max(sigma) r_cutoff = r_cutoff * np.max(sigma) return smap.pair( multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, reduce_axis=(1, ) if per_particle else None)
def morse_pair(displacement_or_metric: DisplacementOrMetricFn, species: Array = None, sigma: Array = 1.0, epsilon: Array = 5.0, alpha: Array = 5.0, r_onset: float = 2.0, r_cutoff: float = 2.5, per_particle: bool = False) -> Callable[[Array], Array]: """Convenience wrapper to compute Morse energy over a system.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) alpha = np.array(alpha, dtype=f32) return smap.pair( multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None)
def bks_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array, Q_sq: Array, exp_coeff: Array, exp_decay: Array, attractive_coeff: Array, repulsive_coeff: Array, coulomb_alpha: Array, cutoff: float, dr_threshold: float = 0.8 ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute BKS energy using a neighbor list.""" Q_sq = np.array(Q_sq, f32) exp_coeff = np.array(exp_coeff, f32) exp_decay = np.array(exp_decay, f32) attractive_coeff = np.array(attractive_coeff, f32) repulsive_coeff = np.array(repulsive_coeff, f32) dr_threshold = f32(dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( bks, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, Q_sq=Q_sq, exp_coeff=exp_coeff, exp_decay=exp_decay, attractive_coeff=attractive_coeff, repulsive_coeff=repulsive_coeff, coulomb_alpha=coulomb_alpha, cutoff=cutoff) return neighbor_fn, energy_fn
def harmonic_morse_pair(displacement_or_metric, species=None, D0=5.0, alpha=10.0, r0=1.0, k=50.0): """The harmonic morse function over all pairs of particles in a system.""" # Initialize various parameters of the harmonic morse function D0 = jnp.array(D0, dtype=jnp.float32) alpha = jnp.array(alpha, dtype=jnp.float32) r0 = jnp.array(r0, dtype=jnp.float32) k = jnp.array(k, dtype=jnp.float32) # Pass the harmonic morse function defined above along with its parameters and a # displacement/metric function. return smap.pair( harmonic_morse, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, D0=D0, alpha=alpha, r0=r0, k=k)