Example #1
0
def _unflatten_cell_buffer(arr: Array, cells_per_side: Array,
                           dim: int) -> Array:
    if (isinstance(cells_per_side, int) or isinstance(cells_per_side, float)
            or (util.is_array(cells_per_side) and not cells_per_side.shape)):
        cells_per_side = (int(cells_per_side), ) * dim
    elif util.is_array(cells_per_side) and len(cells_per_side.shape) == 1:
        cells_per_side = tuple([int(x) for x in cells_per_side[::-1]])
    elif util.is_array(cells_per_side) and len(cells_per_side.shape) == 2:
        cells_per_side = tuple([int(x) for x in cells_per_side[0][::-1]])
    else:
        raise ValueError()
    return jnp.reshape(arr, cells_per_side + (-1, ) + arr.shape[1:])
Example #2
0
def _is_variable_compatible_with_positions(R: Array) -> bool:
  if (util.is_array(R) and
      len(R.shape) == 2 and
      jnp.issubdtype(R.dtype, jnp.floating)):
    return True

  return False
Example #3
0
def _get_neighborhood_matrix_params(format: partition.NeighborListFormat,
                                    idx: Array, params: Array,
                                    combinator: Callable) -> Array:
    if util.is_array(params):
        if len(params.shape) == 1:
            if partition.is_sparse(format):
                return space.map_bond(combinator)(params[idx[0]],
                                                  params[idx[1]])
            else:
                return combinator(params[:, None], params[idx])
                return space.map_neighbor(combinator)(params, params[idx])
        elif len(params.shape) == 2:

            def query(id_a, id_b):
                return params[id_a, id_b]

            if partition.is_sparse(format):
                return space.map_bond(query)(idx[0], idx[1])
            else:
                query = vmap(vmap(query, (None, 0)))
                return query(jnp.arange(idx.shape[0], dtype=jnp.int32), idx)
        elif len(params.shape) == 0:
            return params
        else:
            raise NotImplementedError()
    elif (isinstance(params, int) or isinstance(params, float)
          or jnp.issubdtype(params, jnp.integer)
          or jnp.issubdtype(params, jnp.floating)):
        return params
    else:
        raise NotImplementedError()
Example #4
0
def _get_bond_type_parameters(params: Array, bond_type: Array) -> Array:
    """Get parameters for interactions for bonds indexed by a bond-type."""
    # TODO(schsam): We should do better error checking here.
    assert util.is_array(bond_type)
    assert len(bond_type.shape) == 1

    if util.is_array(params):
        if len(params.shape) == 1:
            return params[bond_type]
        elif len(params.shape) == 0:
            return params
        else:
            raise ValueError(
                'Params must be a scalar or a 1d array if using a bond-type lookup.'
            )
    elif (isinstance(params, int) or isinstance(params, float)
          or jnp.issubdtype(params, jnp.integer)
          or jnp.issubdtype(params, jnp.floating)):
        return params
    raise NotImplementedError
Example #5
0
def _get_species_parameters(params: Array, species: Array) -> Array:
    """Get parameters for interactions between species pairs."""
    # TODO(schsam): We should do better error checking here.
    if util.is_array(params):
        if len(params.shape) == 2:
            return params[species]
        elif len(params.shape) == 0:
            return params
        else:
            raise ValueError(
                'Params must be a scalar or a 2d array if using a species lookup.'
            )
    return params
Example #6
0
def _get_matrix_parameters(params: Array, combinator: Callable) -> Array:
    """Get an NxN parameter matrix from per-particle parameters."""
    if util.is_array(params):
        if len(params.shape) == 1:
            return combinator(params[:, jnp.newaxis], params[jnp.newaxis, :])
        elif len(params.shape) == 0 or len(params.shape) == 2:
            return params
        else:
            raise NotImplementedError
    elif (isinstance(params, int) or isinstance(params, float)
          or jnp.issubdtype(params, jnp.integer)
          or jnp.issubdtype(params, jnp.floating)):
        return params
    else:
        raise NotImplementedError
Example #7
0
def _neighborhood_kwargs_to_params(
        format: partition.NeighborListFormat, idx: Array, species: Array,
        kwargs: Dict[str,
                     Array], combinators: Dict[str,
                                               Callable]) -> Dict[str, Array]:
    out_dict = {}
    for k in kwargs:
        if species is None or (util.is_array(kwargs[k])
                               and kwargs[k].ndim == 1):
            combinator = combinators.get(k, lambda x, y: 0.5 * (x + y))
            out_dict[k] = _get_neighborhood_matrix_params(
                format, idx, kwargs[k], combinator)
        else:
            if k in combinators:
                raise ValueError()
            out_dict[k] = _get_neighborhood_species_params(
                format, idx, species, kwargs[k])
    return out_dict
