Exemplo n.º 1
0
 def test_inverse(self):
   shape = (100,)
   forward_angles = rotation.sample_cis(shape, SEED_PAIR, inverse=False)
   backward_angles = rotation.sample_cis(shape, SEED_PAIR, inverse=True)
   # Angles should revert to identity.
   actual = forward_angles * backward_angles
   expected = tf.complex(real=tf.ones(shape), imag=tf.zeros(shape))
   actual, expected = self.evaluate([actual, expected])
   self.assertAllClose(actual, expected)
Exemplo n.º 2
0
 def test_different_samples_with_different_seeds(self):
   shape = (100,)
   angles_1 = rotation.sample_cis(shape, seed=(42, 42))
   angles_2 = rotation.sample_cis(shape, seed=(4200, 4200))
   angles_1, angles_2 = self.evaluate([angles_1, angles_2])
   self.assertFalse(np.array_equal(angles_1, angles_2))
Exemplo n.º 3
0
 def test_expected_shape(self, shape, inverse):
   angles = rotation.sample_cis(shape, SEED_PAIR, inverse=inverse)
   self.assertAllEqual(angles.shape, shape)
Exemplo n.º 4
0
 def test_output_dtype(self, inverse):
   shape = (100,)
   angles = rotation.sample_cis(shape, SEED_PAIR, inverse=inverse)
   self.assertEqual(angles.dtype, tf.complex64)
Exemplo n.º 5
0
 def test_unit_length(self, inverse):
   # Checks that each complex number has absolute value |r| = 1.
   shape = (100,)
   angles = rotation.sample_cis(shape, SEED_PAIR, inverse=inverse)
   lengths = self.evaluate(tf.math.abs(angles))
   self.assertAllClose(lengths, np.ones(shape))
Exemplo n.º 6
0
 def test_uniform_angles(self):
   # Checks that the average is close to zero.
   trials = 1000
   angles = rotation.sample_cis((trials,), SEED_PAIR)
   value = self.evaluate(tf.reduce_mean(angles))
   self.assertAllClose(value, 0 + 0j, atol=0.1)