def transform(self):

        if hp.mode == '3d':
            if hp.aug:
                training_transform = Compose([
                    # ToCanonical(),
                    CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'),
                    # RandomMotion(),
                    RandomBiasField(),
                    ZNormalization(),
                    RandomNoise(),
                    RandomFlip(axes=(0, )),
                    OneOf({
                        RandomAffine(): 0.8,
                        RandomElasticDeformation(): 0.2,
                    }),
                ])
            else:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size,
                               hp.crop_or_pad_size),
                              padding_mode='reflect'),
                    ZNormalization(),
                ])
        elif hp.mode == '2d':
            if hp.aug:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'),
                    # RandomMotion(),
                    RandomBiasField(),
                    ZNormalization(),
                    RandomNoise(),
                    RandomFlip(axes=(0, )),
                    OneOf({
                        RandomAffine(): 0.8,
                        RandomElasticDeformation(): 0.2,
                    }),
                ])
            else:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size,
                               hp.crop_or_pad_size),
                              padding_mode='reflect'),
                    ZNormalization(),
                ])

        else:
            raise Exception('no such kind of mode!')

        return training_transform
Exemplo n.º 2
0
    def create_transforms(self):
        transforms = []

        # clipping to remove outliers (if any)
        # clip_intensity = Lambda(VolumeDataset.clip_image, types_to_apply=[torchio.INTENSITY])
        # transforms.append(clip_intensity)

        rescale = RescaleIntensity((-1, 1), percentiles=(0.5, 99.5))
        # normalize with mu = 0 and sigma = 1/3 to have data in -1...1 almost
        # ZNormalization()

        transforms.append(rescale)

        # transforms = [rescale]
        # # As RandomAffine is faster then RandomElasticDeformation, we choose to
        # # apply RandomAffine 80% of the times and RandomElasticDeformation the rest
        # # Also, there is a 25% chance that none of them will be applied
        # if self.opt.isTrain:
        #     spatial = OneOf(
        #         {RandomAffine(translation=5): 0.8, RandomElasticDeformation(): 0.2},
        #         p=0.75,
        #     )
        #     transforms += [RandomFlip(axes=(0, 2), p=0.8), spatial]

        self.ratio = self.min_size / np.max(self.input_size)
        transforms.append(Resample(self.ratio))
        transforms.append(CropOrPad(self.input_size))
        transform = Compose(transforms)
        return transform
Exemplo n.º 3
0
    def transform(self):

        training_transform = Compose([
            ZNormalization(),
        ])

        return training_transform
Exemplo n.º 4
0
def get_brats(
        data_root='/scratch/weina/dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/',
        fold=1,
        seed=torch.distributed.get_rank()
    if torch.distributed.is_initialized() else 0,
        **kwargs):
    """ data iter for brats
    """
    logging.debug("BratsIter:: fold = {}, seed = {}".format(fold, seed))
    # args for transforms
    d_size, h_size, w_size = 155, 240, 240
    input_size = [7, 223, 223]
    spacing = (d_size / input_size[0], h_size / input_size[1],
               w_size / input_size[2])
    Mean, Std, Max = read_brats_mean(fold, data_root)
    normalize = transforms.Normalize(mean=Mean, std=Std)
    training_transform = Compose([
        # RescaleIntensity((0, 1)),  # so that there are no negative values for RandomMotion
        # RandomMotion(),
        # HistogramStandardization({MRI: landmarks}),
        RandomBiasField(),
        # ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        ToCanonical(),
        Resample(spacing),
        # CropOrPad((48, 60, 48)),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
        normalize
    ])
    val_transform = Compose([Resample(spacing), normalize])

    train = BratsIter(csv_file=os.path.join(data_root, 'IDH_label',
                                            'train_fold_{}.csv'.format(fold)),
                      brats_path=os.path.join(data_root, 'all'),
                      brats_transform=training_transform,
                      shuffle=True)

    val = BratsIter(csv_file=os.path.join(data_root, 'IDH_label',
                                          'val_fold_{}.csv'.format(fold)),
                    brats_path=os.path.join(data_root, 'all'),
                    brats_transform=val_transform,
                    shuffle=False)
    return train, val
