Пример #1
0
def _get_neighborhood_matrix_params(format: partition.NeighborListFormat,
                                    idx: Array, params: Array,
                                    combinator: Callable) -> Array:
    if util.is_array(params):
        if len(params.shape) == 1:
            if partition.is_sparse(format):
                return space.map_bond(combinator)(params[idx[0]],
                                                  params[idx[1]])
            else:
                return combinator(params[:, None], params[idx])
                return space.map_neighbor(combinator)(params, params[idx])
        elif len(params.shape) == 2:

            def query(id_a, id_b):
                return params[id_a, id_b]

            if partition.is_sparse(format):
                return space.map_bond(query)(idx[0], idx[1])
            else:
                query = vmap(vmap(query, (None, 0)))
                return query(jnp.arange(idx.shape[0], dtype=jnp.int32), idx)
        elif len(params.shape) == 0:
            return params
        else:
            raise NotImplementedError()
    elif (isinstance(params, int) or isinstance(params, float)
          or jnp.issubdtype(params, jnp.integer)
          or jnp.issubdtype(params, jnp.floating)):
        return params
    else:
        raise NotImplementedError()
Пример #2
0
    def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array:
        D_fn = partial(displacement, **kwargs)
        D_fn = space.map_neighbor(D_fn)

        R_neigh = R[neighbor.idx]
        species_neigh = species[neighbor.idx]

        atom_types = onp.unique(species)

        base_mask = neighbor.idx < len(R)
        mask = [
            np.logical_and(base_mask, species_neigh == t) for t in atom_types
        ]

        out = []
        dR = D_fn(R, R_neigh)
        all_angular = _all_pairs_angular(dR, dR)

        for i in range(len(atom_types)):
            mask_i = mask[i][:, :, np.newaxis, np.newaxis]
            for j in range(i, len(atom_types)):
                mask_j = mask[j][:, np.newaxis, :, np.newaxis]
                out += [
                    util.high_precision_sum(all_angular * mask_i * mask_j,
                                            axis=[1, 2])
                ]
        return np.hstack(out)
Пример #3
0
        def g_fn(R, neighbor):
            N, dim = R.shape
            g_R = []
            mask = partition.neighbor_list_mask(neighbor)
            if neighbor.format is partition.Dense:
                neighbor_species = species[neighbor.idx]
                R_neigh = R[neighbor.idx]
                d = space.map_neighbor(metric)
                _pairwise = vmap(vmap(pairwise, (0, None)), (0, None))
                for s in species_types:
                    mask_s = mask * (neighbor_species == s)
                    g_R += [
                        jnp.sum(mask_s[:, :, jnp.newaxis] *
                                _pairwise(d(R, R_neigh), dim),
                                axis=(1, ))
                    ]
            elif neighbor.format is partition.Sparse:
                neighbor_species = species[neighbor.idx[1]]
                dr = space.map_bond(metric)(R[neighbor.idx[0]],
                                            R[neighbor.idx[1]])
                _pairwise = vmap(pairwise, (0, None))
                for s in species_types:
                    mask_s = mask * (neighbor_species == s)
                    g_R += [
                        ops.segment_sum(mask_s[:, None] * _pairwise(dr, dim),
                                        neighbor.idx[0], N)
                    ]
            else:
                raise NotImplementedError(
                    'Pair correlation function does not support '
                    'OrderedSparse neighbor lists.')

            return g_R
Пример #4
0
    def model(R, neighbor, **kwargs):
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_neighbor(d)
        R_neigh = R[neighbor.idx]
        dR = d(R, R_neigh)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = jnp.zeros((N, 1), R.dtype) if nodes is None else nodes

        _globals = jnp.zeros((1, ), R.dtype)

        dr_2 = space.square_distance(dR)
        edge_idx = jnp.where(dr_2 < r_cutoff**2, neighbor.idx, N)

        net = network(n_recurrences=n_recurrences,
                      mlp_sizes=mlp_sizes,
                      mlp_kwargs=mlp_kwargs,
                      polymer_length=polymer_length,
                      polymer_dimensions=polymer_dimensions)

        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))  # pytype: disable=wrong-arg-count
Пример #5
0
 def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array:
     _metric = partial(metric, **kwargs)
     _metric = space.map_neighbor(_metric)
     R_neigh = R[neighbor.idx]
     mask = (neighbor.idx < R.shape[0])[np.newaxis, :, :]
     dr = _metric(R, R_neigh)
     return util.high_precision_sum(radial_fn(etas, dr) * mask,
                                    axis=2).T
