Example #1
0
def morse_neighbor_list(
    displacement_or_metric: DisplacementOrMetricFn,
    box_size: Box,
    species: Array = None,
    sigma: Array = 1.0,
    epsilon: Array = 5.0,
    alpha: Array = 5.0,
    r_onset: float = 2.0,
    r_cutoff: float = 2.5,
    dr_threshold: float = 0.5,
    per_particle: bool = False
) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]:
    """Convenience wrapper to compute Morse using a neighbor list."""
    sigma = maybe_downcast(sigma)
    epsilon = maybe_downcast(epsilon)
    alpha = maybe_downcast(alpha)
    r_onset = maybe_downcast(r_onset)
    r_cutoff = maybe_downcast(r_cutoff)
    dr_threshold = maybe_downcast(dr_threshold)

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          r_cutoff, dr_threshold)
    energy_fn = smap.pair_neighbor_list(
        multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff),
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        ignore_unused_parameters=True,
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        alpha=alpha,
        reduce_axis=(1, ) if per_particle else None)

    return neighbor_fn, energy_fn
Example #2
0
    def test_cell_list_overflow(self):
        displacement_fn, shift_fn = space.free()

        box_size = 100.0
        r_cutoff = 3.0
        dr_threshold = 0.0

        neighbor_list_fn = partition.neighbor_list(
            displacement_fn,
            box_size=box_size,
            r_cutoff=r_cutoff,
            dr_threshold=dr_threshold,
        )

        # all far from eachother
        R = jnp.array([
            [20.0, 20.0],
            [30.0, 30.0],
            [40.0, 40.0],
            [50.0, 50.0],
        ])
        neighbors = neighbor_list_fn.allocate(R)
        self.assertEqual(neighbors.idx.dtype, jnp.int32)

        # two first point are close to eachother
        R = jnp.array([
            [20.0, 20.0],
            [20.0, 20.0],
            [40.0, 40.0],
            [50.0, 50.0],
        ])

        neighbors = neighbors.update(R)
        self.assertTrue(neighbors.did_buffer_overflow)
        self.assertEqual(neighbors.idx.dtype, jnp.int32)
Example #3
0
def bks_neighbor_list(displacement_or_metric,
                      box_size,
                      species,
                      Q_sq,
                      exp_coeff,
                      exp_decay,
                      attractive_coeff,
                      repulsive_coeff,
                      coulomb_alpha,
                      cutoff,
                      dr_threshold=0.8):
    Q_sq = np.array(Q_sq, f32)
    exp_coeff = np.array(exp_coeff, f32)
    exp_decay = np.array(exp_decay, f32)
    attractive_coeff = np.array(attractive_coeff, f32)
    repulsive_coeff = np.array(repulsive_coeff, f32)
    dr_threshold = f32(dr_threshold)

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          cutoff, dr_threshold)

    energy_fn = smap.pair_neighbor_list(
        bks,
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        Q_sq=Q_sq,
        exp_coeff=exp_coeff,
        exp_decay=exp_decay,
        attractive_coeff=attractive_coeff,
        repulsive_coeff=repulsive_coeff,
        coulomb_alpha=coulomb_alpha,
        cutoff=cutoff)

    return neighbor_fn, energy_fn
Example #4
0
def morse_neighbor_list(displacement_or_metric,
                        box_size,
                        species=None,
                        sigma=1.0,
                        epsilon=5.0,
                        alpha=5.0,
                        r_onset=2.0,
                        r_cutoff=2.5,
                        dr_threshold=0.5,
                        per_particle=False):  # TODO(cpgoodri) Optimize this.
    """Convenience wrapper to compute Morse using a neighbor list."""
    sigma = np.array(sigma, f32)
    epsilon = np.array(epsilon, f32)
    alpha = np.array(alpha, f32)
    r_onset = np.array(r_onset, f32)
    r_cutoff = np.array(r_cutoff, f32)
    dr_threshold = np.array(dr_threshold, f32)

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          r_cutoff, dr_threshold)
    energy_fn = smap.pair_neighbor_list(
        multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff),
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        alpha=alpha,
        reduce_axis=(1, ) if per_particle else None)

    return neighbor_fn, energy_fn
Example #5
0
    def test_pair_neighbor_list_vector(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, ))
            return np.where(dr < sigma, dR**2, f32(0.))

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square, disp, reduce_axis=(1, )))
        mapped_square = jit(
            smap.pair(truncated_square, disp, reduce_axis=(1, )))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=1.5)
            neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.)
            nbrs = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))
