def preprocessing(self, sample, training=True): # Access data img_data = sample.img_data seg_data = sample.seg_data # Identify current spacing try: current_spacing = sample.details["spacing"] except AttributeError: print("'spacing' is not initialized in sample details!") # Cache current spacing for later postprocessing if not training: self.original_shape = (1, ) + img_data.shape[0:-1] # Calculate spacing ratio ratio = current_spacing / np.array(self.new_spacing) # Calculate new shape new_shape = tuple(np.floor(img_data.shape[0:-1] * ratio).astype(int)) # Transform data from channel-last to channel-first structure img_data = np.moveaxis(img_data, -1, 0) if training: seg_data = np.moveaxis(seg_data, -1, 0) # Resample imaging data img_data, seg_data = augment_resize(img_data, seg_data, new_shape, order=3, order_seg=1, cval_seg=0) # Transform data from channel-first back to channel-last structure img_data = np.moveaxis(img_data, 0, -1) if training: seg_data = np.moveaxis(seg_data, 0, -1) # Save resampled imaging data to sample sample.img_data = img_data sample.seg_data = seg_data
def resize_augment(data): ''' 尺寸变换增广 :param data: :return: ''' data_result, seg_result = augment_resize(data, (500, 500)) return data_result, seg_result
def test_resize(self): data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=2) mean_resized = float(np.mean(data_resized)) mean_original = float(np.mean(self.data_3D)) self.assertAlmostEqual(mean_original, mean_resized, places=2) self.assertTrue(all((data_resized.shape[i] == 2 and seg_resized.shape[i] == 2) for i in range(len(data_resized.shape))))
def __call__(self, **data_dict): data = data_dict.get(self.data_key) seg = data_dict.get(self.label_key) ret_val = augment_resize(data=data, seg=seg, target_size=self.target_size, order=self.order) data_dict[self.data_key] = ret_val[0] if seg is not None: data_dict[self.label_key] = ret_val[1] return data_dict
def __call__(self, **data_dict): data = data_dict.get(self.data_key) seg = data_dict.get(self.label_key) ret_val = augment_resize(data=data, seg=seg, target_size=self.target_size, order=self.order, order_seg=self.order_seg, cval_seg=self.cval_seg, concatenate_list=self.concatenate_list) data_dict[self.data_key] = ret_val[0] if seg is not None: data_dict[self.label_key] = ret_val[1] return data_dict
def test_resize2(self): data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=(7, 5, 6)) mean_resized = float(np.mean(data_resized)) mean_original = float(np.mean(self.data_3D)) self.assertAlmostEqual(mean_original, mean_resized, places=2) self.assertTrue( all([i == j for i, j in zip(data_resized.shape[1:], (7, 5, 6))])) self.assertTrue( all([i == j for i, j in zip(seg_resized.shape[1:], (7, 5, 6))]))
def generate_train_batch(self): if not self.was_initialized: self.initialize() if self._current_position >= len(self._data): self._reset() self._current_epoch += 1 if 0 < self.epochs <= self._current_epoch: raise StopIteration data_batch = [] vector_batch = None if self.vector_generator is None else [] label_batch = [] loaded_cts = {} for i in range(self.batch_size): index = self._current_position + i if index < len(self._data): label, data, segmentation_file = self._data[index] data_file = data if self.volumetric else data[0] slice = None if self.volumetric else data[1] if data_file in loaded_cts: data_np = loaded_cts[data_file] else: data_sitk = sitk.ReadImage(data_file) data_np = np.expand_dims(sitk.GetArrayFromImage(data_sitk).transpose(), axis=0) if self.normalization_range is not None: data_np = normalize(data_np, self.normalization_range, [0, 1]) loaded_cts[data_file] = data_np if slice is not None: data_np = data_np[:, :, :, slice] visualize_data(data_np) vector_gen_args = {"input_shape_pre": data_np.shape[-3:]} if self.mode == DataLoader.Mode.RESIZE: data_np = augment_resize(data_np, None, self.input_shape)[0].squeeze(0) data_np = np.expand_dims(data_np, 0) vector_gen_args["input_shape_post"] = self.input_shape elif self.mode == DataLoader.Mode.SAMPLE: data_np = np.expand_dims(data_np, 0) with RNGContext(self.seed): data_np = random_crop(data_np, crop_size=self.input_shape)[0].squeeze(0) vector_gen_args["input_shape_post"] = self.input_shape if self.vector_generator is not None: vector_batch.append(self.vector_generator(**vector_gen_args)) data_batch.append(data_np) label_batch.append(label) batch = {"data": np.stack(data_batch), "vector": None if vector_batch is None else np.stack(vector_batch), "label": np.stack(label_batch)} self._current_position += self.number_of_threads_in_multithreaded * self.batch_size return batch
def __call__(self, **data_dict): data = data_dict["data"] dimensionality = len(data.shape) - 2 size = np.array(data.shape[-dimensionality:], dtype=int) zoom = 1 + (np.random.random() * (self.max_zoom - 1)) samples = [] for sample in np.split(data, data.shape[0], 0): if np.random.random() < self.p_per_sample: sample = crop(sample, crop_size=(size / zoom).astype(int), crop_type="random")[0] sample = augment_resize(sample.squeeze(0), sample_seg=None, target_size=size.tolist())[0] else: sample = sample.squeeze(0) samples.append(sample) data = np.stack(samples) data_dict["data"] = data return data_dict
def preprocessing(self, sample, training=True): # Access data img_data = sample.img_data seg_data = sample.seg_data # Cache current spacing for later postprocessing if not training : sample.extended["original_shape"] = img_data.shape[0:-1] # Transform data from channel-last to channel-first structure img_data = np.moveaxis(img_data, -1, 0) if training : seg_data = np.moveaxis(seg_data, -1, 0) # Resize imaging data img_data, seg_data = augment_resize(img_data, seg_data, self.new_shape, order=self.order_img, order_seg=self.order_seg, cval_seg=0) # Transform data from channel-first back to channel-last structure img_data = np.moveaxis(img_data, 0, -1) if training : seg_data = np.moveaxis(seg_data, 0, -1) # Save resized imaging data to sample sample.img_data = img_data sample.seg_data = seg_data
def __call__(self, **data_dict): data = data_dict.get(self.data_key) seg = data_dict.get(self.label_key) if isinstance(data, np.ndarray): concatenate = True else: concatenate = self.concatenate_list if seg is not None: if isinstance(seg, np.ndarray): concatenate_seg = True else: concatenate_seg = self.concatenate_list else: concatenate_seg = None results = [] for b in range(len(data)): sample_seg = None if seg is not None: sample_seg = seg[b] res_data, res_seg = augment_resize(data[b], sample_seg, self.target_size, self.order, self.order_seg, self.cval_seg) results.append((res_data, res_seg)) if concatenate: data = np.vstack([i[0][None] for i in results]) if concatenate_seg is not None and concatenate_seg: seg = np.vstack([i[1][None] for i in results]) data_dict[self.data_key] = data if seg is not None: data_dict[self.label_key] = seg return data_dict