def test_input(self):
     rotate = Rotate(image_in='x')
     output = rotate.forward(data=self.single_input, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output image shape'):
         self.assertEqual(output[0].shape, self.single_output_shape)
 def test_input_image_and_mask(self):
     rotate = Rotate(image_in='x', mask_in='x_mask')
     output = rotate.forward(data=self.input_image_and_mask, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output image shape'):
         self.assertEqual(output[0].shape, self.image_and_mask_output_shape)
     with self.subTest('Check output mask shape'):
         self.assertEqual(output[1].shape, self.image_and_mask_output_shape)