Example #1
0
  def test_triplet_static_species_scalar(self, spatial_dimension, dtype):
      key = random.PRNGKey(0)
      angle_fn = lambda dR1, dR2, param=5.0: param * np.sum(np.square(dR1))
      square = lambda dR, param: param * np.sum(np.square(dR))
      params = f32(np.array([[[1., 1.], [2., 0.]], [[0., 2.], [1., 1.]]]))

      count = PARTICLE_COUNT // 50
      key, split = random.split(key)
      species = random.randint(split, (count,), 0, 2)
      displacement, _ = space.free()
      metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)
      triplet_square = smap.triplet(angle_fn,
                                    displacement,
                                    species=species,
                                    param=params,
                                    reduce_axis=None)

      metric = space.map_product(metric)
      for _ in range(STOCHASTIC_SAMPLES):
        key, split = random.split(key)
        R = random.uniform(
            split, (count, spatial_dimension), dtype=dtype)
        total = 0.
        for i in range(2):
          for j in range(2):
            R_1 = R[species == i]
            R_2 = R[species == j]
            total += 0.5 * np.sum(metric(R_1, R_2))
        self.assertAllClose(triplet_square(R) / count, np.array(total, dtype=dtype))
Example #2
0
    def test_triplet_no_species_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        angle_fn = lambda dR1, dR2: np.sum(np.square(dR1) + np.square(dR2))
        square = lambda dR: np.sum(np.square(dR))
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        triplet_square = smap.triplet(angle_fn, displacement)
        metric = space.map_product(metric)

        count = PARTICLE_COUNT // 50

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (count, spatial_dimension), dtype=dtype)

            self.assertAllClose(
                triplet_square(R) / count / 2.,
                np.array(0.5 * np.sum(metric(R, R)), dtype=dtype))