Example #8
0
def _get_neighborhood_species_params(format: partition.NeighborListFormat,
                                     idx: Array, species: Array,
                                     params: Array) -> Array:
    """Get parameters for interactions between species pairs."""

    # TODO(schsam): We should do better error checking here.
    def lookup(species_a, species_b):
        return params[species_a, species_b]

    if util.is_array(params):
        if len(params.shape) == 2:
            if partition.is_sparse(format):
                return space.map_bond(lookup)(species[idx[0]], species[idx[1]])
            else:
                lookup = vmap(vmap(lookup, (None, 0)))
                return lookup(species, species[idx])
        elif len(params.shape) == 0:
            return params
        else:
            raise ValueError(
                'Params must be a scalar or a 2d array if using a species lookup.'
            )
    return params
Example #9
0
    def cell_list_fn(position: Array,
                     capacity_overflow_update: Optional[Tuple[
                         int, bool, Callable[..., CellList]]] = None,
                     extra_capacity: int = 0,
                     **kwargs) -> CellList:
        N = position.shape[0]
        dim = position.shape[1]

        if dim != 2 and dim != 3:
            # NOTE(schsam): Do we want to check this in compute_fn as well?
            raise ValueError(
                f'Cell list spatial dimension must be 2 or 3. Found {dim}.')

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        if capacity_overflow_update is None:
            cell_capacity = _estimate_cell_capacity(position, box_size,
                                                    cell_size,
                                                    buffer_size_multiplier)
            cell_capacity += extra_capacity
            overflow = False
            update_fn = cell_list_fn
        else:
            cell_capacity, overflow, update_fn = capacity_overflow_update

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create cell list data.
        particle_id = lax.iota(i32, N)
        # NOTE(schsam): We use the convention that particles that are successfully,
        # copied have their true id whereas particles empty slots have id = N.
        # Then when we copy data back from the grid, copy it to an array of shape
        # [N + 1, output_dimension] and then truncate it to an array of shape
        # [N, output_dimension] which ignores the empty slots.
        cell_position = jnp.zeros((cell_count * cell_capacity, dim),
                                  dtype=position.dtype)
        cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32)

        # It might be worth adding an occupied mask. However, that will involve
        # more compute since often we will do a mask for species that will include
        # an occupancy test. It seems easier to design around this empty_data_value
        # for now and revisit the issue if it comes up later.
        empty_kwarg_value = 10**5
        cell_kwargs = {}
        #  pytype: disable=attribute-error
        for k, v in kwargs.items():
            if not util.is_array(v):
                raise ValueError(
                    (f'Data must be specified as an ndarray. Found "{k}" '
                     f'with type {type(v)}.'))
            if v.shape[0] != position.shape[0]:
                raise ValueError(
                    ('Data must be specified per-particle (an ndarray '
                     f'with shape ({N}, ...)). Found "{k}" with '
                     f'shape {v.shape}.'))
            kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, )
            cell_kwargs[k] = empty_kwarg_value * jnp.ones(
                (cell_count * cell_capacity, ) + kwarg_shape, v.dtype)
        #  pytype: enable=attribute-error
        indices = jnp.array(position / cell_size, dtype=i32)
        hashes = jnp.sum(indices * hash_multipliers, axis=1)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where
        # grid_id is a flat list that repeats 0, .., cell_capacity. So long as
        # there are fewer than cell_capacity particles per cell, each particle is
        # guarenteed to get a cell id that is unique.
        sort_map = jnp.argsort(hashes)
        sorted_position = position[sort_map]
        sorted_hash = hashes[sort_map]
        sorted_id = particle_id[sort_map]

        sorted_kwargs = {}
        for k, v in kwargs.items():
            sorted_kwargs[k] = v[sort_map]

        sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        cell_position = cell_position.at[sorted_cell_id].set(sorted_position)
        sorted_id = jnp.reshape(sorted_id, (N, 1))
        cell_id = cell_id.at[sorted_cell_id].set(sorted_id)
        cell_position = _unflatten_cell_buffer(cell_position, cells_per_side,
                                               dim)
        cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

        for k, v in sorted_kwargs.items():
            if v.ndim == 1:
                v = jnp.reshape(v, v.shape + (1, ))
            cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v)
            cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k],
                                                    cells_per_side, dim)

        occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
        max_occupancy = jnp.max(occupancy)
        overflow = overflow | (max_occupancy >= cell_capacity)

        return CellList(cell_position, cell_id, cell_kwargs, overflow,
                        cell_capacity, update_fn)  # pytype: disable=wrong-arg-count
