def _get_neighborhood_matrix_params(format: partition.NeighborListFormat, idx: Array, params: Array, combinator: Callable) -> Array: if util.is_array(params): if len(params.shape) == 1: if partition.is_sparse(format): return space.map_bond(combinator)(params[idx[0]], params[idx[1]]) else: return combinator(params[:, None], params[idx]) return space.map_neighbor(combinator)(params, params[idx]) elif len(params.shape) == 2: def query(id_a, id_b): return params[id_a, id_b] if partition.is_sparse(format): return space.map_bond(query)(idx[0], idx[1]) else: query = vmap(vmap(query, (None, 0))) return query(jnp.arange(idx.shape[0], dtype=jnp.int32), idx) elif len(params.shape) == 0: return params else: raise NotImplementedError() elif (isinstance(params, int) or isinstance(params, float) or jnp.issubdtype(params, jnp.integer) or jnp.issubdtype(params, jnp.floating)): return params else: raise NotImplementedError()
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: D_fn = partial(displacement, **kwargs) D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] species_neigh = species[neighbor.idx] atom_types = onp.unique(species) base_mask = neighbor.idx < len(R) mask = [ np.logical_and(base_mask, species_neigh == t) for t in atom_types ] out = [] dR = D_fn(R, R_neigh) all_angular = _all_pairs_angular(dR, dR) for i in range(len(atom_types)): mask_i = mask[i][:, :, np.newaxis, np.newaxis] for j in range(i, len(atom_types)): mask_j = mask[j][:, np.newaxis, :, np.newaxis] out += [ util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2]) ] return np.hstack(out)
def g_fn(R, neighbor): N, dim = R.shape g_R = [] mask = partition.neighbor_list_mask(neighbor) if neighbor.format is partition.Dense: neighbor_species = species[neighbor.idx] R_neigh = R[neighbor.idx] d = space.map_neighbor(metric) _pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) for s in species_types: mask_s = mask * (neighbor_species == s) g_R += [ jnp.sum(mask_s[:, :, jnp.newaxis] * _pairwise(d(R, R_neigh), dim), axis=(1, )) ] elif neighbor.format is partition.Sparse: neighbor_species = species[neighbor.idx[1]] dr = space.map_bond(metric)(R[neighbor.idx[0]], R[neighbor.idx[1]]) _pairwise = vmap(pairwise, (0, None)) for s in species_types: mask_s = mask * (neighbor_species == s) g_R += [ ops.segment_sum(mask_s[:, None] * _pairwise(dr, dim), neighbor.idx[0], N) ] else: raise NotImplementedError( 'Pair correlation function does not support ' 'OrderedSparse neighbor lists.') return g_R
def model(R, neighbor, **kwargs): N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = jnp.zeros((N, 1), R.dtype) if nodes is None else nodes _globals = jnp.zeros((1, ), R.dtype) dr_2 = space.square_distance(dR) edge_idx = jnp.where(dr_2 < r_cutoff**2, neighbor.idx, N) net = network(n_recurrences=n_recurrences, mlp_sizes=mlp_sizes, mlp_kwargs=mlp_kwargs, polymer_length=polymer_length, polymer_dimensions=polymer_dimensions) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx)) # pytype: disable=wrong-arg-count
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_neighbor(_metric) R_neigh = R[neighbor.idx] mask = (neighbor.idx < R.shape[0])[np.newaxis, :, :] dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr) * mask, axis=2).T
def fn_mapped(R: Array, neighbor: partition.NeighborList, **dynamic_kwargs) -> Array: d = partial(displacement_or_metric, **dynamic_kwargs) _species = dynamic_kwargs.get('species', species) normalization = 2.0 if partition.is_sparse(neighbor.format): d = space.map_bond(d) dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]]) mask = neighbor.idx[0] < R.shape[0] if neighbor.format is partition.OrderedSparse: normalization = 1.0 else: d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) mask = neighbor.idx < R.shape[0] merged_kwargs = merge_dicts(kwargs, dynamic_kwargs) merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format, neighbor.idx, _species, merged_kwargs, param_combinators) out = fn(dR, **merged_kwargs) if out.ndim > mask.ndim: ddim = out.ndim - mask.ndim mask = jnp.reshape(mask, mask.shape + (1, ) * ddim) out *= mask if reduce_axis is None: return util.high_precision_sum(out) / normalization if 0 in reduce_axis and 1 not in reduce_axis: raise ValueError() if not partition.is_sparse(neighbor.format): return util.high_precision_sum(out, reduce_axis) / normalization _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1) if 0 in reduce_axis: return util.high_precision_sum(out, (0, ) + _reduce_axis) if neighbor.format is partition.OrderedSparse: raise ValueError( 'Cannot report per-particle values with a neighbor ' 'list whose format is `OrderedSparse`. Please use ' 'either `Dense` or `Sparse`.') out = util.high_precision_sum(out, _reduce_axis) return ops.segment_sum(out, neighbor.idx[0], R.shape[0]) / normalization
def sym_fn(R: Array, neighbor: NeighborList, mask_i: Array = None, mask_j: Array = None, **kwargs) -> Array: D_fn = partial(displacement, **kwargs) if neighbor.format is partition.Dense: D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] dR = D_fn(R, R_neigh) _all_pairs_angular = vmap( vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0) all_angular = _all_pairs_angular(dR, dR) mask_i = True if mask_i is None else mask_i[neighbor.idx] mask_j = True if mask_j is None else mask_j[neighbor.idx] mask_i = (neighbor.idx < R.shape[0]) & mask_i mask_i = mask_i[:, :, jnp.newaxis, jnp.newaxis] mask_j = (neighbor.idx < R.shape[0]) & mask_j mask_j = mask_j[:, jnp.newaxis, :, jnp.newaxis] return util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2]) elif neighbor.format is partition.Sparse: D_fn = space.map_bond(D_fn) dR = D_fn(R[neighbor.idx[0]], R[neighbor.idx[1]]) _all_pairs_angular = vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)) all_angular = _all_pairs_angular(dR, dR) N = R.shape[0] mask_i = True if mask_i is None else mask_i[neighbor.idx[1]] mask_j = True if mask_j is None else mask_j[neighbor.idx[1]] mask_i = (neighbor.idx[0] < N) & mask_i mask_j = (neighbor.idx[0] < N) & mask_j mask = mask_i[:, None] & mask_j[None, :] mask = mask[:, :, None, None] all_angular = jnp.reshape(all_angular, (-1, ) + all_angular.shape[2:]) neighbor_idx = jnp.repeat(neighbor.idx[0], len(neighbor.idx[0])) out = ops.segment_sum(all_angular, neighbor_idx, N) return out else: raise ValueError()
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: D_fn = partial(displacement, **kwargs) D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] mask = neighbor.idx < R.shape[0] dR = D_fn(R, R_neigh) all_angular = _all_pairs_angular(dR, dR) mask_i = mask[:, :, np.newaxis, np.newaxis] mask_j = mask[:, np.newaxis, :, np.newaxis] return util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2])
def prune_neighbor_list_dense(position: Array, idx: Array, **kwargs) -> Array: d = partial(metric_sq, **kwargs) d = space.map_neighbor(d) N = position.shape[0] neigh_position = position[idx] dR = d(position, neigh_position) mask = (dR < cutoff_sq) & (idx < N) out_idx = N * jnp.ones(idx.shape, i32) cumsum = jnp.cumsum(mask, axis=1) index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1) p_index = jnp.arange(idx.shape[0])[:, None] out_idx = out_idx.at[p_index, index].set(idx) max_occupancy = jnp.max(cumsum[:, -1]) return out_idx[:, :-1], max_occupancy
def prune_neighbor_list_dense(R, idx, **kwargs): d = partial(metric_sq, **kwargs) d = space.map_neighbor(d) N = R.shape[0] neigh_R = R[idx] dR = d(R, neigh_R) mask = (dR < cutoff_sq) & (idx < N) out_idx = N * jnp.ones(idx.shape, jnp.int32) cumsum = jnp.cumsum(mask, axis=1) index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1) p_index = jnp.arange(idx.shape[0])[:, None] out_idx = out_idx.at[p_index, index].set(idx) max_occupancy = jnp.max(cumsum[:, -1]) return out_idx[:, :-1], max_occupancy
def g_fn(R, neighbor): N, dim = R.shape mask = partition.neighbor_list_mask(neighbor) if neighbor.format is partition.Dense: R_neigh = R[neighbor.idx] d = space.map_neighbor(metric) _pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) return jnp.sum(mask[:, :, None] * _pairwise(d(R, R_neigh), dim), axis=(1,)) elif neighbor.format is partition.Sparse: dr = space.map_bond(metric)(R[neighbor.idx[0]], R[neighbor.idx[1]]) _pairwise = vmap(pairwise, (0, None)) return ops.segment_sum(mask[:, None] * _pairwise(dr, dim), neighbor.idx[0], N) else: raise NotImplementedError('Pair correlation function does not support ' 'OrderedSparse neighbor lists.')
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_neighbor(_metric) radial_fn = lambda eta, dr: (np.exp( -eta * dr**2) * _behler_parrinello_cutoff_fn(dr, cutoff_distance)) def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" R_neigh = R[neighbor.idx] species_neigh = species[neighbor.idx] mask = np.logical_and(neighbor.idx < R.shape[0], species_neigh == atom_type) dr = _metric(R, R_neigh) radial = vmap(radial_fn, (0, None))(etas, dr) return util.high_precision_sum(radial * mask[np.newaxis, :, :], axis=2).T return np.hstack( [return_radial(atom_type) for atom_type in onp.unique(species)])
def model(R, neighbor, **kwargs): N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes _globals = np.zeros((1, ), R.dtype) dr_2 = space.square_distance(dR) edge_idx = np.where(dr_2 < r_cutoff**2, neighbor.idx, N) net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))
def sym_fn(R: Array, neighbor: NeighborList, mask: Array = None, **kwargs) -> Array: _metric = partial(metric, **kwargs) if neighbor.format is partition.Dense: _metric = space.map_neighbor(_metric) R_neigh = R[neighbor.idx] mask = True if mask is None else mask[neighbor.idx] mask = (neighbor.idx < R.shape[0])[None, :, :] & mask dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr) * mask, axis=2).T elif neighbor.format is partition.Sparse: _metric = space.map_bond(_metric) dr = _metric(R[neighbor.idx[0]], R[neighbor.idx[1]]) radial = radial_fn(etas, dr).T N = R.shape[0] mask = True if mask is None else mask[neighbor.idx[1]] mask = (neighbor.idx[0] < N) & mask return ops.segment_sum(radial * mask[:, None], neighbor.idx[0], N) else: raise ValueError()
def pair_correlation_neighbor_list( displacement_or_metric: Union[DisplacementFn, MetricFn], box_size: Box, radii: Array, sigma: float, species: Array = None, dr_threshold: float = 0.5): """Computes the pair correlation function at a mesh of distances. The pair correlation function measures the number of particles at a given distance from a central particle. The pair correlation function is defined by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the approximation, $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$. This function uses neighbor lists to speed up the calculation. Args: displacement_or_metric: A function that computes the displacement or distance between two points. box_size: The size of the box containing the particles. radii: An array of radii at which we would like to compute g(r). sigima: A float specifying the width of the approximating Gaussian. species: An optional array specifying the species of each particle. If species is None then we compute a single g(r) for all particles, otherwise we compute one g(r) for each species. dr_threshold: A float specifying the halo size of the neighobr list. Returns: A pair of functions: `neighbor_fn` that constructs a neighbor list (see `neighbor_list` in `partition.py` for details). `g_fn` that computes the pair correlation function for a collection of particles given their position and a neighbor list. """ d = space.canonicalize_displacement_or_metric(displacement_or_metric) d = space.map_neighbor(d) def pairwise(dr, dim): return jnp.exp(-f32(0.5) * (dr - radii)**2 / sigma**2) / radii**(dim - 1) pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, jnp.max(radii) + sigma, dr_threshold) if species is None: def g_fn(R, neighbor): dim = R.shape[-1] R_neigh = R[neighbor.idx] mask = neighbor.idx < R.shape[0] return jnp.sum(mask[:, :, jnp.newaxis] * pairwise(d(R, R_neigh), dim), axis=(1, )) else: if not (isinstance(species, jnp.ndarray) and is_integer(species)): raise TypeError('Malformed species; expecting array of integers.') species_types = jnp.unique(species) def g_fn(R, neighbor): dim = R.shape[-1] g_R = [] mask = neighbor.idx < R.shape[0] neighbor_species = species[neighbor.idx] R_neigh = R[neighbor.idx] for s in species_types: mask_s = mask * (neighbor_species == s) g_R += [ jnp.sum(mask_s[:, :, jnp.newaxis] * pairwise(d(R, R_neigh), dim), axis=(1, )) ] return g_R return neighbor_fn, g_fn
def hybrid_swap_mc( space_fns: space.Space, energy_fn: Callable[[Array, Array], Array], neighbor_fn: partition.NeighborFn, dt: float, kT: float, t_md: float, N_swap: int, sigma_fn: Optional[Callable[[Array], Array]] = None) -> Simulator: """Simulation of Hybrid Swap Monte-Carlo. This code simulates the hybrid Swap Monte Carlo algorithm introduced in [1]. Here an NVT simulation is performed for `t_md` time and then `N_swap` MC moves are performed that swap the radii of randomly chosen particles. The random swaps are accepted with Metropolis-Hastings step. Each call to the step function runs molecular dynamics for `t_md` and then performs the swaps. Note that this code doesn't feature some of the convenience functions in the other simulations. In particular, there is no support for dynamics keyword arguments and the energy function must be a simple callable of two variables: the distance between adjacent particles and the diameter of the particles. If you want support for a better notion of potential or dynamic keyword arguments, please file an issue! Args: space_fns: A tuple of a displacement function and a shift function defined in `space.py`. energy_fn: A function that computes the energy between one pair of particles as a function of the distance between the particles and the diameter. This function should not have been passed to `smap.xxx`. neighbor_fn: A function to construct neighbor lists outlined in `partition.py`. dt: The timestep used for the continuous time MD portion of the simulation. kT: The temperature of heat bath that the system is coupled to during MD. t_md: The time of each MD block. N_swap: The number of swapping moves between MD blocks. sigma_fn: An optional function for combining radii if they are to be non-additive. Returns: See above. [1] L. Berthier, E. Flenner, C. J. Fullerton, C. Scalliet, and M. Singh. "Efficient swap algorithms for molecular dynamics simulations of equilibrium supercooled liquids" J. Stat. Mech. (2019) 064004 """ displacement_fn, shift_fn = space_fns metric_fn = space.metric(displacement_fn) nbr_metric_fn = space.map_neighbor(metric_fn) md_steps = int(t_md // dt) # Canonicalize the argument names to be dr and sigma. wrapped_energy_fn = lambda dr, sigma: energy_fn(dr, sigma) if sigma_fn is None: sigma_fn = lambda si, sj: 0.5 * (si + sj) nbr_energy_fn = smap.pair_neighbor_list(wrapped_energy_fn, metric_fn, sigma=sigma_fn) nvt_init_fn, nvt_step_fn = nvt_nose_hoover(nbr_energy_fn, shift_fn, dt, kT=kT, chain_length=3) def init_fn(key, position, sigma, nbrs=None): key, sim_key = random.split(key) nbrs = neighbor_fn(position, nbrs) # pytype: disable=wrong-arg-count md_state = nvt_init_fn(sim_key, position, neighbor=nbrs, sigma=sigma) return SwapMCState(md_state, sigma, key, nbrs) # pytype: disable=wrong-arg-count def md_step_fn(i, state): md, sigma, key, nbrs = dataclasses.unpack(state) md = nvt_step_fn(md, neighbor=nbrs, sigma=sigma) # pytype: disable=wrong-keyword-args nbrs = neighbor_fn(md.position, nbrs) return SwapMCState(md, sigma, key, nbrs) # pytype: disable=wrong-arg-count def swap_step_fn(i, state): md, sigma, key, nbrs = dataclasses.unpack(state) N = md.position.shape[0] # Swap a random pair of particle radii. key, particle_key, accept_key = random.split(key, 3) ij = random.randint(particle_key, (2, ), jnp.array(0), jnp.array(N)) new_sigma = sigma.at[ij].set([sigma[ij[1]], sigma[ij[0]]]) # Collect neighborhoods around the two swapped particles. nbrs_ij = nbrs.idx[ij] R_ij = md.position[ij] R_neigh = md.position[nbrs_ij] sigma_ij = sigma[ij][:, None] sigma_neigh = sigma[nbrs_ij] new_sigma_ij = new_sigma[ij][:, None] new_sigma_neigh = new_sigma[nbrs_ij] dR = nbr_metric_fn(R_ij, R_neigh) # Compute the energy before the swap. energy = energy_fn(dR, sigma_fn(sigma_ij, sigma_neigh)) energy = jnp.sum(energy * (nbrs_ij < N)) # Compute the energy after the swap. new_energy = energy_fn(dR, sigma_fn(new_sigma_ij, new_sigma_neigh)) new_energy = jnp.sum(new_energy * (nbrs_ij < N)) # Accept or reject with a metropolis probability. p = random.uniform(accept_key, ()) accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT)) sigma = jnp.where(p < accept_prob, new_sigma, sigma) return SwapMCState(md, sigma, key, nbrs) # pytype: disable=wrong-arg-count def block_fn(state): state = lax.fori_loop(0, md_steps, md_step_fn, state) state = lax.fori_loop(0, N_swap, swap_step_fn, state) return state return init_fn, block_fn