def transform(self): if hp.mode == '3d': if hp.aug: training_transform = Compose([ # ToCanonical(), CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'), # RandomMotion(), RandomBiasField(), ZNormalization(), RandomNoise(), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), ]) else: training_transform = Compose([ CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size, hp.crop_or_pad_size), padding_mode='reflect'), ZNormalization(), ]) elif hp.mode == '2d': if hp.aug: training_transform = Compose([ CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'), # RandomMotion(), RandomBiasField(), ZNormalization(), RandomNoise(), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), ]) else: training_transform = Compose([ CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size, hp.crop_or_pad_size), padding_mode='reflect'), ZNormalization(), ]) else: raise Exception('no such kind of mode!') return training_transform
def mri_artifact(p=1): return OneOf( { RandomMotion(): 0.34, RandomGhosting(): 0.33, RandomSpike(): 0.33 }, p=p)
def test_reproducibility_oneof(self): subject1, subject2 = self.get_subjects() trsfm = Compose([ OneOf([RandomNoise(p=1.0), RandomSpike(num_spikes=3, p=1.0)]), RandomNoise(p=.5) ]) transformed1 = trsfm(subject1) history1 = transformed1.history trsfm_hist, seeds_hist = compose_from_history(history=history1) transformed2 = self.apply_transforms(subject2, trsfm_list=trsfm_hist, seeds_list=seeds_hist) data1, data2 = transformed1.img.data, transformed2.img.data self.assertTensorEqual(data1, data2)
def get_brats( data_root='/scratch/weina/dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/', fold=1, seed=torch.distributed.get_rank() if torch.distributed.is_initialized() else 0, **kwargs): """ data iter for brats """ logging.debug("BratsIter:: fold = {}, seed = {}".format(fold, seed)) # args for transforms d_size, h_size, w_size = 155, 240, 240 input_size = [7, 223, 223] spacing = (d_size / input_size[0], h_size / input_size[1], w_size / input_size[2]) Mean, Std, Max = read_brats_mean(fold, data_root) normalize = transforms.Normalize(mean=Mean, std=Std) training_transform = Compose([ # RescaleIntensity((0, 1)), # so that there are no negative values for RandomMotion # RandomMotion(), # HistogramStandardization({MRI: landmarks}), RandomBiasField(), # ZNormalization(masking_method=ZNormalization.mean), RandomNoise(), ToCanonical(), Resample(spacing), # CropOrPad((48, 60, 48)), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), normalize ]) val_transform = Compose([Resample(spacing), normalize]) train = BratsIter(csv_file=os.path.join(data_root, 'IDH_label', 'train_fold_{}.csv'.format(fold)), brats_path=os.path.join(data_root, 'all'), brats_transform=training_transform, shuffle=True) val = BratsIter(csv_file=os.path.join(data_root, 'IDH_label', 'val_fold_{}.csv'.format(fold)), brats_path=os.path.join(data_root, 'all'), brats_transform=val_transform, shuffle=False) return train, val
def random_augment(x): '''Randomly augment input data. Returns: Randomly augmented input ''' # Data augmentations to be used transforms_dict = { RandomFlip(): 1, RandomElasticDeformation(): 1, RandomAffine(): 1, RandomNoise(): 1, RandomBlur(): 1 } # Create random transform, with a p chance to apply augmentation transform = OneOf(transforms_dict, p=0.95) return augment(x, transform)
def training_network(landmarks, dataset, subjects): training_transform = Compose([ ToCanonical(), Resample(4), CropOrPad((48, 60, 48), padding_mode='reflect'), RandomMotion(), HistogramStandardization({'mri': landmarks}), RandomBiasField(), ZNormalization(masking_method=ZNormalization.mean), RandomNoise(), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), ]) validation_transform = Compose([ ToCanonical(), Resample(4), CropOrPad((48, 60, 48), padding_mode='reflect'), HistogramStandardization({'mri': landmarks}), ZNormalization(masking_method=ZNormalization.mean), ]) training_split_ratio = 0.9 num_subjects = len(dataset) num_training_subjects = int(training_split_ratio * num_subjects) training_subjects = subjects[:num_training_subjects] validation_subjects = subjects[num_training_subjects:] training_set = tio.SubjectsDataset(training_subjects, transform=training_transform) validation_set = tio.SubjectsDataset(validation_subjects, transform=validation_transform) print('Training set:', len(training_set), 'subjects') print('Validation set:', len(validation_set), 'subjects') return training_set, validation_set
def test_wrong_input_type(self): with self.assertRaises(ValueError): OneOf(1)
def test_not_transform(self): with self.assertRaises(ValueError): OneOf({RandomAffine: 1, RandomElasticDeformation: 2})
def test_zero_probabilities(self): with self.assertRaises(ValueError): OneOf({RandomAffine(): 0, RandomElasticDeformation(): 0})
def test_negative_probabilities(self): with self.assertRaises(ValueError): OneOf({RandomAffine(): -1, RandomElasticDeformation(): 1})
input_size = [7, 223, 223] spacing = (d_size / input_size[0], h_size / input_size[1], w_size / input_size[2]) training_transform = Compose([ # RescaleIntensity((0, 1)), # so that there are no negative values for RandomMotion # RandomMotion(), # HistogramStandardization({MRI: landmarks}), RandomBiasField(), # ZNormalization(masking_method=ZNormalization.mean), RandomNoise(), ToCanonical(), Resample(spacing), # CropOrPad((48, 60, 48)), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), ]) fold = 1 data_root = '../../dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/' torch.manual_seed(0) torch.cuda.manual_seed(0) logging.getLogger().setLevel(logging.DEBUG) logging.info("Testing BratsIter without transformer [not torch wapper]") # has memory (no metter random or not, it will trival all none overlapped clips) train_dataset = BratsIter(csv_file=os.path.join(
def test_one_of(self): transform = OneOf({ RandomAffine(): 0.2, RandomElasticDeformation(max_displacement=0.5): 0.8, }) transform(self.sample_subject)
label=torchio.Image(tensor=torch.from_numpy(train_seg), label=torchio.LABEL), ) valid_subject = torchio.Subject( data=torchio.Image(tensor=torch.from_numpy(valid_data), label=torchio.INTENSITY), label=torchio.Image(tensor=torch.from_numpy(valid_seg), label=torchio.LABEL), ) # Define the transforms for the set of training patches training_transform = Compose([ RandomNoise(p=0.2), RandomFlip(axes=(0, 1, 2)), RandomBlur(p=0.2), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }, p=0.5), # Changed from p=0.75 24/6/20 ]) # Create the datasets training_dataset = torchio.ImagesDataset( [train_subject], transform=training_transform) validation_dataset = torchio.ImagesDataset( [valid_subject]) # Define the queue of sampled patches for training and validation sampler = torchio.data.UniformSampler(PATCH_SIZE) patches_training_set = torchio.Queue( subjects_dataset=training_dataset, max_length=MAX_QUEUE_LENGTH, samples_per_volume=TRAIN_PATCHES, sampler=sampler,
def mri_artifact(parameters): return OneOf( {RandomGhosting(): 0.5, RandomSpike(): 0.5}, p=parameters["probability"], )
def __init__( self, generator, zoom_range=None, filter_range=None, flip=None, transpose=False, noise=None, normalization=False, minmax=False, affine=None, window_width=None, window_level=None, window_vmax=1.0, window_vmin=0.0, intensity_shift=0.0, rotate=False, rotate3d=False, **kwargs, ): super().__init__(**kwargs) self.generator = generator self.shapes = generator.shapes # augmenting methods self.methods = [] if normalization: def _norm(data): mean = np.mean(data['image']) std = np.std(data['image']) std = 1.0 if std == 0. else std data['image'] = (data['image'] - mean) / std return data self.methods.append(_norm) if minmax: def _minmax(data): lower_percentile = 0.2, upper_percentile = 99.8 foreground = data['image'] != data['image'][ (0, ) * len(data['image'].shape)] min_val = np.percentile(data['image'][foreground].ravel(), lower_percentile) max_val = np.percentile(data['image'][foreground].ravel(), upper_percentile) data['image'][data['image'] > max_val] = max_val data['image'][data['image'] < min_val] = min_val data['image'] = (data['image'] - min_val) / (max_val - min_val) data['image'][~foreground] = 0 return data self.methods.append(_minmax) # affine if affine is not None: import torchio from torchio.transforms import ( RandomAffine, RandomElasticDeformation, OneOf, ) if affine == 'strong': transform = OneOf( { RandomAffine(translation=10, degrees=10, scales=(0.9, 1.1), default_pad_value='otsu', image_interpolation='bspline'): 0.5, RandomElasticDeformation(): 0.5 }, p=0.75, ) else: transform = OneOf( { RandomAffine(translation=10): 0.5, RandomElasticDeformation(): 0.5 }, p=0.75, ) def _affine(data): for key in data: data[key] = torch.Tensor(data[key]) subjs = { 'label': torchio.Image(tensor=data['label'], type=torchio.LABEL) } shape = data['image'].shape # We need to seperate out the case of 4D image if len(shape) == 4: n_channels = shape[-1] for i in range(n_channels): subjs.update({ f'ch{i}': torchio.Image(tensor=data['image'][..., i], type=torchio.INTENSITY) }) else: assert len(shape) == 3 subjs.update({ 'image': torchio.Image(tensor=data['image'], type=torchio.INTENSITY) }) transformed = transform(torchio.Subject(**subjs)) if 'image' in subjs.keys(): data['image'] = transformed.image.numpy() else: # if image contains multiple channels, # then aggregate the transformed results into one data['image'] = np.stack(tuple( getattr(transformed, ch).numpy() for ch in subjs.keys() if 'ch' in ch), axis=-1) data['label'] = transformed.label.numpy() for key in data: data[key] = data[key].squeeze() return data self.methods.append(_affine) # convert image to float def _to_float(data): data['image'] = data['image'].astype(np.float) return data self.methods.append(_to_float) # adjust contrast/window if window_width or window_level: from ..preprocessings import window window_width = window_width if window_width else 100 window_level = window_level if window_level else 50 def _window(data): if isinstance(window_width, (tuple, list)): _window_width = random_factor(window_width) else: _window_width = window_width if isinstance(window_level, (tuple, list)): _window_level = random_factor(window_level) else: _window_level = window_level data['image'] = window( data['image'], width=_window_width, level=_window_level, vmin=window_vmin, vmax=window_vmax, ) return data self.methods.append(_window) if noise is not None: assert isinstance(noise, float) assert noise > 0. def _noise(data): data['image'] += np.random.normal(loc=0.0, scale=noise, size=data['image'].shape) return data self.methods.append(_noise) if zoom_range is not None: from ..preprocessings import zoom def _zoom(data): zoom_factor = random_factor(zoom_range) for key in data: data[key] = zoom(data[key], zoom_factor) return data self.methods.append(_zoom) # TODO: deprecated # if filter_range is not None: # from scipy import ndimage # def _filter(data): # sigma = random_factor(filter_range) # ndim = len(data['image'].shape) # data['image'] = ndimage.gaussian_filter( # data['image'], # sigma=(sigma, sigma) + (0,) * (ndim - 2), # ) # return data # self.methods.append(_filter) if flip is not None: assert isinstance(flip, (list, tuple)) for f in flip: assert f >= 0 and f <= 1 def flip_img(img, flip_x, flip_y, flip_z): if flip_x: img = img[::-1, :, :, ...] if flip_y: img = img[:, ::-1, :, ...] if flip_z: img = img[:, :, ::-1, ...] return img def _flip(data): to_flip_x = random.random() < flip[0] to_flip_y = random.random() < flip[1] to_flip_z = random.random() < flip[2] for key in data: data[key] = flip_img(data[key], flip_x=to_flip_x, flip_y=to_flip_y, flip_z=to_flip_z) return data self.methods.append(_flip) if transpose: def _transpose(data): if random.random() > 0.5: for key in data: data[key] = np.moveaxis(data[key], 0, 1) return data self.methods.append(_transpose) if intensity_shift > 0.: def _shift(data): data['image'] += (np.random.rand() * 2.0 - 1.0) * intensity_shift return data self.methods.append(_shift) if rotate: def _rotate(data): # randomly rotate 0~3 times about the z-axis by 90 degrees times = np.random.randint(0, 4) if times > 0: for key in ['image', 'label']: data[key] = np.rot90(data[key], times, (0, 1)) return data self.methods.append(_rotate) if rotate3d: def _rotate3d(data): # check isotropic assert data['image'].shape[0] == data['image'].shape[1] assert data['image'].shape[0] == data['image'].shape[2] # randomly select the plane spanning by the axes: (0, 1), (1, 2), (0, 2) the_axes = [0, 1, 2] the_axes.remove(np.random.randint(0, 3)) the_axes = tuple(the_axes) # randomly rotate 0~3 times about the axis by 90 degrees times = np.random.randint(0, 4) if times > 0: for key in ['image', 'label']: data[key] = np.rot90(data[key], times, the_axes) return data self.methods.append(_rotate3d)