Example #10
0
def cell_list(box_size: Box,
              minimum_cell_size: float,
              buffer_size_multiplier: float = 1.25) -> CellListFns:
    r"""Returns a function that partitions point data spatially.

  Given a set of points {x_i \in R^d} with associated data {k_i \in R^m} it is
  often useful to partition the points / data spatially. A simple partitioning
  that can be implemented efficiently within XLA is a dense partition into a
  uniform grid called a cell list.

  Since XLA requires that shapes be statically specified inside of a JIT block,
  the cell list code can operate in two modes: allocation and update.

  Allocation creates a new cell list that uses a set of input positions to
  estimate the capacity of the cell list. This capacity can be adjusted by
  setting the `buffer_size_multiplier` or setting the `extra_capacity`.
  Allocation cannot be JIT.

  Updating takes a previously allocated cell list and places a new set of
  particles in the cells. Updating cannot resize the cell list and is therefore
  compatible with JIT. However, if the configuration has changed substantially
  it is possible that the existing cell list won't be large enough to
  accommodate all of the particles. In this case the `did_buffer_overflow` bit
  will be set to True.

  Args:
    box_size: A float or an ndarray of shape [spatial_dimension] specifying the
      size of the system. Note, this code is written for the case where the
      boundaries are periodic. If this is not the case, then the current code
      will be slightly less efficient.
    minimum_cell_size: A float specifying the minimum side length of each cell.
      Cells are enlarged so that they exactly fill the box.
    buffer_size_multiplier: A floating point multiplier that multiplies the
      estimated cell capacity to allow for fluctuations in the maximum cell
      occupancy.
  Returns:
    A CellListFns object that contains two methods, one to allocate the cell
    list and one to update the cell list. The update function can be called
    with either a cell list from which the capacity can be inferred or with
    an explicit integer denoting the capacity. Note that an existing cell list
    can also be updated by calling `cell_list.update(position)`.
  """

    if util.is_array(box_size):
        box_size = onp.array(box_size)
        if len(box_size.shape) == 1:
            box_size = jnp.reshape(box_size, (1, -1))

    if util.is_array(minimum_cell_size):
        minimum_cell_size = onp.array(minimum_cell_size)

    def cell_list_fn(position: Array,
                     capacity_overflow_update: Optional[Tuple[
                         int, bool, Callable[..., CellList]]] = None,
                     extra_capacity: int = 0,
                     **kwargs) -> CellList:
        N = position.shape[0]
        dim = position.shape[1]

        if dim != 2 and dim != 3:
            # NOTE(schsam): Do we want to check this in compute_fn as well?
            raise ValueError(
                f'Cell list spatial dimension must be 2 or 3. Found {dim}.')

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        if capacity_overflow_update is None:
            cell_capacity = _estimate_cell_capacity(position, box_size,
                                                    cell_size,
                                                    buffer_size_multiplier)
            cell_capacity += extra_capacity
            overflow = False
            update_fn = cell_list_fn
        else:
            cell_capacity, overflow, update_fn = capacity_overflow_update

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create cell list data.
        particle_id = lax.iota(i32, N)
        # NOTE(schsam): We use the convention that particles that are successfully,
        # copied have their true id whereas particles empty slots have id = N.
        # Then when we copy data back from the grid, copy it to an array of shape
        # [N + 1, output_dimension] and then truncate it to an array of shape
        # [N, output_dimension] which ignores the empty slots.
        cell_position = jnp.zeros((cell_count * cell_capacity, dim),
                                  dtype=position.dtype)
        cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32)

        # It might be worth adding an occupied mask. However, that will involve
        # more compute since often we will do a mask for species that will include
        # an occupancy test. It seems easier to design around this empty_data_value
        # for now and revisit the issue if it comes up later.
        empty_kwarg_value = 10**5
        cell_kwargs = {}
        #  pytype: disable=attribute-error
        for k, v in kwargs.items():
            if not util.is_array(v):
                raise ValueError(
                    (f'Data must be specified as an ndarray. Found "{k}" '
                     f'with type {type(v)}.'))
            if v.shape[0] != position.shape[0]:
                raise ValueError(
                    ('Data must be specified per-particle (an ndarray '
                     f'with shape ({N}, ...)). Found "{k}" with '
                     f'shape {v.shape}.'))
            kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, )
            cell_kwargs[k] = empty_kwarg_value * jnp.ones(
                (cell_count * cell_capacity, ) + kwarg_shape, v.dtype)
        #  pytype: enable=attribute-error
        indices = jnp.array(position / cell_size, dtype=i32)
        hashes = jnp.sum(indices * hash_multipliers, axis=1)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where
        # grid_id is a flat list that repeats 0, .., cell_capacity. So long as
        # there are fewer than cell_capacity particles per cell, each particle is
        # guarenteed to get a cell id that is unique.
        sort_map = jnp.argsort(hashes)
        sorted_position = position[sort_map]
        sorted_hash = hashes[sort_map]
        sorted_id = particle_id[sort_map]

        sorted_kwargs = {}
        for k, v in kwargs.items():
            sorted_kwargs[k] = v[sort_map]

        sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        cell_position = cell_position.at[sorted_cell_id].set(sorted_position)
        sorted_id = jnp.reshape(sorted_id, (N, 1))
        cell_id = cell_id.at[sorted_cell_id].set(sorted_id)
        cell_position = _unflatten_cell_buffer(cell_position, cells_per_side,
                                               dim)
        cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

        for k, v in sorted_kwargs.items():
            if v.ndim == 1:
                v = jnp.reshape(v, v.shape + (1, ))
            cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v)
            cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k],
                                                    cells_per_side, dim)

        occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
        max_occupancy = jnp.max(occupancy)
        overflow = overflow | (max_occupancy >= cell_capacity)

        return CellList(cell_position, cell_id, cell_kwargs, overflow,
                        cell_capacity, update_fn)  # pytype: disable=wrong-arg-count

    def allocate_fn(position: Array,
                    extra_capacity: int = 0,
                    **kwargs) -> CellList:
        return cell_list_fn(position, extra_capacity=extra_capacity, **kwargs)

    def update_fn(position: Array, cl_or_capacity: Union[CellList, int],
                  **kwargs) -> CellList:
        if isinstance(cl_or_capacity, int):
            capacity = int(cl_or_capacity)
            return cell_list_fn(position, (capacity, False, cell_list_fn),
                                **kwargs)
        cl = cl_or_capacity
        cl_data = (cl.cell_capacity, cl.did_buffer_overflow, cl.update_fn)
        return cell_list_fn(position, cl_data, **kwargs)

    return CellListFns(allocate_fn, update_fn)  # pytype: disable=wrong-arg-count
