def test_single_input(self): coarse_dropout = CoarseDropout(inputs='x', outputs='x') output = coarse_dropout.forward(data=self.single_input, state={}) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output image shape'): self.assertEqual(output[0].shape, self.single_output_shape)
def test_multi_input(self): coarse_dropout = CoarseDropout(inputs='x', outputs='x') output = coarse_dropout.forward(data=self.multi_input, state={}) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output list length'): self.assertEqual(len(output), 2) for img_output in output: with self.subTest('Check output mask shape'): self.assertEqual(img_output.shape, self.multi_output_shape)