Example #6
0
def soft_sphere_neighbor_list(
    displacement_or_metric: DisplacementOrMetricFn,
    box_size: Box,
    species: Array = None,
    sigma: Array = 1.0,
    epsilon: Array = 1.0,
    alpha: Array = 2.0,
    dr_threshold: float = 0.2,
    per_particle: bool = False
) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]:
    """Convenience wrapper to compute soft spheres using a neighbor list."""
    sigma = maybe_downcast(sigma)
    epsilon = maybe_downcast(epsilon)
    alpha = maybe_downcast(alpha)
    list_cutoff = np.max(sigma)
    dr_threshold = list_cutoff * maybe_downcast(dr_threshold)

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          list_cutoff, dr_threshold)
    energy_fn = smap.pair_neighbor_list(
        soft_sphere,
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        ignore_unused_parameters=True,
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        alpha=alpha,
        reduce_axis=(1, ) if per_particle else None)

    return neighbor_fn, energy_fn
Example #7
0
  def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype, format):
    key = random.PRNGKey(0)

    def truncated_square(dr, sigma):
      return np.where(dr < sigma, dr ** 2, f32(0.))

    N = NEIGHBOR_LIST_PARTICLE_COUNT
    box_size = 4. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)
    d = space.metric(disp)

    neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0)
    neighbor_square = jit(neighbor_square)
    mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(
        split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
      neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0,
                                            format=format)
      nbrs = neighbor_fn.allocate(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, nbrs, sigma=sigma))
Example #8
0
def lennard_jones_neighbor_list(
    displacement_or_metric: DisplacementOrMetricFn,
    box_size: Box,
    species: Array = None,
    sigma: Array = 1.0,
    epsilon: Array = 1.0,
    alpha: Array = 2.0,
    r_onset: float = 2.0,
    r_cutoff: float = 2.5,
    dr_threshold: float = 0.5,
    per_particle: bool = False
) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]:
    """Convenience wrapper to compute lennard-jones using a neighbor list."""
    sigma = np.array(sigma, f32)
    epsilon = np.array(epsilon, f32)
    r_onset = np.array(r_onset * np.max(sigma), f32)
    r_cutoff = np.array(r_cutoff * np.max(sigma), f32)
    dr_threshold = np.array(np.max(sigma) * dr_threshold, f32)

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          r_cutoff, dr_threshold)
    energy_fn = smap.pair_neighbor_list(
        multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff),
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        reduce_axis=(1, ) if per_particle else None)

    return neighbor_fn, energy_fn
Example #9
0
def soft_sphere_neighbor_list(displacement_or_metric,
                              box_size,
                              species=None,
                              sigma=1.0,
                              epsilon=1.0,
                              alpha=2.0,
                              dr_threshold=0.2):
    """Convenience wrapper to compute soft spheres using a neighbor list."""
    sigma = np.array(sigma, dtype=f32)
    epsilon = np.array(epsilon, dtype=f32)
    alpha = np.array(alpha, dtype=f32)
    list_cutoff = f32(np.max(sigma))
    dr_threshold = f32(list_cutoff * dr_threshold)

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          list_cutoff, dr_threshold)
    energy_fn = smap.pair_neighbor_list(
        soft_sphere,
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        alpha=alpha)

    return neighbor_fn, energy_fn
Example #10
0
    def test_radial_symmetry_functions_neighbor_list(self, N_types, N_etas,
                                                     dtype, dim):
        key = random.PRNGKey(0)

        N = 128
        box_size = 12.0
        r_cutoff = 3.

        displacement, shift = space.periodic(box_size)
        R_key, species_key = random.split(key)
        R = box_size * random.uniform(R_key, (N, dim))
        species = random.choice(species_key, N_types, (N, ))

        neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff,
                                              0.)

        gr = nn.radial_symmetry_functions(
            displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype),
            r_cutoff)
        gr_neigh = nn.radial_symmetry_functions_neighbor_list(
            displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype),
            r_cutoff)
        nbrs = neighbor_fn(R)
        gr_exact = gr(R)
        gr_nbrs = gr_neigh(R, neighbor=nbrs)

        tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6
        self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol)
Example #11
0
    def test_pair_neighbor_list_force_scalar_diverging_potential(
            self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def potential(dr, sigma):
            return np.where(dr < sigma, dr**-6, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 4. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(
            quantity.force(smap.pair_neighbor_list(potential, d)))
        mapped_square = jit(quantity.force(smap.pair(potential, d)))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=4.5)
            neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0)
            nbrs = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))