Пример #6
0
    def fn_mapped(R: Array, neighbor: partition.NeighborList,
                  **dynamic_kwargs) -> Array:
        d = partial(displacement_or_metric, **dynamic_kwargs)
        _species = dynamic_kwargs.get('species', species)

        normalization = 2.0

        if partition.is_sparse(neighbor.format):
            d = space.map_bond(d)
            dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]])
            mask = neighbor.idx[0] < R.shape[0]
            if neighbor.format is partition.OrderedSparse:
                normalization = 1.0
        else:
            d = space.map_neighbor(d)
            R_neigh = R[neighbor.idx]
            dR = d(R, R_neigh)
            mask = neighbor.idx < R.shape[0]

        merged_kwargs = merge_dicts(kwargs, dynamic_kwargs)
        merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format,
                                                       neighbor.idx, _species,
                                                       merged_kwargs,
                                                       param_combinators)
        out = fn(dR, **merged_kwargs)
        if out.ndim > mask.ndim:
            ddim = out.ndim - mask.ndim
            mask = jnp.reshape(mask, mask.shape + (1, ) * ddim)
        out *= mask

        if reduce_axis is None:
            return util.high_precision_sum(out) / normalization

        if 0 in reduce_axis and 1 not in reduce_axis:
            raise ValueError()

        if not partition.is_sparse(neighbor.format):
            return util.high_precision_sum(out, reduce_axis) / normalization

        _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1)

        if 0 in reduce_axis:
            return util.high_precision_sum(out, (0, ) + _reduce_axis)

        if neighbor.format is partition.OrderedSparse:
            raise ValueError(
                'Cannot report per-particle values with a neighbor '
                'list whose format is `OrderedSparse`. Please use '
                'either `Dense` or `Sparse`.')

        out = util.high_precision_sum(out, _reduce_axis)
        return ops.segment_sum(out, neighbor.idx[0],
                               R.shape[0]) / normalization
Пример #7
0
    def sym_fn(R: Array,
               neighbor: NeighborList,
               mask_i: Array = None,
               mask_j: Array = None,
               **kwargs) -> Array:
        D_fn = partial(displacement, **kwargs)

        if neighbor.format is partition.Dense:
            D_fn = space.map_neighbor(D_fn)

            R_neigh = R[neighbor.idx]

            dR = D_fn(R, R_neigh)
            _all_pairs_angular = vmap(
                vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0)
            all_angular = _all_pairs_angular(dR, dR)

            mask_i = True if mask_i is None else mask_i[neighbor.idx]
            mask_j = True if mask_j is None else mask_j[neighbor.idx]

            mask_i = (neighbor.idx < R.shape[0]) & mask_i
            mask_i = mask_i[:, :, jnp.newaxis, jnp.newaxis]
            mask_j = (neighbor.idx < R.shape[0]) & mask_j
            mask_j = mask_j[:, jnp.newaxis, :, jnp.newaxis]

            return util.high_precision_sum(all_angular * mask_i * mask_j,
                                           axis=[1, 2])
        elif neighbor.format is partition.Sparse:
            D_fn = space.map_bond(D_fn)
            dR = D_fn(R[neighbor.idx[0]], R[neighbor.idx[1]])
            _all_pairs_angular = vmap(vmap(_batched_angular_fn, (0, None)),
                                      (None, 0))
            all_angular = _all_pairs_angular(dR, dR)

            N = R.shape[0]
            mask_i = True if mask_i is None else mask_i[neighbor.idx[1]]
            mask_j = True if mask_j is None else mask_j[neighbor.idx[1]]
            mask_i = (neighbor.idx[0] < N) & mask_i
            mask_j = (neighbor.idx[0] < N) & mask_j

            mask = mask_i[:, None] & mask_j[None, :]
            mask = mask[:, :, None, None]
            all_angular = jnp.reshape(all_angular,
                                      (-1, ) + all_angular.shape[2:])
            neighbor_idx = jnp.repeat(neighbor.idx[0], len(neighbor.idx[0]))
            out = ops.segment_sum(all_angular, neighbor_idx, N)
            return out
        else:
            raise ValueError()
