예제 #1
0
 def preprocessing(self, sample, training=True):
     # Access data
     img_data = sample.img_data
     seg_data = sample.seg_data
     # Identify current spacing
     try:
         current_spacing = sample.details["spacing"]
     except AttributeError:
         print("'spacing' is not initialized in sample details!")
     # Cache current spacing for later postprocessing
     if not training: self.original_shape = (1, ) + img_data.shape[0:-1]
     # Calculate spacing ratio
     ratio = current_spacing / np.array(self.new_spacing)
     # Calculate new shape
     new_shape = tuple(np.floor(img_data.shape[0:-1] * ratio).astype(int))
     # Transform data from channel-last to channel-first structure
     img_data = np.moveaxis(img_data, -1, 0)
     if training: seg_data = np.moveaxis(seg_data, -1, 0)
     # Resample imaging data
     img_data, seg_data = augment_resize(img_data,
                                         seg_data,
                                         new_shape,
                                         order=3,
                                         order_seg=1,
                                         cval_seg=0)
     # Transform data from channel-first back to channel-last structure
     img_data = np.moveaxis(img_data, 0, -1)
     if training: seg_data = np.moveaxis(seg_data, 0, -1)
     # Save resampled imaging data to sample
     sample.img_data = img_data
     sample.seg_data = seg_data
예제 #2
0
def resize_augment(data):
    '''
    尺寸变换增广
    :param data:
    :return:
    '''
    data_result, seg_result = augment_resize(data, (500, 500))
    return data_result, seg_result
    def test_resize(self):
        data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=2)

        mean_resized = float(np.mean(data_resized))
        mean_original = float(np.mean(self.data_3D))

        self.assertAlmostEqual(mean_original, mean_resized, places=2)

        self.assertTrue(all((data_resized.shape[i] == 2 and seg_resized.shape[i] == 2) for i in
                            range(len(data_resized.shape))))
예제 #4
0
    def __call__(self, **data_dict):
        data = data_dict.get(self.data_key)
        seg = data_dict.get(self.label_key)

        ret_val = augment_resize(data=data, seg=seg, target_size=self.target_size, order=self.order)

        data_dict[self.data_key] = ret_val[0]
        if seg is not None:
            data_dict[self.label_key] = ret_val[1]
        return data_dict
    def __call__(self, **data_dict):
        data = data_dict.get(self.data_key)
        seg = data_dict.get(self.label_key)

        ret_val = augment_resize(data=data, seg=seg, target_size=self.target_size, order=self.order,
                                 order_seg=self.order_seg, cval_seg=self.cval_seg,
                                 concatenate_list=self.concatenate_list)

        data_dict[self.data_key] = ret_val[0]
        if seg is not None:
            data_dict[self.label_key] = ret_val[1]
        return data_dict
    def test_resize2(self):
        data_resized, seg_resized = augment_resize(self.data_3D,
                                                   self.seg_3D,
                                                   target_size=(7, 5, 6))

        mean_resized = float(np.mean(data_resized))
        mean_original = float(np.mean(self.data_3D))

        self.assertAlmostEqual(mean_original, mean_resized, places=2)

        self.assertTrue(
            all([i == j for i, j in zip(data_resized.shape[1:], (7, 5, 6))]))
        self.assertTrue(
            all([i == j for i, j in zip(seg_resized.shape[1:], (7, 5, 6))]))
 def generate_train_batch(self):
     if not self.was_initialized:
         self.initialize()
     if self._current_position >= len(self._data):
         self._reset()
         self._current_epoch += 1
         if 0 < self.epochs <= self._current_epoch:
             raise StopIteration
     data_batch = []
     vector_batch = None if self.vector_generator is None else []
     label_batch = []
     loaded_cts = {}
     for i in range(self.batch_size):
         index = self._current_position + i
         if index < len(self._data):
             label, data, segmentation_file = self._data[index]
             data_file = data if self.volumetric else data[0]
             slice = None if self.volumetric else data[1]
             if data_file in loaded_cts:
                 data_np = loaded_cts[data_file]
             else:
                 data_sitk = sitk.ReadImage(data_file)
                 data_np = np.expand_dims(sitk.GetArrayFromImage(data_sitk).transpose(), axis=0)
                 if self.normalization_range is not None:
                     data_np = normalize(data_np, self.normalization_range, [0, 1])
                 loaded_cts[data_file] = data_np
             if slice is not None:
                 data_np = data_np[:, :, :, slice]
                 visualize_data(data_np)
             vector_gen_args = {"input_shape_pre": data_np.shape[-3:]}
             if self.mode == DataLoader.Mode.RESIZE:
                 data_np = augment_resize(data_np, None, self.input_shape)[0].squeeze(0)
                 data_np = np.expand_dims(data_np, 0)
                 vector_gen_args["input_shape_post"] = self.input_shape
             elif self.mode == DataLoader.Mode.SAMPLE:
                 data_np = np.expand_dims(data_np, 0)
                 with RNGContext(self.seed):
                     data_np = random_crop(data_np, crop_size=self.input_shape)[0].squeeze(0)
                 vector_gen_args["input_shape_post"] = self.input_shape
             if self.vector_generator is not None:
                 vector_batch.append(self.vector_generator(**vector_gen_args))
             data_batch.append(data_np)
             label_batch.append(label)
     batch = {"data": np.stack(data_batch),
              "vector": None if vector_batch is None else np.stack(vector_batch),
              "label": np.stack(label_batch)}
     self._current_position += self.number_of_threads_in_multithreaded * self.batch_size
     return batch
 def __call__(self, **data_dict):
     data = data_dict["data"]
     dimensionality = len(data.shape) - 2
     size = np.array(data.shape[-dimensionality:], dtype=int)
     zoom = 1 + (np.random.random() * (self.max_zoom - 1))
     samples = []
     for sample in np.split(data, data.shape[0], 0):
         if np.random.random() < self.p_per_sample:
             sample = crop(sample,
                           crop_size=(size / zoom).astype(int),
                           crop_type="random")[0]
             sample = augment_resize(sample.squeeze(0),
                                     sample_seg=None,
                                     target_size=size.tolist())[0]
         else:
             sample = sample.squeeze(0)
         samples.append(sample)
     data = np.stack(samples)
     data_dict["data"] = data
     return data_dict
