Ejemplo n.º 1
0
    def test_rotate_zonal_harmonics_random(self):
        """Tests the outputs of test_rotate_zonal_harmonics."""
        dtype = tf.float64
        max_band = 2
        zonal_coeffs = tf.constant(np.random.uniform(-1.0, 1.0, size=[3]),
                                   dtype=dtype)
        tensor_size = np.random.randint(3)
        tensor_shape = np.random.randint(1, 10, size=(tensor_size)).tolist()
        theta = tf.constant(np.random.uniform(0.0,
                                              np.pi,
                                              size=tensor_shape + [1]),
                            dtype=dtype)
        phi = tf.constant(np.random.uniform(0.0,
                                            2.0 * np.pi,
                                            size=tensor_shape + [1]),
                          dtype=dtype)

        rotated_zonal_coeffs = spherical_harmonics.rotate_zonal_harmonics(
            zonal_coeffs, theta, phi)
        zonal_coeffs = spherical_harmonics.tile_zonal_coefficients(
            zonal_coeffs)
        l, m = spherical_harmonics.generate_l_m_permutations(max_band)
        l = tf.broadcast_to(l, tensor_shape + l.shape.as_list())
        m = tf.broadcast_to(m, tensor_shape + m.shape.as_list())
        theta_zero = tf.constant(0.0, shape=tensor_shape + [1], dtype=dtype)
        phi_zero = tf.constant(0.0, shape=tensor_shape + [1], dtype=dtype)
        gt = zonal_coeffs * spherical_harmonics.evaluate_spherical_harmonics(
            l, m, theta_zero, phi_zero)
        gt = tf.reduce_sum(input_tensor=gt, axis=-1)
        pred = rotated_zonal_coeffs * spherical_harmonics.evaluate_spherical_harmonics(
            l, m, theta + theta_zero, phi + phi_zero)
        pred = tf.reduce_sum(input_tensor=pred, axis=-1)

        self.assertAllClose(gt, pred)
Ejemplo n.º 2
0
    def test_generate_l_m_permutations_preset(self):
        """Tests that generate_l_m_permutations produces the expected results."""
        l, m = spherical_harmonics.generate_l_m_permutations(2)

        self.assertAllEqual(l, (0, 1, 1, 1, 2, 2, 2, 2, 2))
        self.assertAllEqual(m, (0, -1, 0, 1, -2, -1, 0, 1, 2))