Пример #8
0
        def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array:
            D_fn = partial(displacement, **kwargs)
            D_fn = space.map_neighbor(D_fn)

            R_neigh = R[neighbor.idx]
            mask = neighbor.idx < R.shape[0]

            dR = D_fn(R, R_neigh)
            all_angular = _all_pairs_angular(dR, dR)

            mask_i = mask[:, :, np.newaxis, np.newaxis]
            mask_j = mask[:, np.newaxis, :, np.newaxis]

            return util.high_precision_sum(all_angular * mask_i * mask_j,
                                           axis=[1, 2])
Пример #9
0
  def prune_neighbor_list_dense(position: Array, idx: Array, **kwargs) -> Array:
    d = partial(metric_sq, **kwargs)
    d = space.map_neighbor(d)

    N = position.shape[0]
    neigh_position = position[idx]
    dR = d(position, neigh_position)

    mask = (dR < cutoff_sq) & (idx < N)
    out_idx = N * jnp.ones(idx.shape, i32)

    cumsum = jnp.cumsum(mask, axis=1)
    index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1)
    p_index = jnp.arange(idx.shape[0])[:, None]
    out_idx = out_idx.at[p_index, index].set(idx)
    max_occupancy = jnp.max(cumsum[:, -1])

    return out_idx[:, :-1], max_occupancy
Пример #10
0
  def prune_neighbor_list_dense(R, idx, **kwargs):
    d = partial(metric_sq, **kwargs)
    d = space.map_neighbor(d)

    N = R.shape[0]
    neigh_R = R[idx]
    dR = d(R, neigh_R)

    mask = (dR < cutoff_sq) & (idx < N)
    out_idx = N * jnp.ones(idx.shape, jnp.int32)

    cumsum = jnp.cumsum(mask, axis=1)
    index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1)
    p_index = jnp.arange(idx.shape[0])[:, None]
    out_idx = out_idx.at[p_index, index].set(idx)
    max_occupancy = jnp.max(cumsum[:, -1])

    return out_idx[:, :-1], max_occupancy
Пример #11
0
 def g_fn(R, neighbor):
   N, dim = R.shape
   mask = partition.neighbor_list_mask(neighbor)
   if neighbor.format is partition.Dense:
     R_neigh = R[neighbor.idx]
     d = space.map_neighbor(metric)
     _pairwise = vmap(vmap(pairwise, (0, None)), (0, None))
     return jnp.sum(mask[:, :, None] *
                    _pairwise(d(R, R_neigh), dim), axis=(1,))
   elif neighbor.format is partition.Sparse:
     dr = space.map_bond(metric)(R[neighbor.idx[0]], R[neighbor.idx[1]])
     _pairwise = vmap(pairwise, (0, None))
     return ops.segment_sum(mask[:, None] * _pairwise(dr, dim),
                            neighbor.idx[0],
                            N)
   else:
     raise NotImplementedError('Pair correlation function does not support '
                               'OrderedSparse neighbor lists.')
Пример #12
0
    def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array:
        _metric = partial(metric, **kwargs)
        _metric = space.map_neighbor(_metric)
        radial_fn = lambda eta, dr: (np.exp(
            -eta * dr**2) * _behler_parrinello_cutoff_fn(dr, cutoff_distance))

        def return_radial(atom_type):
            """Returns the radial symmetry functions for neighbor type atom_type."""
            R_neigh = R[neighbor.idx]
            species_neigh = species[neighbor.idx]
            mask = np.logical_and(neighbor.idx < R.shape[0],
                                  species_neigh == atom_type)
            dr = _metric(R, R_neigh)

            radial = vmap(radial_fn, (0, None))(etas, dr)
            return util.high_precision_sum(radial * mask[np.newaxis, :, :],
                                           axis=2).T

        return np.hstack(
            [return_radial(atom_type) for atom_type in onp.unique(species)])
Пример #13
0
    def model(R, neighbor, **kwargs):
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_neighbor(d)
        R_neigh = R[neighbor.idx]
        dR = d(R, R_neigh)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes

        _globals = np.zeros((1, ), R.dtype)

        dr_2 = space.square_distance(dR)
        edge_idx = np.where(dr_2 < r_cutoff**2, neighbor.idx, N)

        net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs)
        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))