예제 #9
0
 def preprocessing(self, sample, training=True):
     # Access data
     img_data = sample.img_data
     seg_data = sample.seg_data
     # Cache current spacing for later postprocessing
     if not training : sample.extended["original_shape"] = img_data.shape[0:-1]
     # Transform data from channel-last to channel-first structure
     img_data = np.moveaxis(img_data, -1, 0)
     if training : seg_data = np.moveaxis(seg_data, -1, 0)
     # Resize imaging data
     img_data, seg_data = augment_resize(img_data, seg_data, self.new_shape,
                                         order=self.order_img,
                                         order_seg=self.order_seg,
                                         cval_seg=0)
     # Transform data from channel-first back to channel-last structure
     img_data = np.moveaxis(img_data, 0, -1)
     if training : seg_data = np.moveaxis(seg_data, 0, -1)
     # Save resized imaging data to sample
     sample.img_data = img_data
     sample.seg_data = seg_data
예제 #10
0
    def __call__(self, **data_dict):
        data = data_dict.get(self.data_key)
        seg = data_dict.get(self.label_key)

        if isinstance(data, np.ndarray):
            concatenate = True
        else:
            concatenate = self.concatenate_list

        if seg is not None:
            if isinstance(seg, np.ndarray):
                concatenate_seg = True
            else:
                concatenate_seg = self.concatenate_list
        else:
            concatenate_seg = None

        results = []
        for b in range(len(data)):
            sample_seg = None
            if seg is not None:
                sample_seg = seg[b]
            res_data, res_seg = augment_resize(data[b], sample_seg,
                                               self.target_size, self.order,
                                               self.order_seg, self.cval_seg)
            results.append((res_data, res_seg))

        if concatenate:
            data = np.vstack([i[0][None] for i in results])

        if concatenate_seg is not None and concatenate_seg:
            seg = np.vstack([i[1][None] for i in results])

        data_dict[self.data_key] = data
        if seg is not None:
            data_dict[self.label_key] = seg
        return data_dict