Ejemplo n.º 1
0
class CellListFns:
    allocate: Callable[..., CellList] = dataclasses.static_field()
    update: Callable[[Array, Union[CellList, int]],
                     CellList] = dataclasses.static_field()

    def __iter__(self):
        return iter((self.allocate, self.update))
Ejemplo n.º 2
0
class Disk:
    """Disk geometry elements.

  Args:
    position: An array of shape `(steps, count, dim)` or `(count, dim)`
      specifying possibly time varying positions. Here `dim` is the spatial
      dimension.
    size: An array of shape (steps, count)`, `(count,)`, or `()` specifying
      possibly time-varying / per-disk diameters.
    color: An array of shape `(steps, count, 3)` or `(count,)` specifying
      possibly time-varying / per-disk RGB colors.
    count: The number of disks.
  """
    position: jnp.ndarray
    size: jnp.ndarray
    color: jnp.ndarray
    count: int = dataclasses.static_field()

    def __init__(self, position, diameter=1.0, color=None):
        if color is None:
            color = jnp.array([0.8, 0.8, 1.0])

        object.__setattr__(self, 'position', position)
        object.__setattr__(self, 'size', diameter)
        object.__setattr__(self, 'color', color)
        object.__setattr__(self, 'count', position.shape[-2])

    def __repr__(self):
        return 'Disk'
Ejemplo n.º 3
0
class NeighborListFns:
    """A struct containing functions to allocate and update neighbor lists.

  Attributes:
    allocate: A function to allocate a new neighbor list. This function cannot
      be compiled, since it uses the values of positions to infer the shapes.
    update: A function to update a neighbor list given a new set of positions
      and a previously allocated neighbor list.
  """
    allocate: Callable[..., NeighborList] = dataclasses.static_field()
    update: Callable[[Array, NeighborList],
                     NeighborList] = dataclasses.static_field()

    def __call__(self,
                 position: Array,
                 neighbors: Optional[NeighborList] = None,
                 extra_capacity: int = 0,
                 **kwargs) -> NeighborList:
        """A function for backward compatibility with previous neighbor lists.

    Args:
      position: An `(N, dim)` array of particle positions.
      neighbors: An optional neighbor list object. If it is provided then
        the function updates the neighbor list, otherwise it allocates a new
        neighbor list.
      extra_capacity: Extra capacity to add if allocating the neighbor list.
    Returns:
      A neighbor list object.
    """
        logging.warning(
            'Using a depricated code path to create / update neighbor '
            'lists. It will be removed in a later version of JAX MD. '
            'Using `neighbor_fn.allocate` and `neighbor_fn.update` '
            'is preferred.')
        if neighbors is None:
            return self.allocate(position, extra_capacity, **kwargs)
        return self.update(position, neighbors, **kwargs)

    def __iter__(self):
        return iter((self.allocate, self.update))
Ejemplo n.º 4
0
class CellList:
    """Stores the spatial partition of a system into a cell list.

  See cell_list(...) for details on the construction / specification.
  Cell list buffers all have a common shape, S, where
    * `S = [cell_count_x, cell_count_y, cell_capacity]`
    * `S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]`
  in two- and three-dimensions respectively. It is assumed that each cell has
  the same capacity.

  Attributes:
    position_buffer: An ndarray of floating point positions with shape
      S + [spatial_dimension].
    id_buffer: An ndarray of int32 particle ids of shape S. Note that empty
      slots are specified by id = N where N is the number of particles in the
      system.
    kwarg_buffers: A dictionary of ndarrays of shape S + [...]. This contains
      side data placed into the cell list.
    did_buffer_overflow: A boolean specifying whether or not the cell list
      exceeded the maximum allocated capacity.
    cell_capacity: An integer specifying the maximum capacity of each cell in
      the cell list.
    update_fn: A function that updates the cell list at a fixed capacity.
  """
    position_buffer: Array
    id_buffer: Array
    kwarg_buffers: Dict[str, Array]

    did_buffer_overflow: Array

    cell_capacity: int = dataclasses.static_field()

    update_fn: Callable[..., 'CellList'] = \
        dataclasses.static_field()

    def update(self, position: Array, **kwargs) -> 'CellList':
        cl_data = (self.cell_capacity, self.did_buffer_overflow,
                   self.update_fn)
        return self.update_fn(position, cl_data, **kwargs)
