def energy(R, **kwargs): dr = metric(R, R, **kwargs) total_charge = util.high_precision_sum(charge_fn(dr), axis=1) embedding_energy = embedding_fn(total_charge) pairwise_energy = util.high_precision_sum( smap._diagonal_mask(pairwise_fn(dr)), axis=1) / f32(2.0) return util.high_precision_sum(embedding_energy + pairwise_energy, axis=axis)
def fn_mapped(R: Array, neighbor: partition.NeighborList, **dynamic_kwargs) -> Array: d = partial(displacement_or_metric, **dynamic_kwargs) _species = dynamic_kwargs.get('species', species) normalization = 2.0 if partition.is_sparse(neighbor.format): d = space.map_bond(d) dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]]) mask = neighbor.idx[0] < R.shape[0] if neighbor.format is partition.OrderedSparse: normalization = 1.0 else: d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) mask = neighbor.idx < R.shape[0] merged_kwargs = merge_dicts(kwargs, dynamic_kwargs) merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format, neighbor.idx, _species, merged_kwargs, param_combinators) out = fn(dR, **merged_kwargs) if out.ndim > mask.ndim: ddim = out.ndim - mask.ndim mask = jnp.reshape(mask, mask.shape + (1, ) * ddim) out *= mask if reduce_axis is None: return util.high_precision_sum(out) / normalization if 0 in reduce_axis and 1 not in reduce_axis: raise ValueError() if not partition.is_sparse(neighbor.format): return util.high_precision_sum(out, reduce_axis) / normalization _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1) if 0 in reduce_axis: return util.high_precision_sum(out, (0, ) + _reduce_axis) if neighbor.format is partition.OrderedSparse: raise ValueError( 'Cannot report per-particle values with a neighbor ' 'list whose format is `OrderedSparse`. Please use ' 'either `Dense` or `Sparse`.') out = util.high_precision_sum(out, _reduce_axis) return ops.segment_sum(out, neighbor.idx[0], R.shape[0]) / normalization
def box_force(alpha, vol, box_fn, position, velocity, mass, force, pressure, **kwargs): N, dim = position.shape def U(vol): return energy_fn(position, box=box_fn(vol), **kwargs) dUdV = grad(U) KE2 = util.high_precision_sum(velocity**2 * mass) R = space.transform(box_fn(vol), position) RdotF = util.high_precision_sum(R * force) return alpha * KE2 + RdotF - dim * vol * dUdV( vol) - pressure * vol * dim
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: D_fn = partial(displacement, **kwargs) D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] species_neigh = species[neighbor.idx] atom_types = onp.unique(species) base_mask = neighbor.idx < len(R) mask = [ np.logical_and(base_mask, species_neigh == t) for t in atom_types ] out = [] dR = D_fn(R, R_neigh) all_angular = _all_pairs_angular(dR, dR) for i in range(len(atom_types)): mask_i = mask[i][:, :, np.newaxis, np.newaxis] for j in range(i, len(atom_types)): mask_j = mask[j][:, np.newaxis, :, np.newaxis] out += [ util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2]) ] return np.hstack(out)
def pressure(energy_fn: EnergyFn, position: Array, box: Box, kinetic_energy: float = 0.0, **kwargs) -> float: """Computes the internal pressure of a system. Note: This function requires that `energy_fn` take a `box` keyword argument. Most frequently, this is accomplished by using `periodic_general` boundary conditions combined with any of the energy functions in `energy.py`. This will not work with `space.periodic`. """ dim = position.shape[1] vol_0 = volume(dim, box) box_fn = lambda vol: (vol / vol_0)**(1 / dim) * box def U(vol): return energy_fn(position, box=box_fn(vol), **kwargs) dUdV = grad(U) KE = kinetic_energy F = force(energy_fn)(position, box=box, **kwargs) R = space.transform(box, position) RdotF = util.high_precision_sum(R * F) return 1 / (dim * vol_0) * (2 * KE + RdotF - dim * vol_0 * dUdV(vol_0))
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_neighbor(_metric) R_neigh = R[neighbor.idx] mask = (neighbor.idx < R.shape[0])[np.newaxis, :, :] dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr) * mask, axis=2).T
def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" R_neigh = R[neighbor.idx] species_neigh = species[neighbor.idx] mask = np.logical_and(neighbor.idx < R.shape[0], species_neigh == atom_type) dr = _metric(R, R_neigh) radial = vmap(radial_fn, (0, None))(etas, dr) return util.high_precision_sum(radial * mask[np.newaxis, :, :], axis=2).T
def sym_fn(R: Array, neighbor: NeighborList, mask_i: Array = None, mask_j: Array = None, **kwargs) -> Array: D_fn = partial(displacement, **kwargs) if neighbor.format is partition.Dense: D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] dR = D_fn(R, R_neigh) _all_pairs_angular = vmap( vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0) all_angular = _all_pairs_angular(dR, dR) mask_i = True if mask_i is None else mask_i[neighbor.idx] mask_j = True if mask_j is None else mask_j[neighbor.idx] mask_i = (neighbor.idx < R.shape[0]) & mask_i mask_i = mask_i[:, :, jnp.newaxis, jnp.newaxis] mask_j = (neighbor.idx < R.shape[0]) & mask_j mask_j = mask_j[:, jnp.newaxis, :, jnp.newaxis] return util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2]) elif neighbor.format is partition.Sparse: D_fn = space.map_bond(D_fn) dR = D_fn(R[neighbor.idx[0]], R[neighbor.idx[1]]) _all_pairs_angular = vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)) all_angular = _all_pairs_angular(dR, dR) N = R.shape[0] mask_i = True if mask_i is None else mask_i[neighbor.idx[1]] mask_j = True if mask_j is None else mask_j[neighbor.idx[1]] mask_i = (neighbor.idx[0] < N) & mask_i mask_j = (neighbor.idx[0] < N) & mask_j mask = mask_i[:, None] & mask_j[None, :] mask = mask[:, :, None, None] all_angular = jnp.reshape(all_angular, (-1, ) + all_angular.shape[2:]) neighbor_idx = jnp.repeat(neighbor.idx[0], len(neighbor.idx[0])) out = ops.segment_sum(all_angular, neighbor_idx, N) return out else: raise ValueError()
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: D_fn = partial(displacement, **kwargs) D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] mask = neighbor.idx < R.shape[0] dR = D_fn(R, R_neigh) all_angular = _all_pairs_angular(dR, dR) mask_i = mask[:, :, np.newaxis, np.newaxis] mask_j = mask[:, np.newaxis, :, np.newaxis] return util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2])
def stress(energy_fn: EnergyFn, position: Array, box: Box, mass: Array = 1.0, velocity: Optional[Array] = None, **kwargs) -> Array: """Computes the internal stress of a system. Args: energy_fn: A function that computes the energy of the system. This function must take as an argument `perturbation` which perturbes the box shape. Any energy function constructed using `smap` or in `energy.py` with a standard space will satisfy this property. position: An array of particle positions. box: A box specifying the shape of the simulation volume. Used to infer the volume of the unit cell. mass: The mass of the particles; only used to compute the kinetic contribution if `velocity` is not None. velocity: An array of atomic velocities. Returns: A float specifying the pressure of the system. """ dim = position.shape[1] zero = jnp.zeros((dim, dim), position.dtype) I = jnp.eye(dim, dtype=position.dtype) def U(eps): return energy_fn(position, perturbation=(I + eps), **kwargs) dUdV = grad(U) vol_0 = volume(dim, box) VxV = 0.0 if velocity is not None: V = velocity VxV = util.high_precision_sum(mass * V[:, None, :] * V[:, :, None], axis=0) return 1 / vol_0 * (VxV - dUdV(zero))
def sym_fn(R: Array, neighbor: NeighborList, mask: Array = None, **kwargs) -> Array: _metric = partial(metric, **kwargs) if neighbor.format is partition.Dense: _metric = space.map_neighbor(_metric) R_neigh = R[neighbor.idx] mask = True if mask is None else mask[neighbor.idx] mask = (neighbor.idx < R.shape[0])[None, :, :] & mask dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr) * mask, axis=2).T elif neighbor.format is partition.Sparse: _metric = space.map_bond(_metric) dr = _metric(R[neighbor.idx[0]], R[neighbor.idx[1]]) radial = radial_fn(etas, dr).T N = R.shape[0] mask = True if mask is None else mask[neighbor.idx[1]] mask = (neighbor.idx[0] < N) & mask return ops.segment_sum(radial * mask[:, None], neighbor.idx[0], N) else: raise ValueError()
def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" R_neigh = R[species == atom_type, :] dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr), axis=1).T
def compute_fn(R: Array, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_product(_metric) return util.high_precision_sum(radial_fn(etas, _metric(R, R)), axis=1).T
def temperature(velocity: Array, mass: Array = 1.0) -> float: """Computes the temperature of a system with some velocities.""" N, dim = velocity.shape return util.high_precision_sum(mass * velocity**2) / (N * dim)
def kinetic_energy(velocity: Array, mass: Array = 1.0) -> float: """Computes the kinetic energy of a system with some velocities.""" return 0.5 * util.high_precision_sum(mass * velocity**2)