def test_preprocess(self): """Tests that _preprocess prepares the tensor as expected.""" # Test preprocessing with input of dimension 1. n = 10 x = tf.random.uniform((n, ), dtype=tf.float64) z = ops._preprocess(x, axis=-1) self.assertEqual(z.shape.rank, 2) self.assertEqual(z.shape, (1, n)) self.assertAllEqual(z[0], x) # Test preprocessing with input of dimension 2. x = tf.random.uniform((3, n), dtype=tf.float64) z = ops._preprocess(x, axis=-1) self.assertEqual(z.shape.rank, 2) self.assertEqual(z.shape, x.shape) self.assertAllEqual(z, x) # Test preprocessing with input of dimension 2, preparing for axis 0 x = tf.random.uniform((3, n), dtype=tf.float64) z = ops._preprocess(x, axis=0) self.assertEqual(z.shape.rank, 2) self.assertEqual(z.shape, (x.shape[1], x.shape[0])) batch = 1 self.assertAllEqual(z[batch], x[:, batch]) # Test preprocessing with input of dimension > 2 shape = [4, 21, 7, 10] x = tf.random.uniform(shape, dtype=tf.float64) axis = 2 n = shape.pop(axis) z = ops._preprocess(x, axis=axis) self.assertEqual(z.shape.rank, 2) self.assertEqual(z.shape, (np.prod(shape), n))
def test_postprocess(self): """Tests that _postprocess is the inverse of _preprocess.""" shape = (4, 21, 7, 10) for i in range(1, len(shape)): x = tf.random.uniform(shape[:i]) for axis in range(x.shape.rank): z = ops._postprocess(ops._preprocess(x, axis), x.shape, axis) self.assertAllEqual(x, z)