Example #12
0
def lennard_jones_neighbor_list(
        displacement_or_metric,
        box_size,
        species=None,
        sigma=1.0,
        epsilon=1.0,
        alpha=2.0,
        r_onset=2.0,
        r_cutoff=2.5,
        dr_threshold=0.5):  # TODO(schsam) Optimize this.
    """Convenience wrapper to compute lennard-jones using a neighbor list."""
    sigma = np.array(sigma, f32)
    epsilon = np.array(epsilon, f32)
    r_onset = np.array(r_onset * np.max(sigma), f32)
    r_cutoff = np.array(r_cutoff * np.max(sigma), f32)
    dr_threshold = np.array(np.max(sigma) * dr_threshold, f32)

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          r_cutoff, dr_threshold)
    energy_fn = smap.pair_neighbor_list(
        multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff),
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        sigma=sigma,
        epsilon=epsilon)

    return neighbor_fn, energy_fn
Example #13
0
  def test_pair_neighbor_list_scalar_nonadditive(
      self, spatial_dimension, dtype, format):
    key = random.PRNGKey(0)

    def truncated_square(dR, sigma):
      dr = space.distance(dR)
      return np.where(dr < sigma, dr ** 2, f32(0.))

    N = PARTICLE_COUNT
    box_size = 2. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)

    neighbor_square = jit(smap.pair_neighbor_list(
      truncated_square, disp, sigma=lambda x, y: x * y))
    mapped_square = jit(smap.pair(truncated_square, disp, sigma=1.0))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(
        split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (N,), minval=0.5, maxval=1.5)
      sigma_pair = sigma[:, None] * sigma[None, :]
      neighbor_fn = partition.neighbor_list(disp,
                                            box_size,
                                            np.max(sigma) ** 2,
                                            0.,
                                            format=format)
      nbrs = neighbor_fn.allocate(R)
      self.assertAllClose(mapped_square(R, sigma=sigma_pair),
                          neighbor_square(R, nbrs, sigma=sigma))
Example #14
0
    def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension,
                                                     dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        tol = 2e-10 if dtype == np.float32 else None

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
        mapped_square = jit(smap.pair(truncated_square, d))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = jit(
                partition.neighbor_list(disp, box_size, np.max(sigma), R))
            idx = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, idx, sigma=sigma), True,
                                tol, tol)
Example #15
0
    def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension,
                                                      dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)
        species = np.zeros((N, ), np.int32)
        species = np.where(np.arange(N) > N / 3, 1, species)
        species = np.where(np.arange(N) > 2 * N / 3, 2, species)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square, d, species=species))
        mapped_square = jit(smap.pair(truncated_square, d, species=species))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = partition.neighbor_list(disp, box_size,
                                                  np.max(sigma), 0.)
            nbrs = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))
Example #16
0
    def test_custom_mask_function(self):
        displacement_fn, shift_fn = space.free()

        box_size = 1.0
        r_cutoff = 3.0
        dr_threshold = 0.0
        n_particles = 10
        R = jnp.broadcast_to(jnp.zeros(3), (n_particles, 3))

        def acceptable_id_pair(id1, id2):
            '''
      Don't allow particles to have an interaction when their id's 
      are closer than 3 (eg disabling 1-2 and 1-3 interactions)
      '''
            return jnp.abs(id1 - id2) > 3

        def mask_id_based(idx: Array, ids: Array, mask_val: int,
                          _acceptable_id_pair: Callable) -> Array:
            '''
      _acceptable_id_pair mapped to act upon the neighbor list where:
          - index of particle 1 is in index in the first dimension of array
          - index of particle 2 is given by the value in the array
      '''
            @partial(vmap, in_axes=(0, 0, None))
            def acceptable_id_pair(idx, id1, ids):
                id2 = ids.at[idx].get()
                return vmap(_acceptable_id_pair, in_axes=(None, 0))(id1, id2)

            mask = acceptable_id_pair(idx, ids, ids)
            return jnp.where(mask, idx, mask_val)

        ids = jnp.arange(n_particles)  # id is just particle index here.
        mask_val = n_particles
        custom_mask_function = partial(mask_id_based,
                                       ids=ids,
                                       mask_val=mask_val,
                                       _acceptable_id_pair=acceptable_id_pair)

        neighbor_list_fn = partition.neighbor_list(
            displacement_fn,
            box_size=box_size,
            r_cutoff=r_cutoff,
            dr_threshold=dr_threshold,
            custom_mask_function=custom_mask_function,
        )

        neighbors = neighbor_list_fn.allocate(R)
        neighbors = neighbors.update(R)
        '''
    Without masking it's 9 neighbors (with mask self) -> 90 neighbors.
    With masking -> 42.
    '''
        self.assertEqual(42, (neighbors.idx != mask_val).sum())
