def test_periodic_displacement(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) dR = space.map_product(space.pairwise_displacement)(R, R) dR_wrapped = space.periodic_displacement(f32(1.0), dR) dR_direct = dR dr_direct = space.distance(dR) dr_direct = np.reshape(dr_direct, dr_direct.shape + (1, )) if spatial_dimension == 2: for i in range(-1, 2): for j in range(-1, 2): dR_shifted = dR + np.array([i, j], dtype=R.dtype) dr_shifted = space.distance(dR_shifted) dr_shifted = np.reshape(dr_shifted, dr_shifted.shape + (1, )) dR_direct = np.where(dr_shifted < dr_direct, dR_shifted, dR_direct) dr_direct = np.where(dr_shifted < dr_direct, dr_shifted, dr_direct) elif spatial_dimension == 3: for i in range(-1, 2): for j in range(-1, 2): for k in range(-1, 2): dR_shifted = dR + np.array([i, j, k], dtype=R.dtype) dr_shifted = space.distance(dR_shifted) dr_shifted = np.reshape(dr_shifted, dr_shifted.shape + (1, )) dR_direct = np.where(dr_shifted < dr_direct, dR_shifted, dR_direct) dr_direct = np.where(dr_shifted < dr_direct, dr_shifted, dr_direct) dR_direct = np.array(dR_direct, dtype=dR.dtype) assert dR_wrapped.dtype == dtype self.assertAllClose(dR_wrapped, dR_direct, True)
def test_periodic_shift(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.uniform(split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) dR = np.sqrt(f32(0.1)) * random.normal( split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) dR = np.where(dR > 0.49, f32(0.49), dR) dR = np.where(dR < -0.49, f32(-0.49), dR) R_shift = space.periodic_shift(f32(1.0), R, dR) assert R_shift.dtype == R.dtype assert np.all(R_shift < 1.0) assert np.all(R_shift > 0.0) dR_after = space.periodic_displacement(f32(1.0), R_shift - R) assert dR_after.dtype == R.dtype self.assertAllClose(dR_after, dR, True)