def test_triplet_static_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) angle_fn = lambda dR1, dR2, param=5.0: param * np.sum(np.square(dR1)) square = lambda dR, param: param * np.sum(np.square(dR)) params = f32(np.array([[[1., 1.], [2., 0.]], [[0., 2.], [1., 1.]]])) count = PARTICLE_COUNT // 50 key, split = random.split(key) species = random.randint(split, (count,), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) triplet_square = smap.triplet(angle_fn, displacement, species=species, param=params, reduce_axis=None) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (count, spatial_dimension), dtype=dtype) total = 0. for i in range(2): for j in range(2): R_1 = R[species == i] R_2 = R[species == j] total += 0.5 * np.sum(metric(R_1, R_2)) self.assertAllClose(triplet_square(R) / count, np.array(total, dtype=dtype))
def test_brownian(self, spatial_dimension, dtype): key = random.PRNGKey(0) key, T_split, mass_split = random.split(key, 3) _, shift = space.free() energy_fn = lambda R, **kwargs: f32(0) R = np.zeros((BROWNIAN_PARTICLE_COUNT, 2), dtype=dtype) mass = random.uniform(mass_split, (), minval=0.1, maxval=10.0, dtype=dtype) T = random.uniform(T_split, (), minval=0.3, maxval=1.4, dtype=dtype) dt = f32(1e-2) gamma = f32(0.1) init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, T, gamma=gamma) apply_fn = jit(apply_fn) state = init_fn(key, R, mass) sim_t = f32(BROWNIAN_DYNAMICS_STEPS * dt) for _ in range(BROWNIAN_DYNAMICS_STEPS): state = apply_fn(state) msd = np.var(state.position) th_msd = dtype(2 * T / (mass * gamma) * sim_t) assert np.abs(msd - th_msd) / msd < 1e-2 assert state.position.dtype == dtype
def test_nve_ensemble_time_dependence(self, spatial_dimension, dtype): key = random.PRNGKey(0) pos_key, center_key, vel_key, mass_key = random.split(key, 4) R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mass = random.uniform(mass_key, (PARTICLE_COUNT, ), minval=0.1, maxval=5.0, dtype=dtype) displacement, shift = space.free() E = energy.soft_sphere_pair(displacement) init_fn, apply_fn = simulate.nve(E, shift, 1e-3) apply_fn = jit(apply_fn) state = init_fn(vel_key, R, mass=mass) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) for t in range(SHORT_DYNAMICS_STEPS): state = apply_fn(state, t=t * 1e-3) E_total = E_T(state) assert np.abs(E_total - E_initial) < E_initial * 0.01 assert state.position.dtype == dtype
def test_nvt_langevin(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, R_key, R0_key, T_key, masses_key = random.split(key, 5) R = random.normal( R_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal( R0_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype) _, shift = space.free() E = functools.partial( lambda R, R0, **kwargs: np.sum((R - R0) ** 2), R0=R0) T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype) mass = random.uniform( masses_key, (LANGEVIN_PARTICLE_COUNT,), minval=0.1, maxval=10.0, dtype=dtype) init_fn, apply_fn = simulate.nvt_langevin(E, shift, f32(1e-2), T, gamma=f32(0.3)) apply_fn = jit(apply_fn) state = init_fn(key, R, mass=mass, T_initial=dtype(1.0)) T_list = [] for step in range(LANGEVIN_DYNAMICS_STEPS): state = apply_fn(state) if step > 4000 and step % 100 == 0: T_list += [quantity.temperature(state.velocity, state.mass)] T_emp = np.mean(np.array(T_list)) assert np.abs(T_emp - T) < 0.1 assert state.position.dtype == dtype
def test_nve_ensemble(self, spatial_dimension, dtype): key = random.PRNGKey(0) pos_key, center_key, vel_key, mass_key = random.split(key, 4) R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mass = random.uniform(mass_key, (PARTICLE_COUNT, ), minval=0.1, maxval=5.0, dtype=dtype) _, shift = space.free() E = lambda R, **kwargs: np.sum((R - R0)**2) init_fn, apply_fn = simulate.nve(E, shift, 1e-3) apply_fn = jit(apply_fn) state = init_fn(vel_key, R, mass=mass) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) for _ in range(DYNAMICS_STEPS): state = apply_fn(state) E_total = E_T(state) assert np.abs(E_total - E_initial) < E_initial * 0.01 assert state.position.dtype == dtype
def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * dr**2 params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT, ), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric, species=quantity.Dynamic, param=params) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum( square(metric(R_1, R_2), param)) self.assertAllClose(mapped_square(R, species, 2), np.array(total, dtype=dtype))
def test_stress_non_minimized_free(self, dim, dtype): key = random.PRNGKey(0) N = 64 box = quantity.box_size_at_number_density(N, 0.8, dim) displacement_fn, _ = space.free() pos = random.uniform(key, (N, dim)) * box energy_fn = energy.soft_sphere_pair(displacement_fn) def exact_stress(R): dR = space.map_product(displacement_fn)(R, R) dr = space.distance(dR) g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()') V = quantity.volume(dim, box) dUdr = 0.5 * g(dr)[:, :, None, None] dr = (dr + jnp.eye(N))[:, :, None, None] return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr), axis=(0, 1)) exact_stress = exact_stress(pos) ad_stress = quantity.stress(energy_fn, pos, box) tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
def test_gradient_descent(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split, split0 = random.split(key, 3) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.uniform(split0, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) energy = lambda R, **kwargs: np.sum((R - R0) ** 2) _, shift_fn = space.free() opt_init, opt_apply = minimize.gradient_descent(energy, shift_fn, f32(1e-1)) E_current = energy(R) dr_current = np.sum((R - R0) ** 2) for _ in range(OPTIMIZATION_STEPS): R = opt_apply(R) E_new = energy(R) dr_new = np.sum((R - R0) ** 2) assert E_new < E_current assert E_new.dtype == dtype assert dr_new < dr_current assert dr_new.dtype == dtype E_current = E_new dr_current = dr_new
def test_pair_dynamic_species_vector(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * np.sum(dr**2, axis=2) params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT, ), 0, 2) disp, _ = space.free() mapped_square = smap.pair(square, disp, species=quantity.Dynamic, param=params) disp = vmap(vmap(disp, (0, None), 0), (None, 0), 0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(disp(R_1, R_2), param)) self.assertAllClose(mapped_square(R, species, 2), np.array(total, dtype=dtype))
def test_fire_descent(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split, split0 = random.split(key, 3) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.uniform(split0, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) energy = lambda R, **kwargs: np.sum((R - R0)**2) _, shift_fn = space.free() opt_init, opt_apply = minimize.fire_descent(energy, shift_fn) opt_state = opt_init(R) E_current = energy(R) dr_current = np.sum((R - R0)**2) @jit def three_steps(state): return opt_apply(opt_apply(opt_apply(state))) for _ in range(OPTIMIZATION_STEPS): opt_state = three_steps(opt_state) R = opt_state.position E_new = energy(R) dr_new = np.sum((R - R0)**2) assert E_new < E_current assert E_new.dtype == dtype assert dr_new < dr_current assert dr_new.dtype == dtype E_current = E_new dr_current = dr_new
def test_graph_network_neighbor_list_moving(self, spatial_dimension, dtype, format): if format is partition.OrderedSparse: self.skipTest('OrderedSparse format incompatible with GNN ' 'force field.') key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.3 dr_threshold = 0.1 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) neighbor_fn, _, nl_energy_fn = \ energy.graph_network_neighbor_list(d, 1.0, cutoff, dr_threshold, format=format) nbrs = neighbor_fn.allocate(R) key = random.fold_in(key, 1) R = R + random.uniform(key, (32, spatial_dimension), minval=-0.05, maxval=0.05, dtype=dtype) if format is partition.Dense: self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs)) else: self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs), rtol=2e-4, atol=2e-4)
def test_cell_list_overflow(self): displacement_fn, shift_fn = space.free() box_size = 100.0 r_cutoff = 3.0 dr_threshold = 0.0 neighbor_list_fn = partition.neighbor_list( displacement_fn, box_size=box_size, r_cutoff=r_cutoff, dr_threshold=dr_threshold, ) # all far from eachother R = jnp.array([ [20.0, 20.0], [30.0, 30.0], [40.0, 40.0], [50.0, 50.0], ]) neighbors = neighbor_list_fn.allocate(R) self.assertEqual(neighbors.idx.dtype, jnp.int32) # two first point are close to eachother R = jnp.array([ [20.0, 20.0], [20.0, 20.0], [40.0, 40.0], [50.0, 50.0], ]) neighbors = neighbors.update(R) self.assertTrue(neighbors.did_buffer_overflow) self.assertEqual(neighbors.idx.dtype, jnp.int32)
def test_radial_symmetry_functions(self, N_types, N_etas, dtype): displacement, shift = space.free() gr = nn.radial_symmetry_functions(displacement, np.array([1, 1, N_types]), np.linspace(1.0, 2.0, N_etas, dtype=dtype), 4) R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype) gr_out = gr(R) self.assertAllClose(gr_out.shape, (3, N_types * N_etas)) self.assertAllClose(gr_out[2, 0], dtype(0.411717), rtol=1e-6, atol=1e-6)
def test_angular_symmetry_functions(self, N_types, N_etas, dtype): displacement, shift = space.free() gr = nn.angular_symmetry_functions(displacement,np.array([1, 1, N_types]), etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=8.0) R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype) gr_out = gr(R) self.assertAllClose(gr_out.shape, (3, N_etas * N_types * (N_types + 1) // 2)) self.assertAllClose(gr_out[2, 0], dtype(1.577944), rtol=1e-6, atol=1e-6)
def run(N=32, n_iter=1000, with_jit=True): import jax.numpy as jnp from jax import random, jit from jax_md import space, energy, simulate # MD configs dt = 1e-1 temperature = 0.1 # R: current position # dR: displacement # displacement(Ra, Rb): # dR = Ra - Rb # periodic displacement(Ra, Rb): # dR = Ra - Rb # np.mod(dR + side * f32(0.5), side) - f32(0.5) * side # periodic shift: # np.mod(R + dR, side) # shift: # R + dR displacement, shift = space.free() # Simulation init # dr: pairwise distances # epsilon: interaction energy scale (const) # alpha: interaction stiffness # dr = distance(R) # U(dr) = np.where(dr < 1.0, (1 - dr) ** 2, 0) # energy_fn(R) = diagonal_mask(U(dr)) energy_fn = energy.soft_sphere_pair(displacement) # force(energy) = -d(energy)/dR # xi = random.normal(R.shape, R.dtype) # gamma = 0.1 # nu = 1 / (mass * gamma) # dR = force(R) * dt * nu + np.sqrt(2 * temperature * dt * nu) * xi # BrownianState(position, mass, rng) pos_key, sim_key = random.split(random.PRNGKey(0)) R = random.uniform(pos_key, (N, 2), dtype=jnp.float32) init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, temperature) if with_jit: apply_fn = jit(apply_fn) state = init_fn(sim_key, R) # Start simulation times = [] for i in range(n_iter): time_start = time.perf_counter_ns() state = apply_fn(state) time_end = time.perf_counter_ns() times.append(time_end - time_start) # Finish with profiling times return times
def test_cosine_angles(self, dtype): displacement, _ = space.free() displacement = space.map_product(displacement) R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype) dR = displacement(R, R) cangles = quantity.cosine_angles(dR) c45 = 1 / np.sqrt(2) true_cangles = np.array([[[0, 0, 0], [0, 1, c45], [0, c45, 1]], [[1, 0, 0], [0, 0, 0], [0, 0, 1]], [[1, c45, 0], [c45, 1, 0], [0, 0, 0]]], dtype=dtype) self.assertAllClose(cangles, true_cangles)
def test_custom_mask_function(self): displacement_fn, shift_fn = space.free() box_size = 1.0 r_cutoff = 3.0 dr_threshold = 0.0 n_particles = 10 R = jnp.broadcast_to(jnp.zeros(3), (n_particles, 3)) def acceptable_id_pair(id1, id2): ''' Don't allow particles to have an interaction when their id's are closer than 3 (eg disabling 1-2 and 1-3 interactions) ''' return jnp.abs(id1 - id2) > 3 def mask_id_based(idx: Array, ids: Array, mask_val: int, _acceptable_id_pair: Callable) -> Array: ''' _acceptable_id_pair mapped to act upon the neighbor list where: - index of particle 1 is in index in the first dimension of array - index of particle 2 is given by the value in the array ''' @partial(vmap, in_axes=(0, 0, None)) def acceptable_id_pair(idx, id1, ids): id2 = ids.at[idx].get() return vmap(_acceptable_id_pair, in_axes=(None, 0))(id1, id2) mask = acceptable_id_pair(idx, ids, ids) return jnp.where(mask, idx, mask_val) ids = jnp.arange(n_particles) # id is just particle index here. mask_val = n_particles custom_mask_function = partial(mask_id_based, ids=ids, mask_val=mask_val, _acceptable_id_pair=acceptable_id_pair) neighbor_list_fn = partition.neighbor_list( displacement_fn, box_size=box_size, r_cutoff=r_cutoff, dr_threshold=dr_threshold, custom_mask_function=custom_mask_function, ) neighbors = neighbor_list_fn.allocate(R) neighbors = neighbors.update(R) ''' Without masking it's 9 neighbors (with mask self) -> 90 neighbors. With masking -> 42. ''' self.assertEqual(42, (neighbors.idx != mask_val).sum())
def test_pair_scalar_dummy_arg(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr**2 key, split = random.split(key) R = random.normal(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) displacement, shift = space.free() mapped = smap.pair(square, space.metric(displacement)) mapped(R, t=f32(0))
def test_langevin_harmonic(self): alpha = 1.0 E = lambda x: jnp.sum(0.5 * alpha * x**2) displacement, shift = space.free() N = 10000 steps = 1000 kT = 0.25 dt = 1e-4 gamma = 3 mass = 2.0 tol = 1e-3 X = jnp.ones((N, 1, 1)) key = random.split(random.PRNGKey(0), N) init_fn, step_fn = simulate.nvt_langevin(E, shift, dt, kT, gamma, False) step_fn = jit(vmap(step_fn)) state = vmap(init_fn, (0, 0, None))(key, X, mass) v0 = state.velocity for i in range(steps): state = step_fn(state) # Compare mean position and velocity autocorrelation with theoretical # prediction. d = jnp.sqrt(gamma**2 / 4 - alpha / mass) beta_1 = gamma / 2 + d beta_2 = gamma / 2 - d A = -beta_2 / (beta_1 - beta_2) B = beta_1 / (beta_1 - beta_2) exp1 = lambda t: jnp.exp(-beta_1 * t) exp2 = lambda t: jnp.exp(-beta_2 * t) Z = kT / (2 * d * mass) pos_fn = lambda t: A * exp1(t) + B * exp2(t) vel_fn = lambda t: Z * (-beta_2 * exp2(t) + beta_1 * exp1(t)) t = steps * dt self.assertAllClose(jnp.mean(state.position), pos_fn(t), rtol=tol, atol=tol) self.assertAllClose(jnp.mean(state.velocity * v0), vel_fn(t), rtol=tol, atol=tol)
def test_nvt_nose_hoover_ensemble(self, spatial_dimension, dtype): key = random.PRNGKey(0) def invariant(T, state): """The conserved quantity for Nose-Hoover thermostat.""" accum = \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) DOF = spatial_dimension * PARTICLE_COUNT accum = accum + (state.v_xi[0]) ** 2 * state.Q[0] * 0.5 + \ DOF * T * state.xi[0] for xi, v_xi, Q in zip(state.xi[1:], state.v_xi[1:], state.Q[1:]): accum = accum + v_xi**2 * Q * 0.5 + T * xi return accum for _ in range(STOCHASTIC_SAMPLES): key, pos_key, center_key, vel_key, T_key, masses_key = \ random.split(key, 6) R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) _, shift = space.free() E = functools.partial(lambda R, R0, **kwargs: np.sum((R - R0)**2), R0=R0) T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype) mass = random.uniform(masses_key, (PARTICLE_COUNT, ), minval=0.1, maxval=10.0, dtype=dtype) init_fn, apply_fn = simulate.nvt_nose_hoover(E, shift, 1e-3, T, tau=10) apply_fn = jit(apply_fn) state = init_fn(vel_key, R, mass=mass, T_initial=dtype(1.0)) initial = invariant(T, state) for _ in range(DYNAMICS_STEPS): state = apply_fn(state) assert np.abs( quantity.temperature(state.velocity, state.mass) - T) < 0.1 assert np.abs(invariant(T, state) - initial) < initial * 0.01 assert state.position.dtype == dtype
def test_pair_no_species_vector(self, spatial_dimension, dtype): square = lambda dr: np.sum(dr ** 2, axis=2) disp, _ = space.free() mapped_square = smap.pair(square, disp) disp = space.map_product(disp) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))), dtype=dtype) self.assertAllClose(mapped_square(R), mapped_ref)
def test_graph_network_shape_dtype(self, spatial_dimension, dtype): key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) E_out = energy_fn(params, R) assert E_out.shape == () assert E_out.dtype == dtype
def test_simple_spring(self, spatial_dimension, dtype): key = random.PRNGKey(0) disp, _ = space.free() if spatial_dimension == 2: R = np.array([[0., 0.], [1., 1.]], dtype=dtype) dist = np.sqrt(2.) elif spatial_dimension == 3: R = np.array([[0., 0., 0.], [1., 1., 1.]], dtype=dtype) dist = np.sqrt(3.) bonds = np.array([[0, 1]], np.int32) for _ in range(STOCHASTIC_SAMPLES): key, l_key, a_key = random.split(key, 3) length = random.uniform(key, (), minval=0.1, maxval=3.0) alpha = random.uniform(key, (), minval=2., maxval=4.) E = energy.simple_spring_bond(disp, bonds, length=length, alpha=alpha) E_exact = dtype((dist - length) ** alpha / alpha) self.assertAllClose(E(R), E_exact, True)
def test_cosine_angles_neighbors(self, dtype): displacement, _ = space.free() displacement = vmap(vmap(displacement, (None, 0)), 0) R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype) R_neigh = np.array( [[[0, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], dtype=dtype) dR = displacement(R, R_neigh) cangles = quantity.cosine_angles(dR) c45 = 1 / np.sqrt(2) true_cangles = np.array( [[[1, c45], [c45, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]], dtype=dtype) self.assertAllClose(cangles, true_cangles)
def test_bond_no_type_static(self, spatial_dimension, dtype): harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2) disp, _ = space.free() metric = space.metric(disp) mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32)) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2])) self.assertAllClose(mapped(R), dtype(accum))
def test_pair_no_species_scalar(self, spatial_dimension, dtype): square = lambda dr: dr**2 displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric) metric = space.map_product(metric) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) self.assertAllClose( mapped_square(R), np.array(0.5 * np.sum(square(metric(R, R))), dtype=dtype))
def test_graph_network_neighbor_list(self, spatial_dimension, dtype): key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) neighbor_fn, _, nl_energy_fn = \ energy.graph_network_neighbor_list(d, 1.0, cutoff, 0.0) nbrs = neighbor_fn(R) self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs))
def test_pair_no_species_vector_nonadditive(self, spatial_dimension, dtype): square = lambda dr, params: params * np.sum(dr ** 2, axis=2) disp, _ = space.free() mapped_square = smap.pair(square, disp, params=lambda x, y: x * y) disp = space.map_product(disp) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, R_key, params_key = random.split(key, 3) R = random.uniform( R_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) params = random.uniform( params_key, (PARTICLE_COUNT,), dtype=dtype, minval=0.1, maxval=1.5) pp_params = params[None, :] * params[:, None] mapped_ref = np.array(0.5 * np.sum(square(disp(R, R), pp_params)), dtype=dtype) self.assertAllClose(mapped_square(R, params=params), mapped_ref)
def test_behler_parrinello_symmetry_functions_neighbor_list(self, N_types, N_etas, dtype): displacement, shift = space.free() neighbor_fn = partition.neighbor_list(displacement, 10.0, 8.0, 0.0) gr = nn.behler_parrinello_symmetry_functions_neighbor_list( displacement,np.array([1, 1, N_types]), radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=8.0) R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype) nbrs = neighbor_fn(R) gr_out = gr(R, neighbor=nbrs) self.assertAllClose(gr_out.shape, (3, N_etas * (N_types + N_types * (N_types + 1) // 2))) self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)
def test_pair_no_species_scalar_dynamic(self, spatial_dimension, dtype): square = lambda dr, epsilon: epsilon * dr ** 2 displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric, epsilon=1.0) metric = space.map_product(metric) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.uniform( split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) epsilon = random.uniform(split2, (PARTICLE_COUNT,), dtype=dtype) mat_epsilon = 0.5 * (epsilon[:, np.newaxis] + epsilon[np.newaxis, :]) self.assertAllClose( mapped_square(R, epsilon=epsilon), np.array(0.5 * np.sum( square(metric(R, R), mat_epsilon)), dtype=dtype))