Example #17
0
    def test_neighbor_list_build_time_dependent(self, dtype, dim):
        key = random.PRNGKey(1)

        if dim == 2:
            box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32)
        elif dim == 3:
            box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0],
                                         [0.0, 0.0, 7.25]])
        min_length = np.min(np.diag(box_fn(0.)))
        cutoff = f32(1.23)
        # TODO(schsam): Get cell-list working with anisotropic cell sizes.
        cell_size = cutoff / min_length

        displacement, _ = space.periodic_general(box_fn)
        metric = space.metric(displacement)

        R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
        N = R.shape[0]
        neighbor_list_fn = partition.neighbor_list(metric,
                                                   1.,
                                                   cutoff,
                                                   0.0,
                                                   1.1,
                                                   cell_size=cell_size,
                                                   t=np.array(0.))

        idx = neighbor_list_fn(R, t=np.array(0.25)).idx
        R_neigh = R[idx]
        mask = idx < N

        metric = partial(metric, t=f32(0.25))
        d = vmap(vmap(metric, (None, 0)))
        dR = d(R, R_neigh)

        d_exact = space.map_product(metric)
        dR_exact = d_exact(R, R)

        dR = np.where(dR < cutoff, dR, 0) * mask
        dR_exact = np.where(dR_exact < cutoff, dR_exact, 0)

        dR = np.sort(dR, axis=1)
        dR_exact = np.sort(dR_exact, axis=1)

        for i in range(dR.shape[0]):
            dR_row = dR[i]
            dR_row = dR_row[dR_row > 0.]

            dR_exact_row = dR_exact[i]
            dR_exact_row = dR_exact_row[dR_exact_row > 0.]

            self.assertAllClose(dR_row, dR_exact_row)
Example #18
0
def behler_parrinello_neighbor_list(displacement: DisplacementFn,
                                    box_size: float,
                                    species: Array=None,
                                    mlp_sizes: Tuple[int, ...]=(30, 30),
                                    mlp_kwargs: Dict[str, Any]=None,
                                    sym_kwargs: Dict[str, Any]=None,
                                    dr_threshold: float=0.5
                                    ) -> Tuple[NeighborFn,
                                               nn.InitFn,
                                               Callable[[PyTree,
                                                         Array,
                                                         NeighborList],
                                                        Array]]:
  if sym_kwargs is None:
    sym_kwargs = {}
  if mlp_kwargs is None:
    mlp_kwargs = {
        'activation': np.tanh
    }

  cutoff_distance = 8.0
  if 'cutoff_distance' in sym_kwargs:
    cutoff_distance = sym_kwargs['cutoff_distance']

  neighbor_fn = partition.neighbor_list(displacement,
                                        box_size,
                                        cutoff_distance,
                                        dr_threshold)

  sym_fn = nn.behler_parrinello_symmetry_functions_neighbor_list(displacement,
                                                                 species,
                                                                 **sym_kwargs)

  @hk.without_apply_rng
  @hk.transform
  def model(R, neighbor, **kwargs):
    embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes+(1,),
                               activate_final=False,
                               name='BPEncoder',
                               **mlp_kwargs)
    embedding_fn = vmap(embedding_fn)
    sym = sym_fn(R, neighbor, **kwargs)
    readout = embedding_fn(sym)
    return np.sum(readout)
  return neighbor_fn, model.init, model.apply
Example #19
0
  def test_swap_mc_jammed(self, dtype):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    space_fn = space.periodic(state.box[0, 0])
    displacement_fn, shift_fn = space_fn

    sigma = np.diag(state.sigma)[state.species]

    energy_fn = lambda dr, sigma: energy.soft_sphere(dr, sigma=sigma)
    neighbor_fn = partition.neighbor_list(displacement_fn,
                                          state.box[0, 0],
                                          np.max(sigma) + 0.1,
                                          dr_threshold=0.5)

    kT = 1e-2
    t_md = 0.1
    N_swap = 10
    init_fn, apply_fn = simulate.hybrid_swap_mc(space_fn,
                                                energy_fn,
                                                neighbor_fn,
                                                1e-3,
                                                kT,
                                                t_md,
                                                N_swap)
    state = init_fn(key, state.real_position, sigma)

    Ts = np.zeros((DYNAMICS_STEPS,))

    def step_fn(i, state_and_temp):
      state, temp = state_and_temp
      state = apply_fn(state)
      temp = temp.at[i].set(quantity.temperature(state.md.velocity))
      return state, temp

    state, Ts = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Ts))

    tol = 5e-4
    self.assertAllClose(Ts[10:],
                        kT * np.ones((DYNAMICS_STEPS - 10)),
                        rtol=5e-1,
                        atol=5e-3)
    self.assertAllClose(np.mean(Ts[10:]), kT, rtol=tol, atol=tol)
    self.assertTrue(not np.all(state.sigma == sigma))
