def test_load_image(self): cdata = ChannelData(path=self.test_path, elevation='elevation') dlist = cdata.get_data_list() images = dlist.map(ChannelData.load_image) self.assertEqual( len(dlist), len(images) )
def test_random_rotation(self): cdata1 = ChannelData(path=self.test_path, elevation='elevation') cdata1.add_process(random_rotation) for image, label in cdata1.data.take(1): fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 10)) axs[0].imshow(image) axs[1].imshow(label) plt.savefig(self.output_path / 'test_rotation.png', bbox_inches='tight')
def test_add_process(self): # TODO: create a folder with just one image so we can compare process = tf.image.flip_left_right cdata1 = ChannelData(path=self.test_path, elevation='elevation') cdata2 = ChannelData(path=self.test_path, elevation='elevation') cdata2 = cdata2.add_process(process) for image, label in cdata1.data.take(1): expected_image = process(image) expected_label = process(label) for image, label in cdata2.data.take(1): fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10)) axs[0][0].imshow(expected_image) axs[0][1].imshow(image) axs[1][0].imshow(expected_label) axs[1][1].imshow(label) plt.savefig(self.output_path / 'test_add_process.png', bbox_inches='tight')
def test_bundle_process(self): process = tf.image.flip_left_right cdata1 = ChannelData(path=self.test_path, elevation='elevation') new_process = ChannelData._bundled_process(process) for image, label in cdata1.data.take(1): new_image, new_label = new_process(image, label) expected_image = process(image) expected_label = process(label) tf.debugging.assert_equal(expected_image, new_image) tf.debugging.assert_equal(expected_label, new_label) fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10)) axs[0][0].imshow(expected_image) axs[0][1].imshow(image) axs[1][0].imshow(expected_label) axs[1][1].imshow(label) plt.savefig(self.output_path / 'test_bundle.png', bbox_inches='tight')
def test_process_path(self): cdata = ChannelData(path=self.test_path, elevation='elevation') dlist = cdata.get_data_list() process = partial( ChannelData.process_path, mask_key=cdata.mask_key, image_key=cdata.image_key, **cdata.others ) labeled_ds = dlist.map(process) expected = [f for f in os.listdir(self.test_path) if 'mask' in f] self.assertEqual( len(expected), len(labeled_ds) ) for image, label in labeled_ds.take(1): self.assertEqual(image.shape[-1], 4) self.assertEqual(image.shape[0], label.shape[0])
def test_init_data(self): cdata = ChannelData(path=self.test_path, elevation='elevation') expected = [f for f in os.listdir(self.test_path) if 'mask' in f] self.assertEqual( len(expected), len(cdata.data) ) for image, label in cdata.data.take(1): self.assertEqual(image.shape[-1], 4) self.assertEqual(image.shape[0], label.shape[0])
def test_get_data_list(self): cdata = ChannelData(path=self.test_path, elevation='elevation') dlist = list(cdata.get_data_list().as_numpy_iterator()) expected = [f for f in os.listdir(self.test_path) if 'mask' in f] actual = [str(f).split('/')[-1][:-1] for f in dlist] self.assertListEqual(sorted(expected), sorted(actual))
def test_init_data(self): cdata = ChannelData(path=self.test_path, elevation='elevation') self.assertEqual('mask', cdata.mask_key)