class CellListFns: allocate: Callable[..., CellList] = dataclasses.static_field() update: Callable[[Array, Union[CellList, int]], CellList] = dataclasses.static_field() def __iter__(self): return iter((self.allocate, self.update))
class Disk: """Disk geometry elements. Args: position: An array of shape `(steps, count, dim)` or `(count, dim)` specifying possibly time varying positions. Here `dim` is the spatial dimension. size: An array of shape (steps, count)`, `(count,)`, or `()` specifying possibly time-varying / per-disk diameters. color: An array of shape `(steps, count, 3)` or `(count,)` specifying possibly time-varying / per-disk RGB colors. count: The number of disks. """ position: jnp.ndarray size: jnp.ndarray color: jnp.ndarray count: int = dataclasses.static_field() def __init__(self, position, diameter=1.0, color=None): if color is None: color = jnp.array([0.8, 0.8, 1.0]) object.__setattr__(self, 'position', position) object.__setattr__(self, 'size', diameter) object.__setattr__(self, 'color', color) object.__setattr__(self, 'count', position.shape[-2]) def __repr__(self): return 'Disk'
class NeighborListFns: """A struct containing functions to allocate and update neighbor lists. Attributes: allocate: A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes. update: A function to update a neighbor list given a new set of positions and a previously allocated neighbor list. """ allocate: Callable[..., NeighborList] = dataclasses.static_field() update: Callable[[Array, NeighborList], NeighborList] = dataclasses.static_field() def __call__(self, position: Array, neighbors: Optional[NeighborList] = None, extra_capacity: int = 0, **kwargs) -> NeighborList: """A function for backward compatibility with previous neighbor lists. Args: position: An `(N, dim)` array of particle positions. neighbors: An optional neighbor list object. If it is provided then the function updates the neighbor list, otherwise it allocates a new neighbor list. extra_capacity: Extra capacity to add if allocating the neighbor list. Returns: A neighbor list object. """ logging.warning( 'Using a depricated code path to create / update neighbor ' 'lists. It will be removed in a later version of JAX MD. ' 'Using `neighbor_fn.allocate` and `neighbor_fn.update` ' 'is preferred.') if neighbors is None: return self.allocate(position, extra_capacity, **kwargs) return self.update(position, neighbors, **kwargs) def __iter__(self): return iter((self.allocate, self.update))
class CellList: """Stores the spatial partition of a system into a cell list. See cell_list(...) for details on the construction / specification. Cell list buffers all have a common shape, S, where * `S = [cell_count_x, cell_count_y, cell_capacity]` * `S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]` in two- and three-dimensions respectively. It is assumed that each cell has the same capacity. Attributes: position_buffer: An ndarray of floating point positions with shape S + [spatial_dimension]. id_buffer: An ndarray of int32 particle ids of shape S. Note that empty slots are specified by id = N where N is the number of particles in the system. kwarg_buffers: A dictionary of ndarrays of shape S + [...]. This contains side data placed into the cell list. did_buffer_overflow: A boolean specifying whether or not the cell list exceeded the maximum allocated capacity. cell_capacity: An integer specifying the maximum capacity of each cell in the cell list. update_fn: A function that updates the cell list at a fixed capacity. """ position_buffer: Array id_buffer: Array kwarg_buffers: Dict[str, Array] did_buffer_overflow: Array cell_capacity: int = dataclasses.static_field() update_fn: Callable[..., 'CellList'] = \ dataclasses.static_field() def update(self, position: Array, **kwargs) -> 'CellList': cl_data = (self.cell_capacity, self.did_buffer_overflow, self.update_fn) return self.update_fn(position, cl_data, **kwargs)
class NeighborList(object): """A struct containing the state of a Neighbor List. Attributes: idx: For an N particle system this is an `[N, max_occupancy]` array of integers such that `idx[i, j]` is the jth neighbor of particle i. reference_position: The positions of particles when the neighbor list was constructed. This is used to decide whether the neighbor list ought to be updated. did_buffer_overflow: A boolean that starts out False. If there are ever more neighbors than max_neighbors this is set to true to indicate that there was a buffer overflow. If this happens, it means that the results of the simulation will be incorrect and the simulation needs to be rerun using a larger buffer. cell_list_capacity: An optional integer specifying the capacity of the cell list used as an intermediate step in the creation of the neighbor list. max_occupancy: A static integer specifying the maximum size of the neighbor list. Changing this will invoke a recompilation. format: A NeighborListFormat enum specifying the format of the neighbor list. update_fn: A static python function used to update the neighbor list. """ idx: Array reference_position: Array did_buffer_overflow: Array cell_list_capacity: Optional[int] = dataclasses.static_field() max_occupancy: int = dataclasses.static_field() format: NeighborListFormat = dataclasses.static_field() update_fn: Callable[[Array, 'NeighborList'], 'NeighborList'] = dataclasses.static_field() def update(self, position: Array, **kwargs) -> 'NeighborList': return self.update_fn(position, self, **kwargs)
class NeighborList(object): """A struct containing the state of a Neighbor List. Attributes: idx: For an N particle system this is an `[N, max_occupancy]` array of integers such that `idx[i, j]` is the jth neighbor of particle i. reference_position: The positions of particles when the neighbor list was constructed. This is used to decide whether the neighbor list ought to be updated. did_buffer_overflow: A boolean that starts out False. If there are ever more neighbors than max_neighbors this is set to true to indicate that there was a buffer overflow. If this happens, it means that the results of the simulation will be incorrect and the simulation needs to be rerun using a larger buffer. max_occupancy: A static integer specifying the maximum size of the neighbor list. Changing this will involk a recompilation. cell_list_fn: A static python callable that is used to construct a cell list used in an intermediate step of the neighbor list calculation. """ idx: np.ndarray reference_position: np.ndarray did_buffer_overflow: bool max_occupancy: int = dataclasses.static_field() cell_list_fn: Callable = dataclasses.static_field()
class NoseHooverChain: """State information for a Nose-Hoover chain. Attributes: position: An ndarray of shape [chain_length] that stores the position of the chain. velocity: An ndarray of shape [chain_length] that stores the velocity of the chain. mass: An ndarray of shape [chain_length] that stores the mass of the chain. tau: The desired period of oscillation for the chain. Longer periods result is better stability but worse temperature control. kinetic_energy: A float that stores the current kinetic energy of the system that the chain is coupled to. degrees_of_freedom: An integer specifying the number of degrees of freedom that the chain is coupled to. """ position: Array velocity: Array mass: Array tau: Array kinetic_energy: Array degrees_of_freedom: int = dataclasses.static_field()