Esempio n. 1
0
    def test_crop_size_larger_than_image(self):
        data = np.random.random((8, 4, 64, 56))
        seg = np.ones(data.shape)

        d, s = random_crop(data, seg, 96, 0)

        self.assertTrue(all(i == j for i, j in zip((8, 4, 96, 96), d.shape)), "data has unexpected return shape")
        self.assertTrue(all(i == j for i, j in zip((8, 4, 96, 96), s.shape)), "seg has unexpected return shape")

        self.assertNotEqual(np.sum(s == 0), 0, "seg was not padded properly")
Esempio n. 2
0
    def __call__(self, **data_dict):
        data = data_dict.get(self.data_key)
        seg = data_dict.get(self.label_key)

        data, seg = random_crop(data, seg, self.crop_size, self.margins)

        data_dict[self.data_key] = data
        if seg is not None:
            data_dict[self.label_key] = seg

        return data_dict
    def test_random_crop_2D_from_List(self):
        data = [np.random.random((4, 64+i, 56+i)) for i in range(32)]
        seg = [np.random.random((4, 64+i, 56+i)) for i in range(32)]

        d, s = random_crop(data, seg, 32, 0)

        self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape")
        self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape")

        self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have"
                                            " happened here!")
    def test_random_crop_2D(self):
        data = np.random.random((32, 4, 64, 56))
        seg = np.ones(data.shape)

        d, s = random_crop(data, seg, 32, 0)

        self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape")
        self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape")

        self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have"
                                            " happened here!")
    def __call__(self, **data_dict):
        data = data_dict.get(self.data_key)
        seg = data_dict.get(self.label_key)

        data, seg = random_crop(data, seg, self.crop_size, self.margins)

        data_dict[self.data_key] = data
        if seg is not None:
            data_dict[self.label_key] = seg

        return data_dict
Esempio n. 6
0
    def __call__(self, **data_dict):
        data = data_dict.get("data")
        seg = data_dict.get("seg")

        data, seg = random_crop(data, seg, self.crop_size, self.margins)

        data_dict["data"] = data
        if seg is not None:
            data_dict["seg"] = seg

        return data_dict
Esempio n. 7
0
    def test_randomness_2(self):
        data = np.random.random((1, 1, 30, 30, 30))
        crop_size = (16, 18, 7)
        margin = (-4, -6, 5)

        sums = []  # these should always be different
        for _ in range(50):
            data_cropped, _ = random_crop(data, crop_size=crop_size, margins=margin)
            s = np.sum(data_cropped)
            sums.append(s)

        assert len(np.unique(sums)) == 50
Esempio n. 8
0
    def test_randomness_1(self):
        data = np.ones((1, 2, 30, 30, 30))
        crop_size = (16, 16, 16)
        margin = (-4, -4, -4)

        sums = [] # these should always be different
        for _ in range(50):
            data_cropped, _ = random_crop(data, crop_size=crop_size, margins=margin)
            s = np.sum(data_cropped[0, 0, 8, 8, :])
            assert 12 <= s <= 16
            sums.append(s)

        assert len(np.unique(sums)) != 0
Esempio n. 9
0
    def test_random_crop_with_cropsize_larger_image(self):
        '''
        should fall back to center crop
        :return:
        '''
        data = [np.random.random((4, 64+i, 56+i)) for i in range(32)]
        seg = [np.random.random((4, 64+i, 56+i)) for i in range(32)]

        d, s = random_crop(data, seg, 32, 32)

        self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape")
        self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape")

        self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have"
                                            " happened here!")
 def generate_train_batch(self):
     if not self.was_initialized:
         self.initialize()
     if self._current_position >= len(self._data):
         self._reset()
         self._current_epoch += 1
         if 0 < self.epochs <= self._current_epoch:
             raise StopIteration
     data_batch = []
     vector_batch = None if self.vector_generator is None else []
     label_batch = []
     loaded_cts = {}
     for i in range(self.batch_size):
         index = self._current_position + i
         if index < len(self._data):
             label, data, segmentation_file = self._data[index]
             data_file = data if self.volumetric else data[0]
             slice = None if self.volumetric else data[1]
             if data_file in loaded_cts:
                 data_np = loaded_cts[data_file]
             else:
                 data_sitk = sitk.ReadImage(data_file)
                 data_np = np.expand_dims(sitk.GetArrayFromImage(data_sitk).transpose(), axis=0)
                 if self.normalization_range is not None:
                     data_np = normalize(data_np, self.normalization_range, [0, 1])
                 loaded_cts[data_file] = data_np
             if slice is not None:
                 data_np = data_np[:, :, :, slice]
                 visualize_data(data_np)
             vector_gen_args = {"input_shape_pre": data_np.shape[-3:]}
             if self.mode == DataLoader.Mode.RESIZE:
                 data_np = augment_resize(data_np, None, self.input_shape)[0].squeeze(0)
                 data_np = np.expand_dims(data_np, 0)
                 vector_gen_args["input_shape_post"] = self.input_shape
             elif self.mode == DataLoader.Mode.SAMPLE:
                 data_np = np.expand_dims(data_np, 0)
                 with RNGContext(self.seed):
                     data_np = random_crop(data_np, crop_size=self.input_shape)[0].squeeze(0)
                 vector_gen_args["input_shape_post"] = self.input_shape
             if self.vector_generator is not None:
                 vector_batch.append(self.vector_generator(**vector_gen_args))
             data_batch.append(data_np)
             label_batch.append(label)
     batch = {"data": np.stack(data_batch),
              "vector": None if vector_batch is None else np.stack(vector_batch),
              "label": np.stack(label_batch)}
     self._current_position += self.number_of_threads_in_multithreaded * self.batch_size
     return batch
def random_crop_generator(generator, crop_size=128, margins=(0, 0, 0)):
    warn("using deprecated generator random_crop_generator", Warning)

    '''
    yields a random crop of size crop_size, crop_size may be a tuple with one entry for each dimension of your data (2D/3D)
    :param margins: allows to give cropping margins measured symmetrically from the image boundaries, which
    restrict the 'box' from which to randomly crop
    '''
    for data_dict in generator:
        assert "data" in list(
            data_dict.keys()), "your data generator needs to return a python dictionary with at least a 'data' key value pair"
        data = data_dict["data"]
        seg = None
        if "seg" in list(data_dict.keys()):
            seg = data_dict["seg"]
        data, seg = random_crop(data, seg, crop_size, margins)
        data_dict["data"] = data
        if seg is not None:
            data_dict["seg"] = seg
        yield data_dict
Esempio n. 12
0
def random_crop_generator(generator, crop_size=128, margins=(0, 0, 0)):
    warn("using deprecated generator random_crop_generator", Warning)
    '''
    yields a random crop of size crop_size, crop_size may be a tuple with one entry for each dimension of your data (2D/3D)
    :param margins: allows to give cropping margins measured symmetrically from the image boundaries, which
    restrict the 'box' from which to randomly crop
    '''
    for data_dict in generator:
        assert "data" in list(
            data_dict.keys()
        ), "your data generator needs to return a python dictionary with at least a 'data' key value pair"
        data = data_dict["data"]
        seg = None
        if "seg" in list(data_dict.keys()):
            seg = data_dict["seg"]
        data, seg = random_crop(data, seg, crop_size, margins)
        data_dict["data"] = data
        if seg is not None:
            data_dict["seg"] = seg
        yield data_dict