def test_get_tile(self): tile_size = 10 tiler = Tiler(data_shape=self.data.shape, tile_shape=(tile_size, )) with self.assertRaises(IndexError): tiler.get_tile(self.data, len(tiler)) with self.assertRaises(IndexError): tiler.get_tile(self.data, -1) # copy test t = tiler.get_tile(self.data, 0, copy_data=True) t[9] = 0 np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], tiler.get_tile(self.data, 0)) np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t) t = tiler.get_tile(self.data, 0, copy_data=False) t[9] = 0 np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], tiler.get_tile(self.data, 0)) np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t) t[9] = 9 # test callable data fn = lambda x, w: self.data[x:x + w] t = tiler.get_tile(fn, 0) np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], t) t = tiler.get_tile(fn, 1) np.testing.assert_equal([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], t)
def test_iterator(self): tile_size = 10 tiler = Tiler(data_shape=self.data.shape, tile_shape=(tile_size, )) # copy test with iterator t = list(tiler(self.data, copy_data=True)) t[0][1][9] = 0 np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], tiler.get_tile(self.data, 0)) np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t[0][1]) self.assertNotEqual(t[0][1][9], self.data[9]) t = [tile for _, tile in tiler(self.data, copy_data=False)] t[0][9] = 0 np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], tiler.get_tile(self.data, 0)) np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t[0]) self.assertEqual(t[0][9], self.data[9]) # test batch size with self.assertRaises(ValueError): t = [x for _, x in tiler(self.data, batch_size=-1)] t = [x for _, x in tiler(self.data, batch_size=0)] self.assertEqual(len(t), 10) np.testing.assert_equal(t[0].shape, (10, )) t = [x for _, x in tiler(self.data, batch_size=1)] self.assertEqual(len(t), 10) np.testing.assert_equal(t[0].shape, (1, 10)) t = [x for _, x in tiler(self.data, batch_size=10)] self.assertEqual(len(t), 1) np.testing.assert_equal(t[0].shape, (10, 10)) t = [x for _, x in tiler(self.data, batch_size=9)] self.assertEqual(len(t), 2) np.testing.assert_equal(t[0].shape, (9, 10)) np.testing.assert_equal(t[1].shape, (1, 10)) t = [x for _, x in tiler(self.data, batch_size=9, drop_last=True)] self.assertEqual(len(t), 1) np.testing.assert_equal(t[0].shape, (9, 10))
def test_callable_data(self): def fn(*x): raise ValueError(x) # 1D test tiler = Tiler(data_shape=(100, ), tile_shape=(10, )) for i in range(tiler.n_tiles): with self.assertRaises(ValueError) as cm: tiler.get_tile(fn, i) np.testing.assert_equal( cm.exception.args[0], (*tiler.get_tile_bbox_position(i)[0], *tiler.tile_shape)) # 2D test tiler = Tiler(data_shape=(100, 100), tile_shape=(10, 20)) for i in range(tiler.n_tiles): with self.assertRaises(ValueError) as cm: tiler.get_tile(fn, i) np.testing.assert_equal( cm.exception.args[0], (*tiler.get_tile_bbox_position(i)[0], *tiler.tile_shape)) # 3D test tiler = Tiler(data_shape=(100, 100, 100), tile_shape=(10, 20, 50)) for i in range(tiler.n_tiles): with self.assertRaises(ValueError) as cm: tiler.get_tile(fn, i) np.testing.assert_equal( cm.exception.args[0], (*tiler.get_tile_bbox_position(i)[0], *tiler.tile_shape)) # channel dimension test tiler = Tiler(data_shape=(100, 100, 3), tile_shape=(10, 20, 3), channel_dimension=2) for i in range(tiler.n_tiles): with self.assertRaises(ValueError) as cm: tiler.get_tile(fn, i) np.testing.assert_equal( cm.exception.args[0], (*tiler.get_tile_bbox_position(i, with_channel_dim=True)[0], *tiler.tile_shape))
def test_add(self): tiler = Tiler(data_shape=self.data.shape, tile_shape=(10, )) tiler2 = Tiler(data_shape=self.data.shape, tile_shape=(12, ), mode='irregular') tiler3 = Tiler(data_shape=(3, ) + self.data.shape, tile_shape=( 3, 10, ), channel_dimension=0) merger = Merger(tiler) merger_logits = Merger(tiler, logits=3) merger_irregular = Merger(tiler2) merger_channel_dim = Merger(tiler3) tile = tiler.get_tile(self.data, 0) tile_logits = np.vstack((tile, tile, tile)) tile_irregular = tiler2.get_tile(self.data, len(tiler2) - 1) # Wrong tile id cases with self.assertRaises(IndexError): merger.add(-1, np.ones((10, ))) with self.assertRaises(IndexError): merger.add(len(tiler), np.ones((10, ))) # Usual mergers expect tile_shape == data_shape with self.assertRaises(ValueError): merger.add(0, np.ones(( 3, 10, ))) merger.add(0, tile) np.testing.assert_equal(merger.merge()[:10], tile) # Logits merger expects an extra dimension in front for logits with self.assertRaises(ValueError): merger_logits.add(0, np.ones((10, ))) merger_logits.add(0, tile_logits) np.testing.assert_equal(merger_logits.merge()[:, :10], tile_logits) np.testing.assert_equal( merger_logits.merge(argmax=True)[:10], np.zeros((10, ))) # Irregular merger expects all(data_shape <= tile_shape) with self.assertRaises(ValueError): merger_irregular.add(0, np.ones((13, ))) merger_irregular.add(len(tiler2) - 1, tile_irregular) np.testing.assert_equal( merger_irregular.merge()[-len(tile_irregular):], tile_irregular) # Channel dimension merger with self.assertRaises(ValueError): merger_channel_dim.add(0, np.ones((10, ))) merger_channel_dim.add(0, tile_logits) np.testing.assert_equal(merger_channel_dim.merge()[:, :10], tile_logits) # gotta get that 100% coverage # this should just print a warning # let's suppress it to avoid confusion with open(os.devnull, "w") as null: with redirect_stderr(null): merger.set_window('boxcar')