def test_rotate_shapes(self, dtype): degrees = 0. for shape in [(3, 3), (5, 5), (224, 224, 3)]: image = tf.zeros(shape, dtype=dtype) self.assertAllEqual(image, augment.rotate(image, degrees))
def test_rotate(self, dtype): image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3)) rotation = 90. transformed = augment.rotate(image=image, degrees=rotation) expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]] self.assertAllEqual(transformed, expected)