def test_triplet_static_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) angle_fn = lambda dR1, dR2, param=5.0: param * np.sum(np.square(dR1)) square = lambda dR, param: param * np.sum(np.square(dR)) params = f32(np.array([[[1., 1.], [2., 0.]], [[0., 2.], [1., 1.]]])) count = PARTICLE_COUNT // 50 key, split = random.split(key) species = random.randint(split, (count,), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) triplet_square = smap.triplet(angle_fn, displacement, species=species, param=params, reduce_axis=None) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (count, spatial_dimension), dtype=dtype) total = 0. for i in range(2): for j in range(2): R_1 = R[species == i] R_2 = R[species == j] total += 0.5 * np.sum(metric(R_1, R_2)) self.assertAllClose(triplet_square(R) / count, np.array(total, dtype=dtype))
def test_triplet_no_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) angle_fn = lambda dR1, dR2: np.sum(np.square(dR1) + np.square(dR2)) square = lambda dR: np.sum(np.square(dR)) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) triplet_square = smap.triplet(angle_fn, displacement) metric = space.map_product(metric) count = PARTICLE_COUNT // 50 for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (count, spatial_dimension), dtype=dtype) self.assertAllClose( triplet_square(R) / count / 2., np.array(0.5 * np.sum(metric(R, R)), dtype=dtype))