Example #11
0
def triplet(fn: Callable[..., Array],
            displacement_or_metric: DisplacementOrMetricFn,
            species: Optional[Array] = None,
            reduce_axis: Optional[Tuple[int, ...]] = None,
            keepdims: bool = False,
            ignore_unused_parameters: bool = False,
            **kwargs) -> Callable[..., Array]:
    """Promotes a function that acts on triples of particles to one on a system.

  Many empirical potentials in jax_md include three-body angular terms (e.g.
  Stillinger Weber). This utility function simplifies the loss computation
  in such cases by converting a function that takes in two pairwise displacements
  or distances to one that only requires the system as input.

  Args:
    fn: A function that takes an ndarray of two distances or displacements
        from a central atom, both of shape [n, m] or [n, m, d_in] respectively,
        as well as kwargs specifying parameters for the function.
    metric: A function that takes two ndarray of positions of shape
        [spatial_dimensions] and [spatial_dimensions] respectively and
        returns an ndarray of distances or displacements of shape [] or [d_in]
        respectively. The metric can optionally take a floating point time as a
        third argument.
    species: A list of species for the different particles. This should either
      be None (in which case it is assumed that all the particles have the same
      species), an integer ndarray of shape [n] with species data, or an
      integer in which case the species data will be specified dynamically with
      `species` giving the naximum number of types of particles. Note: that
      dynamic species specification is less efficient, because we cannot
      specialize shape information.
    reduce_axis: A list of axis to reduce over. This is supplied to np.sum and
        the same convention is used.
    keepdims: A boolean specifying whether the empty dimensions should be kept
        upon reduction. This is supplied to np.sum and so the same convention
        is used.
    ignore_unused_parameters: A boolean that denotes whether dynamically
      specified keyword arguments passed to the mapped function get ignored
      if they were not first specified as keyword arguments when calling
      `smap.triplet(...)`.
    kwargs: Arguement providing parameters to the mapped function. In cases
        where no species information is provided, these should either be 1)
        a scalar, 2) an ndarray of shape [n] based on the central atom,
        3) an ndarray of shape [n, n, n] defining triplet interactions.
        If species information is provided, then the parameters should
        be specified as either 1) a scalar,  2) an ndarray of shape
        [max_species], 3) an ndarray of shape [max_species, max_species,
        max_species] defining triplet interactions.

  Returns:
    A function fn_mapped.

    If species is None or statically specified, then fn_mapped takes as
    arguments an ndarray of positions of shape [n, spatial_dimension].

    If species is dynamic then fn_mapped takes as input an ndarray of shape
    [n, spatial_dimension], an integer ndarray of species of shape [n], and
    an integer specifying the maximum species.

    The mapped function can also optionally take keyword arguments that get
    threaded through the metric.
  """
    merge_dicts = partial(util.merge_dicts,
                          ignore_unused_parameters=ignore_unused_parameters)

    def extract_parameters_by_dim(kwargs, dim: Union[int, List[int]] = 0):
        """Extract parameters from a dictionary via dimension."""
        if isinstance(dim, int):
            dim = [dim]
        return {
            name: value
            for name, value in kwargs.items() if value.ndim in dim
        }

    if species is None:

        def fn_mapped(R, **dynamic_kwargs) -> Array:
            d = space.map_product(
                partial(displacement_or_metric, **dynamic_kwargs))
            _kwargs = merge_dicts(kwargs, dynamic_kwargs)
            _kwargs = _kwargs_to_parameters(species, _kwargs, {})
            dR = d(R, R)
            compute_triplet = partial(fn, **_kwargs)
            output = vmap(vmap(vmap(compute_triplet, (None, 0)), (0, None)),
                          0)(dR, dR)
            return high_precision_sum(
                output, axis=reduce_axis, keepdims=keepdims) / 2.
    elif util.is_array(species):

        def fn_mapped(R, **dynamic_kwargs):
            d = partial(displacement_or_metric, **dynamic_kwargs)
            idx = onp.tile(onp.arange(R.shape[0]), [R.shape[0], 1])
            dR = vmap(vmap(d, (None, 0)))(R, R[idx])

            _kwargs = merge_dicts(kwargs, dynamic_kwargs)

            mapped_args = extract_parameters_by_dim(_kwargs, [3])
            mapped_args = {
                arg_name: arg_value[species]
                for arg_name, arg_value in mapped_args.items()
            }
            # While we support 2 dimensional inputs, these often make less sense
            # as the parameters do not depend on the central atom
            unmapped_args = extract_parameters_by_dim(_kwargs, [0])

            if extract_parameters_by_dim(_kwargs, [1, 2]):
                assert ValueError(
                    'Improper argument dimensions (1 or 2) not well '
                    'defined for triplets.')

            def compute_triplet(dR, mapped_args, unmapped_args):
                paired_args = extract_parameters_by_dim(mapped_args, 2)
                paired_args.update(extract_parameters_by_dim(unmapped_args, 2))

                unpaired_args = extract_parameters_by_dim(mapped_args, 0)
                unpaired_args.update(
                    extract_parameters_by_dim(unmapped_args, 0))

                output_fn = lambda dR1, dR2, paired_args: fn(
                    dR1, dR2, **unpaired_args, **paired_args)
                neighbor_args = _neighborhood_kwargs_to_params(
                    partition.Dense, idx, species, paired_args, {})
                output_fn = vmap(vmap(output_fn, (None, 0, 0)), (0, None, 0))
                return output_fn(dR, dR, neighbor_args)

            output_fn = partial(compute_triplet, unmapped_args=unmapped_args)
            output = vmap(output_fn)(dR, mapped_args)
            return high_precision_sum(
                output, axis=reduce_axis, keepdims=keepdims) / 2.
    elif isinstance(species, int):
        raise NotImplementedError
    else:
        raise ValueError(
            'Species must be None, an ndarray, or Dynamic. Found {}.'.format(
                species))
    return fn_mapped