Exemplo n.º 5
0
    def build(self):
        SEED = 42
        data = pd.read_csv(self.data)
        ab = data.label

        ############################################
        transforms = [
            RescaleIntensity((0, 1)),
            RandomAffine(),
            transformss.ToTensor(),
        ]
        transform = Compose(transforms)
        #############################################

        dataset_dir = self.dataset_dir
        dataset_dir = Path(dataset_dir)

        images_dir = dataset_dir
        labels_dir = dataset_dir
        image_paths = sorted(images_dir.glob('**/*.nii'))
        label_paths = sorted(labels_dir.glob('**/*.nii'))
        assert len(image_paths) == len(label_paths)

        # These two names are arbitrary
        MRI = 'features'
        BRAIN = 'targets'

        #split dataset into training and validation
        from catalyst.utils import split_dataframe_train_test

        train_image_paths, valid_image_paths = split_dataframe_train_test(
            image_paths, test_size=0.2, random_state=SEED)

        #training data
        subjects = []
        i = 0
        for (image_path, label_path) in zip(train_image_paths, label_paths):
            subject_dict = {
                MRI: torchio.Image(image_path, torchio.INTENSITY),
                BRAIN: ab[i],
            }
            i = i + 1
            subject = torchio.Subject(subject_dict)
            subjects.append(subject)
        train_data = torchio.ImagesDataset(subjects)

        #validation data
        subjects = []
        for (image_path, label_path) in zip(valid_image_paths, label_paths):
            subject_dict = {
                MRI: torchio.Image(image_path, torchio.INTENSITY),
                BRAIN: ab[i],
            }
            i = i + 1
            subject = torchio.Subject(subject_dict)
            subjects.append(subject)
        test_data = torchio.ImagesDataset(subjects)
        return train_data, test_data
Exemplo n.º 6
0
 def __call__(self, img):
     transform1 = self.transform1(self.m1, self.p2)
     transform2 = self.transform2(self.m2, self.p2)
     transform = Compose([transform1, transform2])
     print('Policy Selected: (\'' + self.t1_input + '\', ' +
           str(self.m1_input) + ', ' + str(self.p1) + ', \'' +
           self.t2_input + '\', ' + str(self.m2_input) + ', ' +
           str(self.p2) + ')')
     return transform(img)
Exemplo n.º 7
0
    def get_transformations(self, idx):
        from torchio.transforms import Compose
        import torchio.transforms

        row = self.get_row(idx)
        trsfms_order = row[
            "history"]  #[r for r in row["transfo_order"].split("_") if r != ""]
        trsfm_composition = Compose(trsfms_order)
        return trsfm_composition
Exemplo n.º 8
0
 def test_reproducibility_compose(self):
     trsfm = Compose([RandomNoise(p=0.0), RandomSpike(num_spikes=3, p=1.0)])
     subject1, subject2 = self.get_subjects()
     transformed1 = trsfm(subject1)
     history1 = transformed1.history
     trsfm_hist, seeds_hist = compose_from_history(history=history1)
     transformed2 = self.apply_transforms(subject2,
                                          trsfm_list=trsfm_hist,
                                          seeds_list=seeds_hist)
     data1, data2 = transformed1.img.data, transformed2.img.data
     self.assertTensorEqual(data1, data2)
Exemplo n.º 9
0
def training_network(landmarks, dataset, subjects):
    training_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        RandomMotion(),
        HistogramStandardization({'mri': landmarks}),
        RandomBiasField(),
        ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
    ])

    validation_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        HistogramStandardization({'mri': landmarks}),
        ZNormalization(masking_method=ZNormalization.mean),
    ])

    training_split_ratio = 0.9
    num_subjects = len(dataset)
    num_training_subjects = int(training_split_ratio * num_subjects)

    training_subjects = subjects[:num_training_subjects]
    validation_subjects = subjects[num_training_subjects:]

    training_set = tio.SubjectsDataset(training_subjects,
                                       transform=training_transform)

    validation_set = tio.SubjectsDataset(validation_subjects,
                                         transform=validation_transform)

    print('Training set:', len(training_set), 'subjects')
    print('Validation set:', len(validation_set), 'subjects')
    return training_set, validation_set
Exemplo n.º 10
0
 def _get_default_transforms(self):
     io_transforms = Compose([
         RandomMotion(),
         RandomFlip(axes=(1, )),
         RandomAffine(scales=(0.9, 1.2),
                      degrees=(10),
                      isotropic=False,
                      default_pad_value='otsu',
                      image_interpolation='bspline'),
         RescaleIntensity((0, 1))
     ])
     return io_transforms
