# Setup Tiler and Merger tiler = Tiler(data_shape=image.shape, tile_shape=(200, 200, 3), channel_dimension=2) merger = Merger(tiler) # Example 1: process all tiles one by one, i.e. batch_size=0 for tile_i, tile in tiler(image, batch_size=0): merger.add(tile_i, tile) result_bs0 = merger.merge().astype(np.uint8) # Example 2: process all tiles in batches of 1, i.e. batch_size=1 merger.reset() for batch_i, batch in tiler(image, batch_size=1): merger.add_batch(batch_i, 1, batch) result_bs1 = merger.merge().astype(np.uint8) # Example 3: process all tiles in batches of 10, i.e. batch_size=10 merger.reset() for batch_i, batch in tiler(image, batch_size=10): merger.add_batch(batch_i, 10, batch) result_bs10 = merger.merge().astype(np.uint8) # Example 4: process all tiles in batches of 10, but drop the batch that has <batch_size tiles, drop_last=True merger.reset() for batch_i, batch in tiler(image, batch_size=10, drop_last=True): merger.add_batch(batch_i, 10, batch) result_bs10 = merger.merge().astype(np.uint8) assert np.all(result_bs0 == result_bs1)
def test_batch_add(self): tiler = Tiler(data_shape=self.data.shape, tile_shape=(10, )) merger = Merger(tiler) batch1 = [x for _, x in tiler(self.data, False, batch_size=1)] np.testing.assert_equal(len(batch1), 10) np.testing.assert_equal(batch1[0].shape, ( 1, 10, )) for i, b in enumerate(batch1): merger.add_batch(i, 1, b) np.testing.assert_equal(merger.merge(), self.data) merger.reset() batch10 = [x for _, x in tiler(self.data, False, batch_size=10)] for i, b in enumerate(batch10): merger.add_batch(i, 10, b) np.testing.assert_equal(merger.merge(), self.data) merger.reset() batch8 = [x for _, x in tiler(self.data, False, batch_size=8)] np.testing.assert_equal(len(batch8), 2) np.testing.assert_equal(batch8[0].shape, ( 8, 10, )) np.testing.assert_equal(batch8[1].shape, ( 2, 10, )) for i, b in enumerate(batch8): merger.add_batch(i, 8, b) np.testing.assert_equal(merger.merge(), self.data) merger.reset() batch8_drop = [ x for _, x in tiler(self.data, False, batch_size=8, drop_last=True) ] np.testing.assert_equal(len(batch8_drop), 1) np.testing.assert_equal(batch8_drop[0].shape, ( 8, 10, )) for i, b in enumerate(batch8_drop): merger.add_batch(i, 8, b) np.testing.assert_equal(merger.merge()[:80], self.data[:80]) np.testing.assert_equal(merger.merge()[80:], np.zeros((20, ))) with self.assertRaises(IndexError): merger.add_batch(-1, 10, batch10[0]) with self.assertRaises(IndexError): merger.add_batch(10, 10, batch10[9])