Ejemplo n.º 5
0
class NeighborList(object):
    """A struct containing the state of a Neighbor List.

  Attributes:
    idx: For an N particle system this is an `[N, max_occupancy]` array of
      integers such that `idx[i, j]` is the jth neighbor of particle i.
    reference_position: The positions of particles when the neighbor list was
      constructed. This is used to decide whether the neighbor list ought to be
      updated.
    did_buffer_overflow: A boolean that starts out False. If there are ever
      more neighbors than max_neighbors this is set to true to indicate that
      there was a buffer overflow. If this happens, it means that the results
      of the simulation will be incorrect and the simulation needs to be rerun
      using a larger buffer.
    cell_list_capacity: An optional integer specifying the capacity of the cell
      list used as an intermediate step in the creation of the neighbor list.
    max_occupancy: A static integer specifying the maximum size of the
      neighbor list. Changing this will invoke a recompilation.
    format: A NeighborListFormat enum specifying the format of the neighbor
      list.
    update_fn: A static python function used to update the neighbor list.
  """
    idx: Array

    reference_position: Array

    did_buffer_overflow: Array

    cell_list_capacity: Optional[int] = dataclasses.static_field()

    max_occupancy: int = dataclasses.static_field()

    format: NeighborListFormat = dataclasses.static_field()
    update_fn: Callable[[Array, 'NeighborList'],
                        'NeighborList'] = dataclasses.static_field()

    def update(self, position: Array, **kwargs) -> 'NeighborList':
        return self.update_fn(position, self, **kwargs)
Ejemplo n.º 6
0
class NeighborList(object):
  """A struct containing the state of a Neighbor List.

  Attributes:
    idx: For an N particle system this is an `[N, max_occupancy]` array of
      integers such that `idx[i, j]` is the jth neighbor of particle i.
    reference_position: The positions of particles when the neighbor list was
      constructed. This is used to decide whether the neighbor list ought to be
      updated.
    did_buffer_overflow: A boolean that starts out False. If there are ever
      more neighbors than max_neighbors this is set to true to indicate that
      there was a buffer overflow. If this happens, it means that the results
      of the simulation will be incorrect and the simulation needs to be rerun
      using a larger buffer.
    max_occupancy: A static integer specifying the maximum size of the
      neighbor list. Changing this will involk a recompilation.
    cell_list_fn: A static python callable that is used to construct a cell
      list used in an intermediate step of the neighbor list calculation.
  """
  idx: np.ndarray
  reference_position: np.ndarray
  did_buffer_overflow: bool
  max_occupancy: int = dataclasses.static_field()
  cell_list_fn: Callable = dataclasses.static_field()
Ejemplo n.º 7
0
class NoseHooverChain:
    """State information for a Nose-Hoover chain.

  Attributes:
    position: An ndarray of shape [chain_length] that stores the position of
      the chain.
    velocity: An ndarray of shape [chain_length] that stores the velocity of
      the chain.
    mass: An ndarray of shape [chain_length] that stores the mass of the
      chain.
    tau: The desired period of oscillation for the chain. Longer periods result
      is better stability but worse temperature control.
    kinetic_energy: A float that stores the current kinetic energy of the
      system that the chain is coupled to.
    degrees_of_freedom: An integer specifying the number of degrees of freedom
      that the chain is coupled to.
  """
    position: Array
    velocity: Array
    mass: Array
    tau: Array
    kinetic_energy: Array
    degrees_of_freedom: int = dataclasses.static_field()