Example #12
0
def pair(fn: Callable[..., Array],
         displacement_or_metric: DisplacementOrMetricFn,
         species: Optional[Array] = None,
         reduce_axis: Optional[Tuple[int, ...]] = None,
         keepdims: bool = False,
         ignore_unused_parameters: bool = False,
         **kwargs) -> Callable[..., Array]:
    """Promotes a function that acts on a pair of particles to one on a system.

  Args:
    fn: A function that takes an ndarray of pairwise distances or displacements
      of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying
      parameters for the function. fn returns an ndarray of evaluations of shape
      [n, m, d_out].
    metric: A function that takes two ndarray of positions of shape
      [spatial_dimension] and [spatial_dimension] respectively and returns
      an ndarray of distances or displacements of shape [] or [d_in]
      respectively. The metric can optionally take a floating point time as a
      third argument.
    species: A list of species for the different particles. This should either
      be None (in which case it is assumed that all the particles have the same
      species), an integer ndarray of shape [n] with species data, or an
      integer in which case the species data will be specified dynamically with
      `species` giving the naximum number of types of particles. Note: that
      dynamic species specification is less efficient, because we cannot
      specialize shape information.
    reduce_axis: A list of axes to reduce over. This is supplied to jnp.sum and
      so the same convention is used.
    keepdims: A boolean specifying whether the empty dimensions should be kept
      upon reduction. This is supplied to jnp.sum and so the same convention is
      used.
    ignore_unused_parameters: A boolean that denotes whether dynamically
      specified keyword arguments passed to the mapped function get ignored
      if they were not first specified as keyword arguments when calling
      `smap.pair(...)`.
    kwargs: Arguments providing parameters to the mapped function. In cases
      where no species information is provided these should be either 1) a
      scalar, 2) an ndarray of shape [n], 3) an ndarray of shape [n, n],
      3) a binary function that determines how per-particle parameters are to
      be combined, 4) a binary function as well as a default set of parameters
      as in 2). If unspecified then this is taken to be the average of the
      two per-particle parameters. If species information is provided then the
      parameters should be specified as either 1) a scalar or 2) an ndarray of
      shape [max_species, max_species].

  Returns:
    A function fn_mapped.

    If species is None or statically specified then fn_mapped takes as arguments
    an ndarray of positions of shape [n, spatial_dimension].

    If species is dynamic then fn_mapped takes as input an ndarray of shape
    [n, spatial_dimension], an integer ndarray of species of shape [n], and an
    integer specifying the maximum species.

    The mapped function can also optionally take keyword arguments that get
    threaded through the metric.
  """

    # Each application of vmap adds a single batch dimension. For computations
    # over all pairs of particles, we would like to promote the metric function
    # from one that computes the displacement / distance between two vectors to
    # one that acts over the cartesian product of two sets of vectors. This is
    # equivalent to two applications of vmap adding one batch dimension for the
    # first set and then one for the second.

    kwargs, param_combinators = _split_params_and_combinators(kwargs)

    merge_dicts = partial(util.merge_dicts,
                          ignore_unused_parameters=ignore_unused_parameters)

    if species is None:

        def fn_mapped(R: Array, **dynamic_kwargs) -> Array:
            d = space.map_product(
                partial(displacement_or_metric, **dynamic_kwargs))
            _kwargs = merge_dicts(kwargs, dynamic_kwargs)
            _kwargs = _kwargs_to_parameters(None, _kwargs, param_combinators)
            dr = d(R, R)
            # NOTE(schsam): Currently we place a diagonal mask no matter what function
            # we are mapping. Should this be an option?
            return high_precision_sum(_diagonal_mask(fn(dr, **_kwargs)),
                                      axis=reduce_axis,
                                      keepdims=keepdims) * f32(0.5)
    elif util.is_array(species):
        species = onp.array(species)
        _check_species_dtype(species)
        species_count = int(onp.max(species))
        if reduce_axis is not None or keepdims:
            # TODO(schsam): Support reduce_axis with static species.
            raise ValueError

        def fn_mapped(R, **dynamic_kwargs):
            U = f32(0.0)
            d = space.map_product(
                partial(displacement_or_metric, **dynamic_kwargs))
            for i in range(species_count + 1):
                for j in range(i, species_count + 1):
                    _kwargs = merge_dicts(kwargs, dynamic_kwargs)
                    s_kwargs = _kwargs_to_parameters((i, j), _kwargs,
                                                     param_combinators)
                    Ra = R[species == i]
                    Rb = R[species == j]
                    dr = d(Ra, Rb)
                    if j == i:
                        dU = high_precision_sum(
                            _diagonal_mask(fn(dr, **s_kwargs)))
                        U = U + f32(0.5) * dU
                    else:
                        dU = high_precision_sum(fn(dr, **s_kwargs))
                        U = U + dU
            return U
    elif isinstance(species, int):
        species_count = species

        def fn_mapped(R, species, **dynamic_kwargs):
            _check_species_dtype(species)
            U = f32(0.0)
            N = R.shape[0]
            d = space.map_product(
                partial(displacement_or_metric, **dynamic_kwargs))
            _kwargs = merge_dicts(kwargs, dynamic_kwargs)
            dr = d(R, R)
            for i in range(species_count):
                for j in range(species_count):
                    s_kwargs = _kwargs_to_parameters((i, j), _kwargs,
                                                     param_combinators)
                    mask_a = jnp.array(jnp.reshape(species == i, (N, )),
                                       dtype=R.dtype)
                    mask_b = jnp.array(jnp.reshape(species == j, (N, )),
                                       dtype=R.dtype)
                    mask = mask_a[:, jnp.newaxis] * mask_b[jnp.newaxis, :]
                    if i == j:
                        mask = mask * _diagonal_mask(mask)
                    dU = mask * fn(dr, **s_kwargs)
                    U = U + high_precision_sum(
                        dU, axis=reduce_axis, keepdims=keepdims)
            return U / f32(2.0)
    else:
        raise ValueError(
            'Species must be None, an ndarray, or an integer. Found {}.'.
            format(species))
    return fn_mapped