Exemplo n.º 11
0
    def get_transformations(self, idx):
        from torchio.transforms import Compose
        import torchio.transforms

        row = self.get_row(idx)
        trsfms_order = [r for r in row["transfo_order"].split("_") if r != ""]
        trsfm_list = []
        trsfm_seeds = []
        for trsfm_name in trsfms_order:
            if trsfm_name not in ["OneOf"]:
                trsfm_history = default_json_str_to_eval_python(
                    row["T_" + trsfm_name])
                trsfm = getattr(torchio.transforms, trsfm_name)
                #trsfm_seed = trsfm_history["seed"] if "seed" in trsfm_history.keys() else None
                if trsfm_name == "RandomMotionFromTimeCourse":
                    trsfm_seeds.append(trsfm_history["seed"])
                    del trsfm_history["seed"]
                    init_args = inspect.getfullargspec(trsfm.__init__).args
                    print(init_args)
                    trsfm_history = {
                        hist_key: self.trsfm_arg_eval(hist_val)
                        for hist_key, hist_val in trsfm_history.items()
                        if hist_key in init_args and hist_key not in
                        ['metrics', 'fitpars', "read_func"]
                    }
                else:
                    trsfm_seeds.append(None)

                if trsfm_name == "RescaleIntensity":
                    trsfm_history[
                        "masking_method"] = None  #self.trsfm_arg_eval(str(trsfm_history["masking_method"]))
                #if "seed" in trsfm_history.keys():
                #    del trsfm_history["seed"]
                print(f"Found transform: {trsfm_name}\n{trsfm_history}")

                trsfm_history = {
                    k: v
                    for k, v in trsfm_history.items()
                    if k not in ["probability"]
                }
                trsfm = trsfm(**trsfm_history)
                #init_args = inspect.getfullargspec(trsfm.__init__).args
                """
                hist_kwargs_init = {hist_key: self.trsfm_arg_eval(hist_val)
                                    for hist_key, hist_val in trsfm_history.items()
                                    if hist_key in init_args and hist_key not in ['metrics', 'fitpars', "read_func"]}

                trsfm = trsfm(**hist_kwargs_init)
                """
                trsfm_list.append(trsfm)
        trsfm_composition = Compose(trsfm_list)
        return trsfm_composition, trsfm_seeds
Exemplo n.º 12
0
        OneOf,
        Compose,
    )

    d_size, h_size, w_size = 155, 240, 240
    input_size = [7, 223, 223]
    spacing = (d_size / input_size[0], h_size / input_size[1],
               w_size / input_size[2])
    training_transform = Compose([
        # RescaleIntensity((0, 1)),  # so that there are no negative values for RandomMotion
        # RandomMotion(),
        # HistogramStandardization({MRI: landmarks}),
        RandomBiasField(),
        # ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        ToCanonical(),
        Resample(spacing),
        # CropOrPad((48, 60, 48)),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
    ])

    fold = 1
    data_root = '../../dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/'

    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    logging.getLogger().setLevel(logging.DEBUG)
Exemplo n.º 13
0
fig, ax = plt.subplots(dpi=100)
plot_histogram(ax, znormed.mri.data, label='Z-normed', alpha=1)
ax.set_title('Intensity values of one sample after z-normalization')
ax.set_xlabel('Intensity')
ax.grid()

training_transform = Compose([
    ToCanonical(),
    #  Resample(4),
    CropOrPad((112, 112, 48), padding_mode=0),  #reflect , original 112,112,48
    RandomMotion(num_transforms=6, image_interpolation='nearest', p=0.2),
    HistogramStandardization({'mri': landmarks}),
    RandomBiasField(p=0.2),
    RandomBlur(p=0.2),
    ZNormalization(masking_method=ZNormalization.mean),
    RandomFlip(axes=['inferior-superior'], flip_probability=0.2),
    #  RandomNoise(std=0.5, p=0.2),
    RandomGhosting(intensity=1.8, p=0.2),
    #  RandomNoise(),
    #  RandomFlip(axes=(0,)),
    #  OneOf({
    #      RandomAffine(): 0.8,
    #      RandomElasticDeformation(): 0.2,
    #  }),
])

