Exemplo n.º 1
0
    def test_periodic_displacement(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)

            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            dR = space.map_product(space.pairwise_displacement)(R, R)

            dR_wrapped = space.periodic_displacement(f32(1.0), dR)

            dR_direct = dR
            dr_direct = space.distance(dR)
            dr_direct = np.reshape(dr_direct, dr_direct.shape + (1, ))

            if spatial_dimension == 2:
                for i in range(-1, 2):
                    for j in range(-1, 2):
                        dR_shifted = dR + np.array([i, j], dtype=R.dtype)

                        dr_shifted = space.distance(dR_shifted)
                        dr_shifted = np.reshape(dr_shifted,
                                                dr_shifted.shape + (1, ))

                        dR_direct = np.where(dr_shifted < dr_direct,
                                             dR_shifted, dR_direct)
                        dr_direct = np.where(dr_shifted < dr_direct,
                                             dr_shifted, dr_direct)
            elif spatial_dimension == 3:
                for i in range(-1, 2):
                    for j in range(-1, 2):
                        for k in range(-1, 2):
                            dR_shifted = dR + np.array([i, j, k],
                                                       dtype=R.dtype)

                            dr_shifted = space.distance(dR_shifted)
                            dr_shifted = np.reshape(dr_shifted,
                                                    dr_shifted.shape + (1, ))

                            dR_direct = np.where(dr_shifted < dr_direct,
                                                 dR_shifted, dR_direct)
                            dr_direct = np.where(dr_shifted < dr_direct,
                                                 dr_shifted, dr_direct)

            dR_direct = np.array(dR_direct, dtype=dR.dtype)
            assert dR_wrapped.dtype == dtype
            self.assertAllClose(dR_wrapped, dR_direct, True)
Exemplo n.º 2
0
    def test_periodic_shift(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2 = random.split(key, 3)

            R = random.uniform(split1, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            dR = np.sqrt(f32(0.1)) * random.normal(
                split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

            dR = np.where(dR > 0.49, f32(0.49), dR)
            dR = np.where(dR < -0.49, f32(-0.49), dR)

            R_shift = space.periodic_shift(f32(1.0), R, dR)

            assert R_shift.dtype == R.dtype
            assert np.all(R_shift < 1.0)
            assert np.all(R_shift > 0.0)

            dR_after = space.periodic_displacement(f32(1.0), R_shift - R)

            assert dR_after.dtype == R.dtype
            self.assertAllClose(dR_after, dR, True)