Example #13
0
  def build_cells(R: Array, extra_capacity: int=0, **kwargs) -> CellList:
    N = R.shape[0]
    dim = R.shape[1]

    _cell_capacity = cell_capacity + extra_capacity

    if dim != 2 and dim != 3:
      # NOTE(schsam): Do we want to check this in compute_fn as well?
      raise ValueError(
          'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim))

    neighborhood_tile_count = 3 ** dim

    _, cell_size, cells_per_side, cell_count = \
        _cell_dimensions(dim, box_size, minimum_cell_size)

    hash_multipliers = _compute_hash_constants(dim, cells_per_side)

    # Create cell list data.
    particle_id = lax.iota(jnp.int64, N)
    # NOTE(schsam): We use the convention that particles that are successfully,
    # copied have their true id whereas particles empty slots have id = N.
    # Then when we copy data back from the grid, copy it to an array of shape
    # [N + 1, output_dimension] and then truncate it to an array of shape
    # [N, output_dimension] which ignores the empty slots.
    mask_id = jnp.ones((N,), jnp.int64) * N
    cell_R = jnp.zeros((cell_count * _cell_capacity, dim), dtype=R.dtype)
    cell_id = N * jnp.ones((cell_count * _cell_capacity, 1), dtype=i32)

    # It might be worth adding an occupied mask. However, that will involve
    # more compute since often we will do a mask for species that will include
    # an occupancy test. It seems easier to design around this empty_data_value
    # for now and revisit the issue if it comes up later.
    empty_kwarg_value = 10 ** 5
    cell_kwargs = {}
    for k, v in kwargs.items():
      if not util.is_array(v):
        raise ValueError((
          'Data must be specified as an ndarry. Found "{}" with '
          'type {}'.format(k, type(v))))
      if v.shape[0] != R.shape[0]:
        raise ValueError(
          ('Data must be specified per-particle (an ndarray with shape '
           '(R.shape[0], ...)). Found "{}" with shape {}'.format(k, v.shape)))
      kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,)
      cell_kwargs[k] = empty_kwarg_value * jnp.ones(
        (cell_count * _cell_capacity,) + kwarg_shape, v.dtype)

    indices = jnp.array(R / cell_size, dtype=i32)
    hashes = jnp.sum(indices * hash_multipliers, axis=1)

    # Copy the particle data into the grid. Here we use a trick to allow us to
    # copy into all cells simultaneously using a single lax.scatter call. To do
    # this we first sort particles by their cell hash. We then assign each
    # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
    # is a flat list that repeats 0, .., cell_capacity. So long as there are
    # fewer than cell_capacity particles per cell, each particle is guarenteed
    # to get a cell id that is unique.
    sort_map = jnp.argsort(hashes)
    sorted_R = R[sort_map]
    sorted_hash = hashes[sort_map]
    sorted_id = particle_id[sort_map]

    sorted_kwargs = {}
    for k, v in kwargs.items():
      sorted_kwargs[k] = v[sort_map]

    sorted_cell_id = jnp.mod(lax.iota(jnp.int64, N), _cell_capacity)
    sorted_cell_id = sorted_hash * _cell_capacity + sorted_cell_id

    cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
    sorted_id = jnp.reshape(sorted_id, (N, 1))
    cell_id = ops.index_update(
        cell_id, sorted_cell_id, sorted_id)
    cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
    cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

    for k, v in sorted_kwargs.items():
      if v.ndim == 1:
        v = jnp.reshape(v, v.shape + (1,))
      cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id, v)
      cell_kwargs[k] = _unflatten_cell_buffer(
        cell_kwargs[k], cells_per_side, dim)

    return CellList(cell_R, cell_id, cell_kwargs)  # pytype: disable=wrong-arg-count