Example #20
0
def bks_neighbor_list(
    displacement_or_metric: DisplacementOrMetricFn,
    box_size: Box,
    species: Array,
    Q_sq: Array,
    exp_coeff: Array,
    exp_decay: Array,
    attractive_coeff: Array,
    repulsive_coeff: Array,
    coulomb_alpha: Array,
    cutoff: float,
    dr_threshold: float = 0.8,
    fractional_coordinates: bool = False,
) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]:
    """Convenience wrapper to compute BKS energy using a neighbor list."""
    Q_sq = maybe_downcast(Q_sq)
    exp_coeff = maybe_downcast(exp_coeff)
    exp_decay = maybe_downcast(exp_decay)
    attractive_coeff = maybe_downcast(attractive_coeff)
    repulsive_coeff = maybe_downcast(repulsive_coeff)
    dr_threshold = maybe_downcast(dr_threshold)

    neighbor_fn = partition.neighbor_list(
        displacement_or_metric,
        box_size,
        cutoff,
        dr_threshold,
        fractional_coordinates=fractional_coordinates)

    energy_fn = smap.pair_neighbor_list(
        bks,
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        ignore_unused_parameters=True,
        Q_sq=Q_sq,
        exp_coeff=exp_coeff,
        exp_decay=exp_decay,
        attractive_coeff=attractive_coeff,
        repulsive_coeff=repulsive_coeff,
        coulomb_alpha=coulomb_alpha,
        cutoff=cutoff)

    return neighbor_fn, energy_fn