Пример #14
0
 def sym_fn(R: Array,
            neighbor: NeighborList,
            mask: Array = None,
            **kwargs) -> Array:
     _metric = partial(metric, **kwargs)
     if neighbor.format is partition.Dense:
         _metric = space.map_neighbor(_metric)
         R_neigh = R[neighbor.idx]
         mask = True if mask is None else mask[neighbor.idx]
         mask = (neighbor.idx < R.shape[0])[None, :, :] & mask
         dr = _metric(R, R_neigh)
         return util.high_precision_sum(radial_fn(etas, dr) * mask,
                                        axis=2).T
     elif neighbor.format is partition.Sparse:
         _metric = space.map_bond(_metric)
         dr = _metric(R[neighbor.idx[0]], R[neighbor.idx[1]])
         radial = radial_fn(etas, dr).T
         N = R.shape[0]
         mask = True if mask is None else mask[neighbor.idx[1]]
         mask = (neighbor.idx[0] < N) & mask
         return ops.segment_sum(radial * mask[:, None], neighbor.idx[0], N)
     else:
         raise ValueError()
Пример #15
0
def pair_correlation_neighbor_list(
        displacement_or_metric: Union[DisplacementFn, MetricFn],
        box_size: Box,
        radii: Array,
        sigma: float,
        species: Array = None,
        dr_threshold: float = 0.5):
    """Computes the pair correlation function at a mesh of distances.

  The pair correlation function measures the number of particles at a given
  distance from a central particle. The pair correlation function is defined
  by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the
  approximation,
  $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$.

  This function uses neighbor lists to speed up the calculation.

  Args:
    displacement_or_metric: A function that computes the displacement or
      distance between two points.
    box_size: The size of the box containing the particles.
    radii: An array of radii at which we would like to compute g(r).
    sigima: A float specifying the width of the approximating Gaussian.
    species: An optional array specifying the species of each particle. If
      species is None then we compute a single g(r) for all particles,
      otherwise we compute one g(r) for each species.
    dr_threshold: A float specifying the halo size of the neighobr list.

  Returns:
    A pair of functions: `neighbor_fn` that constructs a neighbor list (see
    `neighbor_list` in `partition.py` for details). `g_fn` that computes the
    pair correlation function for a collection of particles given their
    position and a neighbor list.
  """
    d = space.canonicalize_displacement_or_metric(displacement_or_metric)
    d = space.map_neighbor(d)

    def pairwise(dr, dim):
        return jnp.exp(-f32(0.5) *
                       (dr - radii)**2 / sigma**2) / radii**(dim - 1)

    pairwise = vmap(vmap(pairwise, (0, None)), (0, None))

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          jnp.max(radii) + sigma, dr_threshold)

    if species is None:

        def g_fn(R, neighbor):
            dim = R.shape[-1]
            R_neigh = R[neighbor.idx]
            mask = neighbor.idx < R.shape[0]
            return jnp.sum(mask[:, :, jnp.newaxis] *
                           pairwise(d(R, R_neigh), dim),
                           axis=(1, ))
    else:
        if not (isinstance(species, jnp.ndarray) and is_integer(species)):
            raise TypeError('Malformed species; expecting array of integers.')
        species_types = jnp.unique(species)

        def g_fn(R, neighbor):
            dim = R.shape[-1]
            g_R = []
            mask = neighbor.idx < R.shape[0]
            neighbor_species = species[neighbor.idx]
            R_neigh = R[neighbor.idx]
            for s in species_types:
                mask_s = mask * (neighbor_species == s)
                g_R += [
                    jnp.sum(mask_s[:, :, jnp.newaxis] *
                            pairwise(d(R, R_neigh), dim),
                            axis=(1, ))
                ]
            return g_R

    return neighbor_fn, g_fn