Example #14
0
def cell_list(box_size: Box,
              minimum_cell_size: float,
              cell_capacity_or_example_R: Union[int, Array],
              buffer_size_multiplier: float=1.1
              ) -> Callable[[Array], CellList]:
  r"""Returns a function that partitions point data spatially.

  Given a set of points {x_i \in R^d} with associated data {k_i \in R^m} it is
  often useful to partition the points / data spatially. A simple partitioning
  that can be implemented efficiently within XLA is a dense partition into a
  uniform grid called a cell list.

  Since XLA requires that shapes be statically specified, we allocate fixed
  sized buffers for each cell. The size of this buffer can either be specified
  manually or it can be estimated automatically from a set of positions. Note,
  if the distribution of points changes significantly it is likely the buffer
  the buffer sizes will have to be adjusted.

  This partitioning will likely form the groundwork for parallelizing
  simulations over different accelerators.

  Args:
    box_size: A float or an ndarray of shape [spatial_dimension] specifying the
      size of the system. Note, this code is written for the case where the
      boundaries are periodic. If this is not the case, then the current code
      will be slightly less efficient.
    minimum_cell_size: A float specifying the minimum side length of each cell.
      Cells are enlarged so that they exactly fill the box.
    cell_capacity_or_example_R: Either an integer specifying the size
      number of particles that can be stored in each cell or an ndarray of
      positions of shape [particle_count, spatial_dimension] that is used to
      estimate the cell_capacity.
    buffer_size_multiplier: A floating point multiplier that multiplies the
      estimated cell capacity to allow for fluctuations in the maximum cell
      occupancy.
  Returns:
    A function `cell_list_fn(R, **kwargs)` that partitions positions, `R`, and
    side data specified by kwargs into a cell list. Returns a CellList
    containing the partition.
  """

  if util.is_array(box_size):
    box_size = onp.array(box_size)
    if len(box_size.shape) == 1:
      box_size = jnp.reshape(box_size, (1, -1))

  if util.is_array(minimum_cell_size):
    minimum_cell_size = onp.array(minimum_cell_size)

  cell_capacity = cell_capacity_or_example_R
  if _is_variable_compatible_with_positions(cell_capacity):
    cell_capacity = _estimate_cell_capacity(
      cell_capacity, box_size, minimum_cell_size, buffer_size_multiplier)
  elif not isinstance(cell_capacity, int):
    msg = (
        'cell_capacity_or_example_positions must either be an integer '
        'specifying the cell capacity or a set of positions that will be used '
        'to estimate a cell capacity. Found {}.'.format(type(cell_capacity))
        )
    raise ValueError(msg)

  def build_cells(R: Array, extra_capacity: int=0, **kwargs) -> CellList:
    N = R.shape[0]
    dim = R.shape[1]

    _cell_capacity = cell_capacity + extra_capacity

    if dim != 2 and dim != 3:
      # NOTE(schsam): Do we want to check this in compute_fn as well?
      raise ValueError(
          'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim))

    neighborhood_tile_count = 3 ** dim

    _, cell_size, cells_per_side, cell_count = \
        _cell_dimensions(dim, box_size, minimum_cell_size)

    hash_multipliers = _compute_hash_constants(dim, cells_per_side)

    # Create cell list data.
    particle_id = lax.iota(jnp.int64, N)
    # NOTE(schsam): We use the convention that particles that are successfully,
    # copied have their true id whereas particles empty slots have id = N.
    # Then when we copy data back from the grid, copy it to an array of shape
    # [N + 1, output_dimension] and then truncate it to an array of shape
    # [N, output_dimension] which ignores the empty slots.
    mask_id = jnp.ones((N,), jnp.int64) * N
    cell_R = jnp.zeros((cell_count * _cell_capacity, dim), dtype=R.dtype)
    cell_id = N * jnp.ones((cell_count * _cell_capacity, 1), dtype=i32)

    # It might be worth adding an occupied mask. However, that will involve
    # more compute since often we will do a mask for species that will include
    # an occupancy test. It seems easier to design around this empty_data_value
    # for now and revisit the issue if it comes up later.
    empty_kwarg_value = 10 ** 5
    cell_kwargs = {}
    for k, v in kwargs.items():
      if not util.is_array(v):
        raise ValueError((
          'Data must be specified as an ndarry. Found "{}" with '
          'type {}'.format(k, type(v))))
      if v.shape[0] != R.shape[0]:
        raise ValueError(
          ('Data must be specified per-particle (an ndarray with shape '
           '(R.shape[0], ...)). Found "{}" with shape {}'.format(k, v.shape)))
      kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,)
      cell_kwargs[k] = empty_kwarg_value * jnp.ones(
        (cell_count * _cell_capacity,) + kwarg_shape, v.dtype)

    indices = jnp.array(R / cell_size, dtype=i32)
    hashes = jnp.sum(indices * hash_multipliers, axis=1)

    # Copy the particle data into the grid. Here we use a trick to allow us to
    # copy into all cells simultaneously using a single lax.scatter call. To do
    # this we first sort particles by their cell hash. We then assign each
    # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
    # is a flat list that repeats 0, .., cell_capacity. So long as there are
    # fewer than cell_capacity particles per cell, each particle is guarenteed
    # to get a cell id that is unique.
    sort_map = jnp.argsort(hashes)
    sorted_R = R[sort_map]
    sorted_hash = hashes[sort_map]
    sorted_id = particle_id[sort_map]

    sorted_kwargs = {}
    for k, v in kwargs.items():
      sorted_kwargs[k] = v[sort_map]

    sorted_cell_id = jnp.mod(lax.iota(jnp.int64, N), _cell_capacity)
    sorted_cell_id = sorted_hash * _cell_capacity + sorted_cell_id

    cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
    sorted_id = jnp.reshape(sorted_id, (N, 1))
    cell_id = ops.index_update(
        cell_id, sorted_cell_id, sorted_id)
    cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
    cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

    for k, v in sorted_kwargs.items():
      if v.ndim == 1:
        v = jnp.reshape(v, v.shape + (1,))
      cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id, v)
      cell_kwargs[k] = _unflatten_cell_buffer(
        cell_kwargs[k], cells_per_side, dim)

    return CellList(cell_R, cell_id, cell_kwargs)  # pytype: disable=wrong-arg-count
  return build_cells