Example #21
0
 def test_behler_parrinello_symmetry_functions_neighbor_list(self,
                                                             N_types,
                                                             N_etas,
                                                             dtype):
   displacement, shift = space.free()
   neighbor_fn = partition.neighbor_list(displacement, 10.0, 8.0, 0.0)
   gr = nn.behler_parrinello_symmetry_functions_neighbor_list(
           displacement,np.array([1, 1, N_types]),
           radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
           angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
           lambdas=np.array([-1.0] * N_etas, dtype),
           zetas=np.array([1.0] * N_etas, dtype),
           cutoff_distance=8.0)
   R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
   nbrs = neighbor_fn(R)
   gr_out = gr(R, neighbor=nbrs)
   self.assertAllClose(gr_out.shape,
                       (3, N_etas *  (N_types + N_types * (N_types + 1) // 2)))
   self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)
Example #22
0
    def test_neighbor_list_build_sparse(self, dtype, dim):
        key = random.PRNGKey(1)

        box_size = (np.array([9.0, 4.0, 7.25], f32)
                    if dim == 3 else np.array([9.0, 4.25], f32))
        cutoff = f32(1.23)

        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
        N = R.shape[0]
        neighbor_fn = partition.neighbor_list(displacement,
                                              box_size,
                                              cutoff,
                                              0.0,
                                              1.1,
                                              format=partition.Sparse)

        nbrs = neighbor_fn.allocate(R)
        mask = partition.neighbor_list_mask(nbrs)

        d = space.map_bond(metric)
        dR = d(R[nbrs.idx[0]], R[nbrs.idx[1]])

        d_exact = space.map_product(metric)
        dR_exact = d_exact(R, R)

        dR = np.where(dR < cutoff, dR, f32(0)) * mask
        mask_exact = 1. - np.eye(dR_exact.shape[0])
        dR_exact = np.where(dR_exact < cutoff, dR_exact, f32(0)) * mask_exact

        dR_exact = np.sort(dR_exact, axis=1)

        for i in range(N):
            dR_row = dR[nbrs.idx[0] == i]
            dR_row = dR_row[dR_row > 0.]
            dR_row = np.sort(dR_row)

            dR_exact_row = dR_exact[i]
            dR_exact_row = np.array(dR_exact_row[dR_exact_row > 0.], dtype)

            self.assertAllClose(dR_row, dR_exact_row)
Example #23
0
  def test_neighbor_list_build(self, dtype, dim):
    key = random.PRNGKey(1)

    box_size = (
      np.array([9.0, 4.0, 7.25], f32) if dim is 3 else
      np.array([9.0, 4.25], f32))
    cutoff = f32(1.23)

    displacement, _ = space.periodic(box_size)
    metric = space.metric(displacement)

    R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
    N = R.shape[0]
    neighbor_list_fn = partition.neighbor_list(
      displacement, box_size, cutoff, R)

    idx = neighbor_list_fn(R)
    R_neigh = R[idx]
    mask = idx < N

    d = vmap(vmap(metric, (None, 0)))
    dR = d(R, R_neigh)

    d_exact = space.map_product(metric)
    dR_exact = d_exact(R, R)

    dR = np.where(dR < cutoff, dR, f32(0)) * mask
    mask_exact = 1. - np.eye(dR_exact.shape[0])
    dR_exact = np.where(dR_exact < cutoff, dR_exact, f32(0)) * mask_exact

    dR = np.sort(dR, axis=1)
    dR_exact = np.sort(dR_exact, axis=1)

    for i in range(dR.shape[0]):
      dR_row = dR[i]
      dR_row = dR_row[dR_row > 0.]

      dR_exact_row = dR_exact[i]
      dR_exact_row = np.array(dR_exact_row[dR_exact_row > 0.], dtype)

      self.assertAllClose(dR_row, dR_exact_row, True)
Example #24
0
  def test_angular_symmetry_functions_neighbor_list(self,
                                                    N_types,
                                                    N_etas,
                                                    dtype,
                                                    dim):
    key = random.PRNGKey(0)

    N = 128
    box_size = 12.0
    r_cutoff = 3.

    displacement, shift = space.periodic(box_size)
    R_key, species_key = random.split(key)
    R = box_size * random.uniform(R_key, (N, dim))
    species = random.choice(species_key, N_types, (N,))

    neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff, 0.)

    etas = np.linspace(1., 2., N_etas, dtype=dtype)
    gr = nn.angular_symmetry_functions(displacement,
                                       species,
                                       etas=etas, 
                                       lambdas=np.array([-1.0] * N_etas, dtype), 
                                       zetas=np.array([1.0] * N_etas, dtype),
                                       cutoff_distance=r_cutoff)
    
    gr_neigh = nn.angular_symmetry_functions_neighbor_list(displacement,
                                                           species,
                                                           etas=etas, 
                                                           lambdas=np.array([-1.0] * N_etas, dtype), 
                                                           zetas=np.array([1.0] * N_etas, dtype),
                                                           cutoff_distance=r_cutoff)
    
    nbrs = neighbor_fn(R)
    gr_exact = gr(R)
    gr_nbrs = gr_neigh(R, neighbor=nbrs)

    self.assertAllClose(gr_exact, gr_nbrs)
Example #25
0
def bks_neighbor_list(
    displacement_or_metric: DisplacementOrMetricFn,
    box_size: Box,
    species: Array,
    Q_sq: Array,
    exp_coeff: Array,
    exp_decay: Array,
    attractive_coeff: Array,
    repulsive_coeff: Array,
    coulomb_alpha: Array,
    cutoff: float,
    dr_threshold: float = 0.8
) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]:
    """Convenience wrapper to compute BKS energy using a neighbor list."""
    Q_sq = np.array(Q_sq, f32)
    exp_coeff = np.array(exp_coeff, f32)
    exp_decay = np.array(exp_decay, f32)
    attractive_coeff = np.array(attractive_coeff, f32)
    repulsive_coeff = np.array(repulsive_coeff, f32)
    dr_threshold = f32(dr_threshold)

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          cutoff, dr_threshold)

    energy_fn = smap.pair_neighbor_list(
        bks,
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        Q_sq=Q_sq,
        exp_coeff=exp_coeff,
        exp_decay=exp_decay,
        attractive_coeff=attractive_coeff,
        repulsive_coeff=repulsive_coeff,
        coulomb_alpha=coulomb_alpha,
        cutoff=cutoff)

    return neighbor_fn, energy_fn