validation_transform = Compose([
    ToCanonical(),
    #  Resample(4),
    CropOrPad((112, 112, 48), padding_mode=0),  #original 112,112,48
    #  RandomMotion(num_transforms=6, image_interpolation='nearest', p = 0.2),
Exemplo n.º 14
0
def get_transforms_for_preprocessing(parameters, current_transformations,
                                     train_mode, apply_zero_crop):
    """
    This function gets the pre-processing transformations from the parameters.

    Args:
        parameters (dict): The parameters dictionary.
        current_transformations (list): The current transformations list.
        train_mode (bool): Whether the data is in train mode or not.
        apply_zero_crop (bool): Whether to apply zero crop or not.

    Returns:
        list: The list of pre-processing transformations.
    """

    preprocessing_params_dict = parameters["data_preprocessing"]
    # first, we want to do thresholding, followed by clipping, if it is present - required for inference as well
    normalize_to_apply = None
    if not (preprocessing_params_dict is None):
        # go through preprocessing in the order they are specified
        for preprocess in preprocessing_params_dict:
            preprocess_lower = preprocess.lower()
            # special check for resize and resample
            if preprocess_lower == "resize_patch":
                resize_values = generic_3d_check(
                    preprocessing_params_dict[preprocess])
                current_transformations.append(Resize(resize_values))
            elif preprocess_lower == "resample":
                if "resolution" in preprocessing_params_dict[preprocess]:
                    # Need to take a look here
                    resample_values = generic_3d_check(
                        preprocessing_params_dict[preprocess]["resolution"])
                    current_transformations.append(Resample(resample_values))
            elif preprocess_lower in ["resample_minimum", "resample_min"]:
                if "resolution" in preprocessing_params_dict[preprocess]:
                    resample_values = generic_3d_check(
                        preprocessing_params_dict[preprocess]["resolution"])
                    current_transformations.append(
                        Resample_Minimum(resample_values))
            # special check for histogram_matching
            elif preprocess_lower == "histogram_matching":
                if preprocessing_params_dict[preprocess] is not False:
                    current_transformations.append(
                        global_preprocessing_dict[preprocess_lower](
                            preprocessing_params_dict[preprocess]))
            # special check for stain_normalizer
            elif preprocess_lower == "stain_normalizer":
                if normalize_to_apply is None:
                    normalize_to_apply = global_preprocessing_dict[
                        preprocess_lower](
                            preprocessing_params_dict[preprocess])
            # normalize should be applied at the end
            elif "normalize" in preprocess_lower:
                if normalize_to_apply is None:
                    normalize_to_apply = global_preprocessing_dict[
                        preprocess_lower]
            # preprocessing routines that we only want for training
            elif preprocess_lower in ["crop_external_zero_planes"]:
                if train_mode or apply_zero_crop:
                    current_transformations.append(
                        global_preprocessing_dict["crop_external_zero_planes"](
                            patch_size=parameters["patch_size"]))
            # everything else is taken in the order passed by user
            elif preprocess_lower in global_preprocessing_dict:
                current_transformations.append(
                    global_preprocessing_dict[preprocess_lower](
                        preprocessing_params_dict[preprocess]))

    # normalization type is applied at the end
    if normalize_to_apply is not None:
        current_transformations.append(normalize_to_apply)

    # compose the transformations
    transforms_to_apply = None
    if current_transformations:
        transforms_to_apply = Compose(current_transformations)

    return transforms_to_apply
Exemplo n.º 15
0
    RescaleIntensity,
    Resample,
    ToCanonical,
    ZNormalization,
    CropOrPad,
    HistogramStandardization,
    OneOf,
    Compose,
)

landmarks = np.load('landmarks.npy')

transform = Compose([
    RescaleIntensity((0, 1)),
    HistogramStandardization({'mri': landmarks}),
    ZNormalization(masking_method=ZNormalization.mean),
    ToCanonical(),
    Resample((1, 1, 1)),
    CropOrPad((224, 224, 224)),
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def create_paths(datapath):
    #     Create paths to all nested images
    imagepaths = []
    for root, dirs, files in os.walk(datapath, topdown=False):
        for name in files:
            imagepaths.append(os.path.join(root, name))
    return imagepaths
Exemplo n.º 16
0
def compose_transforms() -> Compose:
    print(f"{ctime()}:  Setting up transformations...")
    """
    # Our Preprocessing Options available in TorchIO are:

    * Intensity
        - NormalizationTransform
        - RescaleIntensity
        - ZNormalization
        - HistogramStandardization
    * Spatial
        - CropOrPad
        - Crop
        - Pad
        - Resample
        - ToCanonical

    We should read and experiment with these, but for now will just use a bunch with
    the default values.

    """

    preprocessors = [
        ToCanonical(p=1),
        ZNormalization(masking_method=None,
                       p=1),  # alternately, use RescaleIntensity
    ]
    """
    # Our Augmentation Options available in TorchIO are:

    * Spatial
        - RandomFlip
        - RandomAffine
        - RandomElasticDeformation

    * Intensity
        - RandomMotion
        - RandomGhosting
        - RandomSpike
        - RandomBiasField
        - RandomBlur
        - RandomNoise
        - RandomSwap



    We should read and experiment with these, but for now will just use a bunch with
    the default values.

    """
    augments = [
        RandomFlip(axes=(0, 1, 2), flip_probability=0.5),
        RandomAffine(image_interpolation="linear",
                     p=0.8),  # default, compromise on speed + quality
        # this will be most processing intensive, leave out for now, see results
        # RandomElasticDeformation(p=1),
        RandomMotion(),
        RandomSpike(),
        RandomBiasField(),
        RandomBlur(),
        RandomNoise(),
    ]
    transform = Compose(preprocessors + augments)
    print(f"{ctime()}:  Transformations registered.")
    return transform
Exemplo n.º 17
0
training_batch_size = 12
validation_batch_size = 6
patch_size = 32
samples_per_volume = 20
max_queue_length = 80

training_name = "denseNet3D_torchIO_patch_{}_samples_{}_ADAMOptim_{}Epochs_BS{}_GlorotWeights_SSIM_1511".format(
    patch_size, samples_per_volume, Epochs, training_batch_size)
train_writer = SummaryWriter(
    os.path.join("runs", "Densenets", training_name + "_training"))
validation_writer = SummaryWriter(
    os.path.join("runs", "Densenets", training_name + "_validation"))

training_subjects, test_subjects, validation_subjects = train_test_val_split()

training_transform = Compose([RescaleIntensity((0, 1)), RandomNoise(p=0.05)])
validation_transform = Compose([RescaleIntensity((0, 1))])
test_transform = Compose([RescaleIntensity((0, 1))])

training_dataset = tio.SubjectsDataset(training_subjects,
                                       transform=training_transform)
validation_dataset = tio.SubjectsDataset(validation_subjects,
                                         transform=validation_transform)
test_dataset = tio.SubjectsDataset(test_subjects, transform=test_transform)
'''Patching'''

patches_training_set = tio.Queue(
    subjects_dataset=training_dataset,
    max_length=max_queue_length,
    samples_per_volume=samples_per_volume,
    sampler=tio.sampler.UniformSampler(patch_size),
Exemplo n.º 18
0
# Each element of subjects_list is an instance of torchio.Subject:
# subject = Subject(
#     one_image=torchio.Image(path_to_one_image, torchio.INTENSITY),
#     another_image=torchio.Image(path_to_another_image, torchio.INTENSITY),
#     a_label=torchio.Image(path_to_a_label, torchio.LABEL),
# )

# Define transforms for data normalization and augmentation
transforms = (
    ZNormalization(),
    RandomNoise(std=(0, 0.25)),
    RandomAffine(scales=(0.9, 1.1), degrees=10),
    RandomFlip(axes=(0,)),
)
transform = Compose(transforms)
subjects_dataset = ImagesDataset(subjects_list, transform)


# Run a benchmark for different numbers of workers
workers = range(mp.cpu_count() + 1)
for num_workers in workers:
    print('Number of workers:', num_workers)


    # Define the dataset as a queue of patches
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        samples_per_volume,
        patch_size,
Exemplo n.º 19
0
state_dict = torch.load(
    "Models/DenseNet_3x3Conv_no Scale Aug/denseNet3D_torchIO_patch_32_samples_20_ADAMOptim_50Epochs_BS6_GlorotWeights_SSIM_3X3.pth"
)
model = DenseNetModel.DenseNet(num_init_features=4,
                               growth_rate=6,
                               block_config=(6, 6, 6)).to("cuda")
model.load_state_dict(state_dict["model_state_dict"])
ground_truths = Path("IXI-T1/Actual_Images")
ground_paths = sorted(ground_truths.glob('*.nii.gz'))
compressed_dirs = [
    sorted(Path((os.path.join("IXI-T1", comp))).glob('*.nii.gz'))
    for comp in os.listdir("IXI-T1") if "Compressed" in comp
]
validation_batch_size = 12
test_transform = Compose([RescaleIntensity((0, 1))])


def test_network(sample):
    patch_size = 48, 48, 48
    patch_overlap = 4, 4, 4
    model.eval()
    grid_sampler = tio.inference.GridSampler(sample, patch_size, patch_overlap)
    patch_loader = torch.utils.data.DataLoader(grid_sampler,
                                               int(validation_batch_size / 4))
    aggregator = tio.inference.GridAggregator(grid_sampler,
                                              overlap_mode="average")
    with torch.no_grad():
        for batch in patch_loader:
            inputs = batch["compressed"][DATA].to("cuda")
            logits = model(inputs)
Exemplo n.º 20
0
def define_transform(transform,
                     p,
                     blur_std=4,
                     motion_trans=10,
                     motion_deg=10,
                     motion_num=2,
                     biascoeff=0.5,
                     noise_std=0.25,
                     affine_trans=10,
                     affine_deg=10,
                     elastic_disp=7.5,
                     resample_size=1,
                     target_shape=0):
    ### (1) try with different blur
    if transform == 'blur':
        transforms = [RandomBlur(std=(blur_std, blur_std), p=p, seed=None)]
        transforms = Compose(transforms)

    ### (2) try with different motion artifacts
    if transform == 'motion':
        transforms = [
            RandomMotion(degrees=motion_deg,
                         translation=motion_trans,
                         num_transforms=motion_num,
                         image_interpolation=Interpolation.LINEAR,
                         p=p,
                         seed=None),
        ]
        transforms = Compose(transforms)
    ### (3) with random bias fields
    if transform == 'biasfield':
        transforms = [
            RandomBiasField(coefficients=biascoeff, order=3, p=p, seed=None)
        ]
        transforms = Compose(transforms)

    ### (4) try with different noise artifacts
    if transform == 'noise':
        transforms = [
            RandomNoise(mean=0, std=(noise_std, noise_std), p=p, seed=None)
        ]
        transforms = Compose(transforms)

    ### (5) try with different warp (affine transformatins)
    if transform == 'affine':
        transforms = [
            RandomAffine(scales=(1, 1),
                         degrees=(affine_deg),
                         isotropic=False,
                         default_pad_value='otsu',
                         image_interpolation=Interpolation.LINEAR,
                         p=p,
                         seed=None)
        ]
        transforms = Compose(transforms)

    ### (6) try with different warp (elastic transformations)
    if transform == 'elastic':
        transforms = [
            RandomElasticDeformation(num_control_points=elastic_disp,
                                     max_displacement=20,
                                     locked_borders=2,
                                     image_interpolation=Interpolation.LINEAR,
                                     p=p,
                                     seed=None),
        ]
        transforms = Compose(transforms)

    if transform == 'resample':
        transforms = [
            Resample(target=resample_size,
                     image_interpolation=Interpolation.LINEAR,
                     p=p),
            CropOrPad(target_shape=target_shape, p=1)
        ]

        transforms = Compose(transforms)

    return transforms
Exemplo n.º 21
0
def get_data_loader(cfg: DictConfig, _) -> dict:
    log = logging.getLogger(__name__)

    transform = Compose([
        RandomMotion(),
        RandomBiasField(),
        RandomNoise(),
        RandomFlip(axes=(0, )),
    ])

    log.info(f"Data loader selected: {cfg['dataset']}")
    try:
        log.info("Attempting to use defined data loader")
        dataset = getattr(datasets, cfg["dataset"])(cfg, transform)
    except ImportError:
        log.info(
            "Not a defined data loader... Attempting to use torchio loader")
        dataset = getattr(torchio.datasets,
                          cfg["dataset"])(root=cfg["base_path"],
                                          transform=transform,
                                          download=True)

    for subject in random.sample(dataset._subjects, cfg["plot_number"]):
        plot_subject(
            subject,
            os.path.join(os.environ["OUTPUT_PATH"], cfg["save_plot_dir"],
                         subject["subject_id"]),
        )

    sampler = GridSampler(patch_size=cfg["patch_size"])
    samples_per_volume = len(sampler._compute_locations(
        dataset[0]))  # type: ignore

    with open_dict(cfg):
        cfg["size"] = dataset[0].spatial_shape

    val_size = max(1, int(0.2 * len(dataset)))
    test_set, train_set, val_set = split_dataset(
        dataset, [21, len(dataset) - val_size - 21, val_size])

    train_loader = __create_data_loader(
        train_set,
        queue_max_length=samples_per_volume * cfg["queue_length"],
        queue_samples_per_volume=samples_per_volume,
        sampler=sampler,
        verbose=log.level > 0,
        batch_size=cfg["batch"],
    )

    val_loader = __create_data_loader(
        val_set,
        queue_max_length=samples_per_volume * cfg["queue_length"],
        queue_samples_per_volume=samples_per_volume,
        sampler=sampler,
        verbose=log.level > 0,
        batch_size=cfg["batch"],
    )

    test_loader = __create_data_loader(
        test_set,
        queue_max_length=samples_per_volume * cfg["queue_length"],
        queue_samples_per_volume=samples_per_volume,
        sampler=sampler,
        verbose=log.level > 0,
        batch_size=cfg["batch"],
    )

    return {
        "data_loader_train": train_loader,
        "data_loader_val": val_loader,
        "data_loader_test": test_loader,
    }
Exemplo n.º 22
0
def ImagesFromDataFrame(dataframe,
                        psize,
                        headers,
                        q_max_length=10,
                        q_samples_per_volume=1,
                        q_num_workers=2,
                        q_verbose=False,
                        sampler='label',
                        train=True,
                        augmentations=None,
                        preprocessing=None,
                        in_memory=False):
    # Finding the dimension of the dataframe for computational purposes later
    num_row, num_col = dataframe.shape
    # num_channels = num_col - 1 # for non-segmentation tasks, this might be different
    # changing the column indices to make it easier
    dataframe.columns = range(0, num_col)
    dataframe.index = range(0, num_row)
    # This list will later contain the list of subjects
    subjects_list = []

    channelHeaders = headers['channelHeaders']
    labelHeader = headers['labelHeader']
    predictionHeaders = headers['predictionHeaders']
    subjectIDHeader = headers['subjectIDHeader']

    sampler = sampler.lower()  # for easier parsing

    # define the control points and swap axes for augmentation
    augmentation_patchAxesPoints = copy.deepcopy(psize)
    for i in range(len(augmentation_patchAxesPoints)):
        augmentation_patchAxesPoints[i] = max(
            round(augmentation_patchAxesPoints[i] / 10),
            1)  # always at least have 1

    # iterating through the dataframe
    resizeCheck = False
    for patient in range(num_row):
        # We need this dict for storing the meta data for each subject
        # such as different image modalities, labels, any other data
        subject_dict = {}
        subject_dict['subject_id'] = dataframe[subjectIDHeader][patient]
        # iterating through the channels/modalities/timepoints of the subject
        for channel in channelHeaders:
            # assigning the dict key to the channel
            if not in_memory:
                subject_dict[str(channel)] = Image(str(
                    dataframe[channel][patient]),
                                                   type=torchio.INTENSITY)
            else:
                img = sitk.ReadImage(str(dataframe[channel][patient]))
                array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
                subject_dict[str(channel)] = Image(
                    tensor=array,
                    type=torchio.INTENSITY,
                    path=dataframe[channel][patient])

            # if resize has been defined but resample is not (or is none)
            if not resizeCheck:
                if not (preprocessing is None) and ('resize' in preprocessing):
                    if (preprocessing['resize'] is not None):
                        resizeCheck = True
                        if not ('resample' in preprocessing):
                            preprocessing['resample'] = {}
                            if not ('resolution' in preprocessing['resample']):
                                preprocessing['resample'][
                                    'resolution'] = resize_image_resolution(
                                        subject_dict[str(channel)].as_sitk(),
                                        preprocessing['resize'])
                        else:
                            print(
                                'WARNING: \'resize\' is ignored as \'resample\' is defined under \'data_processing\', this will be skipped',
                                file=sys.stderr)
                else:
                    resizeCheck = True

        # # for regression
        # if predictionHeaders:
        #     # get the mask
        #     if (subject_dict['label'] is None) and (class_list is not None):
        #         sys.exit('The \'class_list\' parameter has been defined but a label file is not present for patient: ', patient)

        if labelHeader is not None:
            if not in_memory:
                subject_dict['label'] = Image(str(
                    dataframe[labelHeader][patient]),
                                              type=torchio.LABEL)
            else:
                img = sitk.ReadImage(str(dataframe[labelHeader][patient]))
                array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
                subject_dict['label'] = Image(
                    tensor=array,
                    type=torchio.LABEL,
                    path=dataframe[labelHeader][patient])

            subject_dict['path_to_metadata'] = str(
                dataframe[labelHeader][patient])
        else:
            subject_dict['label'] = "NA"
            subject_dict['path_to_metadata'] = str(dataframe[channel][patient])

        # iterating through the values to predict of the subject
        valueCounter = 0
        for values in predictionHeaders:
            # assigning the dict key to the channel
            subject_dict['value_' + str(valueCounter)] = np.array(
                dataframe[values][patient])
            valueCounter = valueCounter + 1

        # Initializing the subject object using the dict
        subject = Subject(subject_dict)

        # padding image, but only for label sampler, because we don't want to pad for uniform
        if 'label' in sampler or 'weight' in sampler:
            psize_pad = list(
                np.asarray(np.round(np.divide(psize, 2)), dtype=int))
            padder = Pad(
                psize_pad, padding_mode='symmetric'
            )  # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            subject = padder(subject)

        # Appending this subject to the list of subjects
        subjects_list.append(subject)

    augmentation_list = []

    # first, we want to do thresholding, followed by clipping, if it is present - required for inference as well
    if not (preprocessing is None):
        if train:  # we want the crop to only happen during training
            if 'crop_external_zero_planes' in preprocessing:
                augmentation_list.append(
                    global_preprocessing_dict['crop_external_zero_planes'](
                        psize))
        for key in ['threshold', 'clip']:
            if key in preprocessing:
                augmentation_list.append(global_preprocessing_dict[key](
                    min=preprocessing[key]['min'],
                    max=preprocessing[key]['max']))

        # first, we want to do the resampling, if it is present - required for inference as well
        if 'resample' in preprocessing:
            if 'resolution' in preprocessing['resample']:
                # resample_split = str(aug).split(':')
                resample_values = tuple(
                    np.array(preprocessing['resample']['resolution']).astype(
                        np.float))
                if len(resample_values) == 2:
                    resample_values = tuple(np.append(resample_values, 1))
                augmentation_list.append(Resample(resample_values))

        # next, we want to do the intensity normalize - required for inference as well
        if 'normalize' in preprocessing:
            augmentation_list.append(global_preprocessing_dict['normalize'])
        elif 'normalize_nonZero' in preprocessing:
            augmentation_list.append(
                global_preprocessing_dict['normalize_nonZero'])
        elif 'normalize_nonZero_masked' in preprocessing:
            augmentation_list.append(
                global_preprocessing_dict['normalize_nonZero_masked'])

    # other augmentations should only happen for training - and also setting the probabilities
    # for the augmentations
    if train and not (augmentations == None):
        for aug in augmentations:
            if aug != 'default_probability':
                actual_function = None

                if aug == 'flip':
                    if ('axes_to_flip' in augmentations[aug]):
                        print(
                            'WARNING: \'flip\' augmentation needs the key \'axis\' instead of \'axes_to_flip\'',
                            file=sys.stderr)
                        augmentations[aug]['axis'] = augmentations[aug][
                            'axes_to_flip']
                    actual_function = global_augs_dict[aug](
                        axes=augmentations[aug]['axis'],
                        p=augmentations[aug]['probability'])
                elif aug in ['rotate_90', 'rotate_180']:
                    for axis in augmentations[aug]['axis']:
                        augmentation_list.append(global_augs_dict[aug](
                            axis=axis, p=augmentations[aug]['probability']))
                elif aug in ['swap', 'elastic']:
                    actual_function = global_augs_dict[aug](
                        patch_size=augmentation_patchAxesPoints,
                        p=augmentations[aug]['probability'])
                elif aug == 'blur':
                    actual_function = global_augs_dict[aug](
                        std=augmentations[aug]['std'],
                        p=augmentations[aug]['probability'])
                elif aug == 'noise':
                    actual_function = global_augs_dict[aug](
                        mean=augmentations[aug]['mean'],
                        std=augmentations[aug]['std'],
                        p=augmentations[aug]['probability'])
                elif aug == 'anisotropic':
                    actual_function = global_augs_dict[aug](
                        axes=augmentations[aug]['axis'],
                        downsampling=augmentations[aug]['downsampling'],
                        p=augmentations[aug]['probability'])
                else:
                    actual_function = global_augs_dict[aug](
                        p=augmentations[aug]['probability'])
                if actual_function is not None:
                    augmentation_list.append(actual_function)

    if augmentation_list:
        transform = Compose(augmentation_list)
    else:
        transform = None
    subjects_dataset = torchio.SubjectsDataset(subjects_list,
                                               transform=transform)
    if not train:
        return subjects_dataset
    if sampler in ('weighted', 'weightedsampler', 'weightedsample'):
        sampler = global_sampler_dict[sampler](psize, probability_map='label')
    else:
        sampler = global_sampler_dict[sampler](psize)
    # all of these need to be read from model.yaml
    patches_queue = torchio.Queue(subjects_dataset,
                                  max_length=q_max_length,
                                  samples_per_volume=q_samples_per_volume,
                                  sampler=sampler,
                                  num_workers=q_num_workers,
                                  shuffle_subjects=True,
                                  shuffle_patches=True,
                                  verbose=q_verbose)
    return patches_queue
Exemplo n.º 23
0
def pre_transform() -> Compose:
    transform = Compose([
        Resample(1.0),
    ])
    return transform
Exemplo n.º 24
0
                       label=torchio.INTENSITY),
    label=torchio.Image(tensor=torch.from_numpy(train_seg),
                        label=torchio.LABEL),
)
valid_subject = torchio.Subject(
    data=torchio.Image(tensor=torch.from_numpy(valid_data),
                       label=torchio.INTENSITY),
    label=torchio.Image(tensor=torch.from_numpy(valid_seg),
                        label=torchio.LABEL),
)
# Define the transforms for the set of training patches
training_transform = Compose([
    RandomNoise(p=0.2),
    RandomFlip(axes=(0, 1, 2)),
    RandomBlur(p=0.2),
    OneOf({
        RandomAffine(): 0.8,
        RandomElasticDeformation(): 0.2,
    }, p=0.5),  # Changed from p=0.75 24/6/20
])
# Create the datasets
training_dataset = torchio.ImagesDataset(
    [train_subject], transform=training_transform)

validation_dataset = torchio.ImagesDataset(
    [valid_subject])
# Define the queue of sampled patches for training and validation
sampler = torchio.data.UniformSampler(PATCH_SIZE)
patches_training_set = torchio.Queue(
    subjects_dataset=training_dataset,
    max_length=MAX_QUEUE_LENGTH,
Exemplo n.º 25
0
                if len(label.shape) <= 3:
                    label = np.expand_dims(label, axis=0)

            self.data_list.append(data)
            self.label_list.append(label)


if __name__ == "__main__":
    slice_n = 99

    spatial = RandomAffine(scales=1,
                           degrees=3,
                           isotropic=False,
                           default_pad_value='otsu',
                           image_interpolation='bspline')
    tmp_transform = Compose([spatial, ZNormalization()])

    dataset = HaNOarsDataset(f'./data/{"HaN_OAR"}_shrink{2}x_padded160', 10)
    dataset.filter_labels([
        OARS_LABELS.EYE_L, OARS_LABELS.EYE_R, OARS_LABELS.LENS_L,
        OARS_LABELS.LENS_R
    ], False)
    dataset.to_numpy()

    tmp_data, tmp_label = dataset[0]

    # tmp_label2 = dataset.label_list[0]
    # unique, counts = np.unique(tmp_label[tmp_label > 0], return_counts=True)
    # print(np.asarray((unique, counts)).T)
    # unique, counts = np.unique(tmp_label2[tmp_label2 > 0], return_counts=True)
    # print(np.asarray((unique, counts)).T)
Exemplo n.º 26
0
 def transform(self):
     test_transform = Compose([
         ZNormalization(),
     ])
     return test_transform
Exemplo n.º 27
0
 def reverse_resample(self, min_value=-1):
     transforms = [Resample(1 / self.ratio)]
     return Compose(transforms + [CropOrPad(self.opt.origshape, padding_mode=min_value)])