Пример #16
0
def hybrid_swap_mc(
        space_fns: space.Space,
        energy_fn: Callable[[Array, Array], Array],
        neighbor_fn: partition.NeighborFn,
        dt: float,
        kT: float,
        t_md: float,
        N_swap: int,
        sigma_fn: Optional[Callable[[Array], Array]] = None) -> Simulator:
    """Simulation of Hybrid Swap Monte-Carlo.

  This code simulates the hybrid Swap Monte Carlo algorithm introduced in [1].
  Here an NVT simulation is performed for `t_md` time and then `N_swap` MC
  moves are performed that swap the radii of randomly chosen particles. The
  random swaps are accepted with Metropolis-Hastings step. Each call to the
  step function runs molecular dynamics for `t_md` and then performs the swaps.

  Note that this code doesn't feature some of the convenience functions in the
  other simulations. In particular, there is no support for dynamics keyword
  arguments and the energy function must be a simple callable of two variables:
  the distance between adjacent particles and the diameter of the particles.
  If you want support for a better notion of potential or dynamic keyword
  arguments, please file an issue!

  Args:
    space_fns: A tuple of a displacement function and a shift function defined
      in `space.py`.
    energy_fn: A function that computes the energy between one pair of
      particles as a function of the distance between the particles and the
      diameter. This function should not have been passed to `smap.xxx`.
    neighbor_fn: A function to construct neighbor lists outlined in
      `partition.py`.
    dt: The timestep used for the continuous time MD portion of the simulation.
    kT: The temperature of heat bath that the system is coupled to during MD.
    t_md: The time of each MD block.
    N_swap: The number of swapping moves between MD blocks.
    sigma_fn: An optional function for combining radii if they are to be
      non-additive.

  Returns:
    See above.

  [1] L. Berthier, E. Flenner, C. J. Fullerton, C. Scalliet, and M. Singh.
      "Efficient swap algorithms for molecular dynamics simulations of
       equilibrium supercooled liquids"
      J. Stat. Mech. (2019) 064004
  """
    displacement_fn, shift_fn = space_fns
    metric_fn = space.metric(displacement_fn)
    nbr_metric_fn = space.map_neighbor(metric_fn)

    md_steps = int(t_md // dt)

    # Canonicalize the argument names to be dr and sigma.
    wrapped_energy_fn = lambda dr, sigma: energy_fn(dr, sigma)
    if sigma_fn is None:
        sigma_fn = lambda si, sj: 0.5 * (si + sj)
    nbr_energy_fn = smap.pair_neighbor_list(wrapped_energy_fn,
                                            metric_fn,
                                            sigma=sigma_fn)

    nvt_init_fn, nvt_step_fn = nvt_nose_hoover(nbr_energy_fn,
                                               shift_fn,
                                               dt,
                                               kT=kT,
                                               chain_length=3)

    def init_fn(key, position, sigma, nbrs=None):
        key, sim_key = random.split(key)
        nbrs = neighbor_fn(position, nbrs)  # pytype: disable=wrong-arg-count
        md_state = nvt_init_fn(sim_key, position, neighbor=nbrs, sigma=sigma)
        return SwapMCState(md_state, sigma, key, nbrs)  # pytype: disable=wrong-arg-count

    def md_step_fn(i, state):
        md, sigma, key, nbrs = dataclasses.unpack(state)
        md = nvt_step_fn(md, neighbor=nbrs, sigma=sigma)  # pytype: disable=wrong-keyword-args
        nbrs = neighbor_fn(md.position, nbrs)
        return SwapMCState(md, sigma, key, nbrs)  # pytype: disable=wrong-arg-count

    def swap_step_fn(i, state):
        md, sigma, key, nbrs = dataclasses.unpack(state)

        N = md.position.shape[0]

        # Swap a random pair of particle radii.
        key, particle_key, accept_key = random.split(key, 3)
        ij = random.randint(particle_key, (2, ), jnp.array(0), jnp.array(N))
        new_sigma = sigma.at[ij].set([sigma[ij[1]], sigma[ij[0]]])

        # Collect neighborhoods around the two swapped particles.
        nbrs_ij = nbrs.idx[ij]
        R_ij = md.position[ij]
        R_neigh = md.position[nbrs_ij]

        sigma_ij = sigma[ij][:, None]
        sigma_neigh = sigma[nbrs_ij]

        new_sigma_ij = new_sigma[ij][:, None]
        new_sigma_neigh = new_sigma[nbrs_ij]

        dR = nbr_metric_fn(R_ij, R_neigh)

        # Compute the energy before the swap.
        energy = energy_fn(dR, sigma_fn(sigma_ij, sigma_neigh))
        energy = jnp.sum(energy * (nbrs_ij < N))

        # Compute the energy after the swap.
        new_energy = energy_fn(dR, sigma_fn(new_sigma_ij, new_sigma_neigh))
        new_energy = jnp.sum(new_energy * (nbrs_ij < N))

        # Accept or reject with a metropolis probability.
        p = random.uniform(accept_key, ())
        accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT))
        sigma = jnp.where(p < accept_prob, new_sigma, sigma)

        return SwapMCState(md, sigma, key, nbrs)  # pytype: disable=wrong-arg-count

    def block_fn(state):
        state = lax.fori_loop(0, md_steps, md_step_fn, state)
        state = lax.fori_loop(0, N_swap, swap_step_fn, state)
        return state

    return init_fn, block_fn