Example #26
0
def graph_network_neighbor_list(
    displacement_fn: DisplacementFn,
    box_size: Box,
    r_cutoff: float,
    dr_threshold: float,
    nodes: Array = None,
    n_recurrences: int = 2,
    mlp_sizes: Tuple[int, ...] = (64, 64),
    mlp_kwargs: Dict[str, Any] = None
) -> Tuple[NeighborFn, nn.InitFn, Callable[[PyTree, Array, NeighborList],
                                           Array]]:
    """Convenience wrapper around EnergyGraphNet model using neighbor lists.

  Args:
    displacement_fn: Function to compute displacement between two positions.
    box_size: The size of the simulation volume, used to construct neighbor
      list.
    r_cutoff: A floating point cutoff; Edges will be added to the graph
      for pairs of particles whose separation is smaller than the cutoff.
    dr_threshold: A floating point number specifying a "halo" radius that we use
      for neighbor list construction. See `neighbor_list` for details.
    nodes: None or an ndarray of shape `[N, node_dim]` specifying the state
      of the nodes. If None this is set to the zeroes vector. Often, for a
      system with multiple species, this could be the species id.
    n_recurrences: The number of steps of message passing in the graph network.
    mlp_sizes: A tuple specifying the layer-widths for the fully-connected
      networks used to update the states in the graph network.
    mlp_kwargs: A dict specifying args for the fully-connected networks used to
      update the states in the graph network.

  Returns:
    A pair of functions. An `params = init_fn(key, R)` that instantiates the
    model parameters and an `E = apply_fn(params, R)` that computes the energy
    for a particular state.
  """

    nodes = _canonicalize_node_state(nodes)

    @hk.without_apply_rng
    @hk.transform
    def model(R, neighbor, **kwargs):
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_neighbor(d)
        R_neigh = R[neighbor.idx]
        dR = d(R, R_neigh)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes

        _globals = np.zeros((1, ), R.dtype)

        dr_2 = space.square_distance(dR)
        edge_idx = np.where(dr_2 < r_cutoff**2, neighbor.idx, N)

        net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs)
        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))  # pytype: disable=wrong-arg-count

    neighbor_fn = partition.neighbor_list(displacement_fn,
                                          box_size,
                                          r_cutoff,
                                          dr_threshold,
                                          mask_self=False)
    init_fn, apply_fn = model.init, model.apply

    return neighbor_fn, init_fn, apply_fn
Example #27
0
def pair_correlation_neighbor_list(
        displacement_or_metric: Union[DisplacementFn, MetricFn],
        box_size: Box,
        radii: Array,
        sigma: float,
        species: Array = None,
        dr_threshold: float = 0.5,
        eps: float = 1e-7,
        fractional_coordinates: bool = False,
        format: partition.NeighborListFormat = partition.Dense):
    """Computes the pair correlation function at a mesh of distances.

  The pair correlation function measures the number of particles at a given
  distance from a central particle. The pair correlation function is defined
  by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the
  approximation,
  $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$.

  This function uses neighbor lists to speed up the calculation.

  Args:
    displacement_or_metric: A function that computes the displacement or
      distance between two points.
    box_size: The size of the box containing the particles.
    radii: An array of radii at which we would like to compute g(r).
    sigima: A float specifying the width of the approximating Gaussian.
    species: An optional array specifying the species of each particle. If
      species is None then we compute a single g(r) for all particles,
      otherwise we compute one g(r) for each species.
    dr_threshold: A float specifying the halo size of the neighobr list.
    eps: A small additive constant used to ensure stability if the radius is
      zero.
    fractional_coordinates: Bool determining whether positions are stored in
      the unit cube or not.
    format: The format of the neighbor lists. Must be `Dense` or `Sparse`.

  Returns:
    A pair of functions: `neighbor_fn` that constructs a neighbor list (see
    `neighbor_list` in `partition.py` for details). `g_fn` that computes the
    pair correlation function for a collection of particles given their
    position and a neighbor list.
  """
    metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
    inv_rad = 1 / (radii + eps)

    def pairwise(dr, dim):
        return jnp.exp(-f32(0.5) *
                       (dr - radii)**2 / sigma**2) * inv_rad**(dim - 1)

    neighbor_fn = partition.neighbor_list(displacement_or_metric,
                                          box_size,
                                          jnp.max(radii) + sigma,
                                          dr_threshold,
                                          format=format)

    if species is None:

        def g_fn(R, neighbor):
            N, dim = R.shape
            mask = partition.neighbor_list_mask(neighbor)
            if neighbor.format is partition.Dense:
                R_neigh = R[neighbor.idx]
                d = space.map_neighbor(metric)
                _pairwise = vmap(vmap(pairwise, (0, None)), (0, None))
                return jnp.sum(mask[:, :, None] *
                               _pairwise(d(R, R_neigh), dim),
                               axis=(1, ))
            elif neighbor.format is partition.Sparse:
                dr = space.map_bond(metric)(R[neighbor.idx[0]],
                                            R[neighbor.idx[1]])
                _pairwise = vmap(pairwise, (0, None))
                return ops.segment_sum(mask[:, None] * _pairwise(dr, dim),
                                       neighbor.idx[0], N)
            else:
                raise NotImplementedError(
                    'Pair correlation function does not support '
                    'OrderedSparse neighbor lists.')

    else:
        if not (isinstance(species, jnp.ndarray) and is_integer(species)):
            raise TypeError('Malformed species; expecting array of integers.')
        species_types = jnp.unique(species)

        def g_fn(R, neighbor):
            N, dim = R.shape
            g_R = []
            mask = partition.neighbor_list_mask(neighbor)
            if neighbor.format is partition.Dense:
                neighbor_species = species[neighbor.idx]
                R_neigh = R[neighbor.idx]
                d = space.map_neighbor(metric)
                _pairwise = vmap(vmap(pairwise, (0, None)), (0, None))
                for s in species_types:
                    mask_s = mask * (neighbor_species == s)
                    g_R += [
                        jnp.sum(mask_s[:, :, jnp.newaxis] *
                                _pairwise(d(R, R_neigh), dim),
                                axis=(1, ))
                    ]
            elif neighbor.format is partition.Sparse:
                neighbor_species = species[neighbor.idx[1]]
                dr = space.map_bond(metric)(R[neighbor.idx[0]],
                                            R[neighbor.idx[1]])
                _pairwise = vmap(pairwise, (0, None))
                for s in species_types:
                    mask_s = mask * (neighbor_species == s)
                    g_R += [
                        ops.segment_sum(mask_s[:, None] * _pairwise(dr, dim),
                                        neighbor.idx[0], N)
                    ]
            else:
                raise NotImplementedError(
                    'Pair correlation function does not support '
                    'OrderedSparse neighbor lists.')

            return g_R

    return neighbor_fn, g_fn
