def morse_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array = None, sigma: Array = 1.0, epsilon: Array = 5.0, alpha: Array = 5.0, r_onset: float = 2.0, r_cutoff: float = 2.5, dr_threshold: float = 0.5, per_particle: bool = False ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute Morse using a neighbor list.""" sigma = maybe_downcast(sigma) epsilon = maybe_downcast(epsilon) alpha = maybe_downcast(alpha) r_onset = maybe_downcast(r_onset) r_cutoff = maybe_downcast(r_cutoff) dr_threshold = maybe_downcast(dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, r_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), ignore_unused_parameters=True, species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None) return neighbor_fn, energy_fn
def test_cell_list_overflow(self): displacement_fn, shift_fn = space.free() box_size = 100.0 r_cutoff = 3.0 dr_threshold = 0.0 neighbor_list_fn = partition.neighbor_list( displacement_fn, box_size=box_size, r_cutoff=r_cutoff, dr_threshold=dr_threshold, ) # all far from eachother R = jnp.array([ [20.0, 20.0], [30.0, 30.0], [40.0, 40.0], [50.0, 50.0], ]) neighbors = neighbor_list_fn.allocate(R) self.assertEqual(neighbors.idx.dtype, jnp.int32) # two first point are close to eachother R = jnp.array([ [20.0, 20.0], [20.0, 20.0], [40.0, 40.0], [50.0, 50.0], ]) neighbors = neighbors.update(R) self.assertTrue(neighbors.did_buffer_overflow) self.assertEqual(neighbors.idx.dtype, jnp.int32)
def bks_neighbor_list(displacement_or_metric, box_size, species, Q_sq, exp_coeff, exp_decay, attractive_coeff, repulsive_coeff, coulomb_alpha, cutoff, dr_threshold=0.8): Q_sq = np.array(Q_sq, f32) exp_coeff = np.array(exp_coeff, f32) exp_decay = np.array(exp_decay, f32) attractive_coeff = np.array(attractive_coeff, f32) repulsive_coeff = np.array(repulsive_coeff, f32) dr_threshold = f32(dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( bks, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, Q_sq=Q_sq, exp_coeff=exp_coeff, exp_decay=exp_decay, attractive_coeff=attractive_coeff, repulsive_coeff=repulsive_coeff, coulomb_alpha=coulomb_alpha, cutoff=cutoff) return neighbor_fn, energy_fn
def morse_neighbor_list(displacement_or_metric, box_size, species=None, sigma=1.0, epsilon=5.0, alpha=5.0, r_onset=2.0, r_cutoff=2.5, dr_threshold=0.5, per_particle=False): # TODO(cpgoodri) Optimize this. """Convenience wrapper to compute Morse using a neighbor list.""" sigma = np.array(sigma, f32) epsilon = np.array(epsilon, f32) alpha = np.array(alpha, f32) r_onset = np.array(r_onset, f32) r_cutoff = np.array(r_cutoff, f32) dr_threshold = np.array(dr_threshold, f32) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, r_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None) return neighbor_fn, energy_fn
def test_pair_neighbor_list_vector(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, )) return np.where(dr < sigma, dR**2, f32(0.)) N = PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, disp, reduce_axis=(1, ))) mapped_square = jit( smap.pair(truncated_square, disp, reduce_axis=(1, ))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=1.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.) nbrs = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def soft_sphere_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, alpha: Array = 2.0, dr_threshold: float = 0.2, per_particle: bool = False ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute soft spheres using a neighbor list.""" sigma = maybe_downcast(sigma) epsilon = maybe_downcast(epsilon) alpha = maybe_downcast(alpha) list_cutoff = np.max(sigma) dr_threshold = list_cutoff * maybe_downcast(dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, list_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( soft_sphere, space.canonicalize_displacement_or_metric(displacement_or_metric), ignore_unused_parameters=True, species=species, sigma=sigma, epsilon=epsilon, alpha=alpha, reduce_axis=(1, ) if per_particle else None) return neighbor_fn, energy_fn
def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr ** 2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0) neighbor_square = jit(neighbor_square) mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform( split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=2.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def lennard_jones_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, alpha: Array = 2.0, r_onset: float = 2.0, r_cutoff: float = 2.5, dr_threshold: float = 0.5, per_particle: bool = False ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute lennard-jones using a neighbor list.""" sigma = np.array(sigma, f32) epsilon = np.array(epsilon, f32) r_onset = np.array(r_onset * np.max(sigma), f32) r_cutoff = np.array(r_cutoff * np.max(sigma), f32) dr_threshold = np.array(np.max(sigma) * dr_threshold, f32) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, r_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, reduce_axis=(1, ) if per_particle else None) return neighbor_fn, energy_fn
def soft_sphere_neighbor_list(displacement_or_metric, box_size, species=None, sigma=1.0, epsilon=1.0, alpha=2.0, dr_threshold=0.2): """Convenience wrapper to compute soft spheres using a neighbor list.""" sigma = np.array(sigma, dtype=f32) epsilon = np.array(epsilon, dtype=f32) alpha = np.array(alpha, dtype=f32) list_cutoff = f32(np.max(sigma)) dr_threshold = f32(list_cutoff * dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, list_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( soft_sphere, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, alpha=alpha) return neighbor_fn, energy_fn
def test_radial_symmetry_functions_neighbor_list(self, N_types, N_etas, dtype, dim): key = random.PRNGKey(0) N = 128 box_size = 12.0 r_cutoff = 3. displacement, shift = space.periodic(box_size) R_key, species_key = random.split(key) R = box_size * random.uniform(R_key, (N, dim)) species = random.choice(species_key, N_types, (N, )) neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff, 0.) gr = nn.radial_symmetry_functions( displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype), r_cutoff) gr_neigh = nn.radial_symmetry_functions_neighbor_list( displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype), r_cutoff) nbrs = neighbor_fn(R) gr_exact = gr(R) gr_nbrs = gr_neigh(R, neighbor=nbrs) tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6 self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol)
def test_pair_neighbor_list_force_scalar_diverging_potential( self, spatial_dimension, dtype): key = random.PRNGKey(0) def potential(dr, sigma): return np.where(dr < sigma, dr**-6, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit( quantity.force(smap.pair_neighbor_list(potential, d))) mapped_square = jit(quantity.force(smap.pair(potential, d))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=4.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0) nbrs = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def lennard_jones_neighbor_list( displacement_or_metric, box_size, species=None, sigma=1.0, epsilon=1.0, alpha=2.0, r_onset=2.0, r_cutoff=2.5, dr_threshold=0.5): # TODO(schsam) Optimize this. """Convenience wrapper to compute lennard-jones using a neighbor list.""" sigma = np.array(sigma, f32) epsilon = np.array(epsilon, f32) r_onset = np.array(r_onset * np.max(sigma), f32) r_cutoff = np.array(r_cutoff * np.max(sigma), f32) dr_threshold = np.array(np.max(sigma) * dr_threshold, f32) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, r_cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon) return neighbor_fn, energy_fn
def test_pair_neighbor_list_scalar_nonadditive( self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = space.distance(dR) return np.where(dr < sigma, dr ** 2, f32(0.)) N = PARTICLE_COUNT box_size = 2. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit(smap.pair_neighbor_list( truncated_square, disp, sigma=lambda x, y: x * y)) mapped_square = jit(smap.pair(truncated_square, disp, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform( split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N,), minval=0.5, maxval=1.5) sigma_pair = sigma[:, None] * sigma[None, :] neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma) ** 2, 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma_pair), neighbor_square(R, nbrs, sigma=sigma))
def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) tol = 2e-10 if dtype == np.float32 else None N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d)) mapped_square = jit(smap.pair(truncated_square, d)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = jit( partition.neighbor_list(disp, box_size, np.max(sigma), R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol)
def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) species = np.zeros((N, ), np.int32) species = np.where(np.arange(N) > N / 3, 1, species) species = np.where(np.arange(N) > 2 * N / 3, 2, species) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, d, species=species)) mapped_square = jit(smap.pair(truncated_square, d, species=species)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0.) nbrs = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma))
def test_custom_mask_function(self): displacement_fn, shift_fn = space.free() box_size = 1.0 r_cutoff = 3.0 dr_threshold = 0.0 n_particles = 10 R = jnp.broadcast_to(jnp.zeros(3), (n_particles, 3)) def acceptable_id_pair(id1, id2): ''' Don't allow particles to have an interaction when their id's are closer than 3 (eg disabling 1-2 and 1-3 interactions) ''' return jnp.abs(id1 - id2) > 3 def mask_id_based(idx: Array, ids: Array, mask_val: int, _acceptable_id_pair: Callable) -> Array: ''' _acceptable_id_pair mapped to act upon the neighbor list where: - index of particle 1 is in index in the first dimension of array - index of particle 2 is given by the value in the array ''' @partial(vmap, in_axes=(0, 0, None)) def acceptable_id_pair(idx, id1, ids): id2 = ids.at[idx].get() return vmap(_acceptable_id_pair, in_axes=(None, 0))(id1, id2) mask = acceptable_id_pair(idx, ids, ids) return jnp.where(mask, idx, mask_val) ids = jnp.arange(n_particles) # id is just particle index here. mask_val = n_particles custom_mask_function = partial(mask_id_based, ids=ids, mask_val=mask_val, _acceptable_id_pair=acceptable_id_pair) neighbor_list_fn = partition.neighbor_list( displacement_fn, box_size=box_size, r_cutoff=r_cutoff, dr_threshold=dr_threshold, custom_mask_function=custom_mask_function, ) neighbors = neighbor_list_fn.allocate(R) neighbors = neighbors.update(R) ''' Without masking it's 9 neighbors (with mask self) -> 90 neighbors. With masking -> 42. ''' self.assertEqual(42, (neighbors.idx != mask_val).sum())
def test_neighbor_list_build_time_dependent(self, dtype, dim): key = random.PRNGKey(1) if dim == 2: box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32) elif dim == 3: box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0], [0.0, 0.0, 7.25]]) min_length = np.min(np.diag(box_fn(0.))) cutoff = f32(1.23) # TODO(schsam): Get cell-list working with anisotropic cell sizes. cell_size = cutoff / min_length displacement, _ = space.periodic_general(box_fn) metric = space.metric(displacement) R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_list_fn = partition.neighbor_list(metric, 1., cutoff, 0.0, 1.1, cell_size=cell_size, t=np.array(0.)) idx = neighbor_list_fn(R, t=np.array(0.25)).idx R_neigh = R[idx] mask = idx < N metric = partial(metric, t=f32(0.25)) d = vmap(vmap(metric, (None, 0))) dR = d(R, R_neigh) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, 0) * mask dR_exact = np.where(dR_exact < cutoff, dR_exact, 0) dR = np.sort(dR, axis=1) dR_exact = np.sort(dR_exact, axis=1) for i in range(dR.shape[0]): dR_row = dR[i] dR_row = dR_row[dR_row > 0.] dR_exact_row = dR_exact[i] dR_exact_row = dR_exact_row[dR_exact_row > 0.] self.assertAllClose(dR_row, dR_exact_row)
def behler_parrinello_neighbor_list(displacement: DisplacementFn, box_size: float, species: Array=None, mlp_sizes: Tuple[int, ...]=(30, 30), mlp_kwargs: Dict[str, Any]=None, sym_kwargs: Dict[str, Any]=None, dr_threshold: float=0.5 ) -> Tuple[NeighborFn, nn.InitFn, Callable[[PyTree, Array, NeighborList], Array]]: if sym_kwargs is None: sym_kwargs = {} if mlp_kwargs is None: mlp_kwargs = { 'activation': np.tanh } cutoff_distance = 8.0 if 'cutoff_distance' in sym_kwargs: cutoff_distance = sym_kwargs['cutoff_distance'] neighbor_fn = partition.neighbor_list(displacement, box_size, cutoff_distance, dr_threshold) sym_fn = nn.behler_parrinello_symmetry_functions_neighbor_list(displacement, species, **sym_kwargs) @hk.without_apply_rng @hk.transform def model(R, neighbor, **kwargs): embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes+(1,), activate_final=False, name='BPEncoder', **mlp_kwargs) embedding_fn = vmap(embedding_fn) sym = sym_fn(R, neighbor, **kwargs) readout = embedding_fn(sym) return np.sum(readout) return neighbor_fn, model.init, model.apply
def test_swap_mc_jammed(self, dtype): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) space_fn = space.periodic(state.box[0, 0]) displacement_fn, shift_fn = space_fn sigma = np.diag(state.sigma)[state.species] energy_fn = lambda dr, sigma: energy.soft_sphere(dr, sigma=sigma) neighbor_fn = partition.neighbor_list(displacement_fn, state.box[0, 0], np.max(sigma) + 0.1, dr_threshold=0.5) kT = 1e-2 t_md = 0.1 N_swap = 10 init_fn, apply_fn = simulate.hybrid_swap_mc(space_fn, energy_fn, neighbor_fn, 1e-3, kT, t_md, N_swap) state = init_fn(key, state.real_position, sigma) Ts = np.zeros((DYNAMICS_STEPS,)) def step_fn(i, state_and_temp): state, temp = state_and_temp state = apply_fn(state) temp = temp.at[i].set(quantity.temperature(state.md.velocity)) return state, temp state, Ts = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Ts)) tol = 5e-4 self.assertAllClose(Ts[10:], kT * np.ones((DYNAMICS_STEPS - 10)), rtol=5e-1, atol=5e-3) self.assertAllClose(np.mean(Ts[10:]), kT, rtol=tol, atol=tol) self.assertTrue(not np.all(state.sigma == sigma))
def bks_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array, Q_sq: Array, exp_coeff: Array, exp_decay: Array, attractive_coeff: Array, repulsive_coeff: Array, coulomb_alpha: Array, cutoff: float, dr_threshold: float = 0.8, fractional_coordinates: bool = False, ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute BKS energy using a neighbor list.""" Q_sq = maybe_downcast(Q_sq) exp_coeff = maybe_downcast(exp_coeff) exp_decay = maybe_downcast(exp_decay) attractive_coeff = maybe_downcast(attractive_coeff) repulsive_coeff = maybe_downcast(repulsive_coeff) dr_threshold = maybe_downcast(dr_threshold) neighbor_fn = partition.neighbor_list( displacement_or_metric, box_size, cutoff, dr_threshold, fractional_coordinates=fractional_coordinates) energy_fn = smap.pair_neighbor_list( bks, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, ignore_unused_parameters=True, Q_sq=Q_sq, exp_coeff=exp_coeff, exp_decay=exp_decay, attractive_coeff=attractive_coeff, repulsive_coeff=repulsive_coeff, coulomb_alpha=coulomb_alpha, cutoff=cutoff) return neighbor_fn, energy_fn
def test_behler_parrinello_symmetry_functions_neighbor_list(self, N_types, N_etas, dtype): displacement, shift = space.free() neighbor_fn = partition.neighbor_list(displacement, 10.0, 8.0, 0.0) gr = nn.behler_parrinello_symmetry_functions_neighbor_list( displacement,np.array([1, 1, N_types]), radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=8.0) R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype) nbrs = neighbor_fn(R) gr_out = gr(R, neighbor=nbrs) self.assertAllClose(gr_out.shape, (3, N_etas * (N_types + N_types * (N_types + 1) // 2))) self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)
def test_neighbor_list_build_sparse(self, dtype, dim): key = random.PRNGKey(1) box_size = (np.array([9.0, 4.0, 7.25], f32) if dim == 3 else np.array([9.0, 4.25], f32)) cutoff = f32(1.23) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_fn = partition.neighbor_list(displacement, box_size, cutoff, 0.0, 1.1, format=partition.Sparse) nbrs = neighbor_fn.allocate(R) mask = partition.neighbor_list_mask(nbrs) d = space.map_bond(metric) dR = d(R[nbrs.idx[0]], R[nbrs.idx[1]]) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, f32(0)) * mask mask_exact = 1. - np.eye(dR_exact.shape[0]) dR_exact = np.where(dR_exact < cutoff, dR_exact, f32(0)) * mask_exact dR_exact = np.sort(dR_exact, axis=1) for i in range(N): dR_row = dR[nbrs.idx[0] == i] dR_row = dR_row[dR_row > 0.] dR_row = np.sort(dR_row) dR_exact_row = dR_exact[i] dR_exact_row = np.array(dR_exact_row[dR_exact_row > 0.], dtype) self.assertAllClose(dR_row, dR_exact_row)
def test_neighbor_list_build(self, dtype, dim): key = random.PRNGKey(1) box_size = ( np.array([9.0, 4.0, 7.25], f32) if dim is 3 else np.array([9.0, 4.25], f32)) cutoff = f32(1.23) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_list_fn = partition.neighbor_list( displacement, box_size, cutoff, R) idx = neighbor_list_fn(R) R_neigh = R[idx] mask = idx < N d = vmap(vmap(metric, (None, 0))) dR = d(R, R_neigh) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, f32(0)) * mask mask_exact = 1. - np.eye(dR_exact.shape[0]) dR_exact = np.where(dR_exact < cutoff, dR_exact, f32(0)) * mask_exact dR = np.sort(dR, axis=1) dR_exact = np.sort(dR_exact, axis=1) for i in range(dR.shape[0]): dR_row = dR[i] dR_row = dR_row[dR_row > 0.] dR_exact_row = dR_exact[i] dR_exact_row = np.array(dR_exact_row[dR_exact_row > 0.], dtype) self.assertAllClose(dR_row, dR_exact_row, True)
def test_angular_symmetry_functions_neighbor_list(self, N_types, N_etas, dtype, dim): key = random.PRNGKey(0) N = 128 box_size = 12.0 r_cutoff = 3. displacement, shift = space.periodic(box_size) R_key, species_key = random.split(key) R = box_size * random.uniform(R_key, (N, dim)) species = random.choice(species_key, N_types, (N,)) neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff, 0.) etas = np.linspace(1., 2., N_etas, dtype=dtype) gr = nn.angular_symmetry_functions(displacement, species, etas=etas, lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=r_cutoff) gr_neigh = nn.angular_symmetry_functions_neighbor_list(displacement, species, etas=etas, lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=r_cutoff) nbrs = neighbor_fn(R) gr_exact = gr(R) gr_nbrs = gr_neigh(R, neighbor=nbrs) self.assertAllClose(gr_exact, gr_nbrs)
def bks_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array, Q_sq: Array, exp_coeff: Array, exp_decay: Array, attractive_coeff: Array, repulsive_coeff: Array, coulomb_alpha: Array, cutoff: float, dr_threshold: float = 0.8 ) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]: """Convenience wrapper to compute BKS energy using a neighbor list.""" Q_sq = np.array(Q_sq, f32) exp_coeff = np.array(exp_coeff, f32) exp_decay = np.array(exp_decay, f32) attractive_coeff = np.array(attractive_coeff, f32) repulsive_coeff = np.array(repulsive_coeff, f32) dr_threshold = f32(dr_threshold) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, cutoff, dr_threshold) energy_fn = smap.pair_neighbor_list( bks, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, Q_sq=Q_sq, exp_coeff=exp_coeff, exp_decay=exp_decay, attractive_coeff=attractive_coeff, repulsive_coeff=repulsive_coeff, coulomb_alpha=coulomb_alpha, cutoff=cutoff) return neighbor_fn, energy_fn
def graph_network_neighbor_list( displacement_fn: DisplacementFn, box_size: Box, r_cutoff: float, dr_threshold: float, nodes: Array = None, n_recurrences: int = 2, mlp_sizes: Tuple[int, ...] = (64, 64), mlp_kwargs: Dict[str, Any] = None ) -> Tuple[NeighborFn, nn.InitFn, Callable[[PyTree, Array, NeighborList], Array]]: """Convenience wrapper around EnergyGraphNet model using neighbor lists. Args: displacement_fn: Function to compute displacement between two positions. box_size: The size of the simulation volume, used to construct neighbor list. r_cutoff: A floating point cutoff; Edges will be added to the graph for pairs of particles whose separation is smaller than the cutoff. dr_threshold: A floating point number specifying a "halo" radius that we use for neighbor list construction. See `neighbor_list` for details. nodes: None or an ndarray of shape `[N, node_dim]` specifying the state of the nodes. If None this is set to the zeroes vector. Often, for a system with multiple species, this could be the species id. n_recurrences: The number of steps of message passing in the graph network. mlp_sizes: A tuple specifying the layer-widths for the fully-connected networks used to update the states in the graph network. mlp_kwargs: A dict specifying args for the fully-connected networks used to update the states in the graph network. Returns: A pair of functions. An `params = init_fn(key, R)` that instantiates the model parameters and an `E = apply_fn(params, R)` that computes the energy for a particular state. """ nodes = _canonicalize_node_state(nodes) @hk.without_apply_rng @hk.transform def model(R, neighbor, **kwargs): N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes _globals = np.zeros((1, ), R.dtype) dr_2 = space.square_distance(dR) edge_idx = np.where(dr_2 < r_cutoff**2, neighbor.idx, N) net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx)) # pytype: disable=wrong-arg-count neighbor_fn = partition.neighbor_list(displacement_fn, box_size, r_cutoff, dr_threshold, mask_self=False) init_fn, apply_fn = model.init, model.apply return neighbor_fn, init_fn, apply_fn
def pair_correlation_neighbor_list( displacement_or_metric: Union[DisplacementFn, MetricFn], box_size: Box, radii: Array, sigma: float, species: Array = None, dr_threshold: float = 0.5, eps: float = 1e-7, fractional_coordinates: bool = False, format: partition.NeighborListFormat = partition.Dense): """Computes the pair correlation function at a mesh of distances. The pair correlation function measures the number of particles at a given distance from a central particle. The pair correlation function is defined by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the approximation, $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$. This function uses neighbor lists to speed up the calculation. Args: displacement_or_metric: A function that computes the displacement or distance between two points. box_size: The size of the box containing the particles. radii: An array of radii at which we would like to compute g(r). sigima: A float specifying the width of the approximating Gaussian. species: An optional array specifying the species of each particle. If species is None then we compute a single g(r) for all particles, otherwise we compute one g(r) for each species. dr_threshold: A float specifying the halo size of the neighobr list. eps: A small additive constant used to ensure stability if the radius is zero. fractional_coordinates: Bool determining whether positions are stored in the unit cube or not. format: The format of the neighbor lists. Must be `Dense` or `Sparse`. Returns: A pair of functions: `neighbor_fn` that constructs a neighbor list (see `neighbor_list` in `partition.py` for details). `g_fn` that computes the pair correlation function for a collection of particles given their position and a neighbor list. """ metric = space.canonicalize_displacement_or_metric(displacement_or_metric) inv_rad = 1 / (radii + eps) def pairwise(dr, dim): return jnp.exp(-f32(0.5) * (dr - radii)**2 / sigma**2) * inv_rad**(dim - 1) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, jnp.max(radii) + sigma, dr_threshold, format=format) if species is None: def g_fn(R, neighbor): N, dim = R.shape mask = partition.neighbor_list_mask(neighbor) if neighbor.format is partition.Dense: R_neigh = R[neighbor.idx] d = space.map_neighbor(metric) _pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) return jnp.sum(mask[:, :, None] * _pairwise(d(R, R_neigh), dim), axis=(1, )) elif neighbor.format is partition.Sparse: dr = space.map_bond(metric)(R[neighbor.idx[0]], R[neighbor.idx[1]]) _pairwise = vmap(pairwise, (0, None)) return ops.segment_sum(mask[:, None] * _pairwise(dr, dim), neighbor.idx[0], N) else: raise NotImplementedError( 'Pair correlation function does not support ' 'OrderedSparse neighbor lists.') else: if not (isinstance(species, jnp.ndarray) and is_integer(species)): raise TypeError('Malformed species; expecting array of integers.') species_types = jnp.unique(species) def g_fn(R, neighbor): N, dim = R.shape g_R = [] mask = partition.neighbor_list_mask(neighbor) if neighbor.format is partition.Dense: neighbor_species = species[neighbor.idx] R_neigh = R[neighbor.idx] d = space.map_neighbor(metric) _pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) for s in species_types: mask_s = mask * (neighbor_species == s) g_R += [ jnp.sum(mask_s[:, :, jnp.newaxis] * _pairwise(d(R, R_neigh), dim), axis=(1, )) ] elif neighbor.format is partition.Sparse: neighbor_species = species[neighbor.idx[1]] dr = space.map_bond(metric)(R[neighbor.idx[0]], R[neighbor.idx[1]]) _pairwise = vmap(pairwise, (0, None)) for s in species_types: mask_s = mask * (neighbor_species == s) g_R += [ ops.segment_sum(mask_s[:, None] * _pairwise(dr, dim), neighbor.idx[0], N) ] else: raise NotImplementedError( 'Pair correlation function does not support ' 'OrderedSparse neighbor lists.') return g_R return neighbor_fn, g_fn
def pair_correlation_neighbor_list( displacement_or_metric: Union[DisplacementFn, MetricFn], box_size: Box, radii: Array, sigma: float, species: Array = None, dr_threshold: float = 0.5): """Computes the pair correlation function at a mesh of distances. The pair correlation function measures the number of particles at a given distance from a central particle. The pair correlation function is defined by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the approximation, $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$. This function uses neighbor lists to speed up the calculation. Args: displacement_or_metric: A function that computes the displacement or distance between two points. box_size: The size of the box containing the particles. radii: An array of radii at which we would like to compute g(r). sigima: A float specifying the width of the approximating Gaussian. species: An optional array specifying the species of each particle. If species is None then we compute a single g(r) for all particles, otherwise we compute one g(r) for each species. dr_threshold: A float specifying the halo size of the neighobr list. Returns: A pair of functions: `neighbor_fn` that constructs a neighbor list (see `neighbor_list` in `partition.py` for details). `g_fn` that computes the pair correlation function for a collection of particles given their position and a neighbor list. """ d = space.canonicalize_displacement_or_metric(displacement_or_metric) d = space.map_neighbor(d) def pairwise(dr, dim): return jnp.exp(-f32(0.5) * (dr - radii)**2 / sigma**2) / radii**(dim - 1) pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) neighbor_fn = partition.neighbor_list(displacement_or_metric, box_size, jnp.max(radii) + sigma, dr_threshold) if species is None: def g_fn(R, neighbor): dim = R.shape[-1] R_neigh = R[neighbor.idx] mask = neighbor.idx < R.shape[0] return jnp.sum(mask[:, :, jnp.newaxis] * pairwise(d(R, R_neigh), dim), axis=(1, )) else: if not (isinstance(species, jnp.ndarray) and is_integer(species)): raise TypeError('Malformed species; expecting array of integers.') species_types = jnp.unique(species) def g_fn(R, neighbor): dim = R.shape[-1] g_R = [] mask = neighbor.idx < R.shape[0] neighbor_species = species[neighbor.idx] R_neigh = R[neighbor.idx] for s in species_types: mask_s = mask * (neighbor_species == s) g_R += [ jnp.sum(mask_s[:, :, jnp.newaxis] * pairwise(d(R, R_neigh), dim), axis=(1, )) ] return g_R return neighbor_fn, g_fn