Example #28
0
def pair_correlation_neighbor_list(
        displacement_or_metric: Union[DisplacementFn, MetricFn],
        box_size: Box,
        radii: Array,
        sigma: float,
        species: Array = None,
        dr_threshold: float = 0.5):
    """Computes the pair correlation function at a mesh of distances.

  The pair correlation function measures the number of particles at a given
  distance from a central particle. The pair correlation function is defined
  by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the
  approximation,
  $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$.

  This function uses neighbor lists to speed up the calculation.

  Args:
    displacement_or_metric: A function that computes the displacement or
      distance between two points.
    box_size: The size of the box containing the particles.
    radii: An array of radii at which we would like to compute g(r).
    sigima: A float specifying the width of the approximating Gaussian.
    species: An optional array specifying the species of each particle. If
      species is None then we compute a single g(r) for all particles,
      otherwise we compute one g(r) for each species.
    dr_threshold: A float specifying the halo size of the neighobr list.

  Returns:
    A pair of functions: `neighbor_fn` that constructs a neighbor list (see
    `neighbor_list` in `partition.py` for details). `g_fn` that computes the
    pair correlation function for a collection of particles given their
    position and a neighbor list.
  """
    d = space.canonicalize_displacement_or_metric(displacement_or_metric)
    d = space.map_neighbor(d)

    def pairwise(dr, dim):
        return jnp.exp(-f32(0.5) *
                       (dr - radii)**2 / sigma**2) / radii**(dim - 1)

    pairwise = vmap(vmap(pairwise, (0, None)), (0, None))

    neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size,
                                          jnp.max(radii) + sigma, dr_threshold)

    if species is None:

        def g_fn(R, neighbor):
            dim = R.shape[-1]
            R_neigh = R[neighbor.idx]
            mask = neighbor.idx < R.shape[0]
            return jnp.sum(mask[:, :, jnp.newaxis] *
                           pairwise(d(R, R_neigh), dim),
                           axis=(1, ))
    else:
        if not (isinstance(species, jnp.ndarray) and is_integer(species)):
            raise TypeError('Malformed species; expecting array of integers.')
        species_types = jnp.unique(species)

        def g_fn(R, neighbor):
            dim = R.shape[-1]
            g_R = []
            mask = neighbor.idx < R.shape[0]
            neighbor_species = species[neighbor.idx]
            R_neigh = R[neighbor.idx]
            for s in species_types:
                mask_s = mask * (neighbor_species == s)
                g_R += [
                    jnp.sum(mask_s[:, :, jnp.newaxis] *
                            pairwise(d(R, R_neigh), dim),
                            axis=(1, ))
                ]
            return g_R

    return neighbor_fn, g_fn