Esempio n. 1
0
 def test_reference_path(self):
     reference_image, reference_path = self.get_reference_image_and_path()
     transform = Resample(reference_path)
     transformed = transform(self.sample_subject)
     for image in transformed.values():
         self.assertEqual(reference_image.shape, image.shape)
         self.assertTensorAlmostEqual(reference_image.affine, image.affine)
Esempio n. 2
0
 def test_reference_path(self):
     reference_image, reference_path = self.get_reference_image_and_path()
     transform = Resample(reference_path)
     transformed = transform(self.sample)
     for image in transformed.values():
         self.assertEqual(reference_image.shape, image.shape)
         assert_array_equal(reference_image.affine, image.affine)
Esempio n. 3
0
 def test_spacing(self):
     # Should this raise an error if sizes are different?
     spacing = 2
     transform = Resample(spacing)
     transformed = transform(self.sample_subject)
     for image in transformed.get_images(intensity_only=False):
         self.assertEqual(image.spacing, 3 * (spacing, ))
Esempio n. 4
0
    def create_transforms(self):
        transforms = []

        # clipping to remove outliers (if any)
        # clip_intensity = Lambda(VolumeDataset.clip_image, types_to_apply=[torchio.INTENSITY])
        # transforms.append(clip_intensity)

        rescale = RescaleIntensity((-1, 1), percentiles=(0.5, 99.5))
        # normalize with mu = 0 and sigma = 1/3 to have data in -1...1 almost
        # ZNormalization()

        transforms.append(rescale)

        # transforms = [rescale]
        # # As RandomAffine is faster then RandomElasticDeformation, we choose to
        # # apply RandomAffine 80% of the times and RandomElasticDeformation the rest
        # # Also, there is a 25% chance that none of them will be applied
        # if self.opt.isTrain:
        #     spatial = OneOf(
        #         {RandomAffine(translation=5): 0.8, RandomElasticDeformation(): 0.2},
        #         p=0.75,
        #     )
        #     transforms += [RandomFlip(axes=(0, 2), p=0.8), spatial]

        self.ratio = self.min_size / np.max(self.input_size)
        transforms.append(Resample(self.ratio))
        transforms.append(CropOrPad(self.input_size))
        transform = Compose(transforms)
        return transform
Esempio n. 5
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     transforms = (
         CenterCropOrPad((9, 21, 30)),
         ToCanonical(),
         Resample((1, 1.1, 1.25)),
         RandomFlip(axes=(0, 1, 2), flip_probability=1),
         RandomMotion(proportion_to_augment=1),
         RandomGhosting(proportion_to_augment=1, axes=(0, 1, 2)),
         RandomSpike(),
         RandomNoise(),
         RandomBlur(),
         RandomSwap(patch_size=2, num_iterations=5),
         Lambda(lambda x: 1.5 * x, types_to_apply=INTENSITY),
         RandomBiasField(),
         Rescale((0, 1)),
         ZNormalization(masking_method='label'),
         HistogramStandardization(landmarks_dict=landmarks_dict),
         RandomElasticDeformation(proportion_to_augment=1),
         RandomAffine(),
         Pad((1, 2, 3, 0, 5, 6)),
         Crop((3, 2, 8, 0, 1, 4)),
     )
     transformed = self.get_sample()
     for transform in transforms:
         transformed = transform(transformed)
Esempio n. 6
0
def get_brats(
        data_root='/scratch/weina/dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/',
        fold=1,
        seed=torch.distributed.get_rank()
    if torch.distributed.is_initialized() else 0,
        **kwargs):
    """ data iter for brats
    """
    logging.debug("BratsIter:: fold = {}, seed = {}".format(fold, seed))
    # args for transforms
    d_size, h_size, w_size = 155, 240, 240
    input_size = [7, 223, 223]
    spacing = (d_size / input_size[0], h_size / input_size[1],
               w_size / input_size[2])
    Mean, Std, Max = read_brats_mean(fold, data_root)
    normalize = transforms.Normalize(mean=Mean, std=Std)
    training_transform = Compose([
        # RescaleIntensity((0, 1)),  # so that there are no negative values for RandomMotion
        # RandomMotion(),
        # HistogramStandardization({MRI: landmarks}),
        RandomBiasField(),
        # ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        ToCanonical(),
        Resample(spacing),
        # CropOrPad((48, 60, 48)),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
        normalize
    ])
    val_transform = Compose([Resample(spacing), normalize])

    train = BratsIter(csv_file=os.path.join(data_root, 'IDH_label',
                                            'train_fold_{}.csv'.format(fold)),
                      brats_path=os.path.join(data_root, 'all'),
                      brats_transform=training_transform,
                      shuffle=True)

    val = BratsIter(csv_file=os.path.join(data_root, 'IDH_label',
                                          'val_fold_{}.csv'.format(fold)),
                    brats_path=os.path.join(data_root, 'all'),
                    brats_transform=val_transform,
                    shuffle=False)
    return train, val
Esempio n. 7
0
 def test_spacing(self):
     # Should this raise an error if sizes are different?
     spacing = 2
     transform = Resample(spacing)
     transformed = transform(self.sample)
     for image_dict in transformed.values():
         image = image_dict.as_sitk()
         self.assertEqual(image.GetSpacing(), 3 * (spacing, ))
Esempio n. 8
0
 def test_reference_name(self):
     sample = self.get_inconsistent_sample()
     reference_name = 't1'
     transform = Resample(reference_name)
     transformed = transform(sample)
     ref_image_dict = sample[reference_name]
     for image_dict in transformed.values():
         self.assertEqual(ref_image_dict.shape, image_dict.shape)
         assert_array_equal(ref_image_dict[AFFINE], image_dict[AFFINE])
Esempio n. 9
0
 def test_reference_name(self):
     sample = self.get_inconsistent_sample()
     reference_name = 't1'
     transform = Resample(reference_name)
     transformed = transform(sample)
     reference_image = sample[reference_name]
     for image in transformed.get_images(intensity_only=False):
         self.assertEqual(reference_image.shape, image.shape)
         assert_array_equal(reference_image[AFFINE], image[AFFINE])
Esempio n. 10
0
 def test_reference_name(self):
     subject = self.get_inconsistent_shape_subject()
     reference_name = 't1'
     transform = Resample(reference_name)
     transformed = transform(subject)
     reference_image = subject[reference_name]
     for image in transformed.get_images(intensity_only=False):
         self.assertEqual(reference_image.shape, image.shape)
         self.assertTensorAlmostEqual(reference_image.affine, image.affine)
Esempio n. 11
0
def training_network(landmarks, dataset, subjects):
    training_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        RandomMotion(),
        HistogramStandardization({'mri': landmarks}),
        RandomBiasField(),
        ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
    ])

    validation_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        HistogramStandardization({'mri': landmarks}),
        ZNormalization(masking_method=ZNormalization.mean),
    ])

    training_split_ratio = 0.9
    num_subjects = len(dataset)
    num_training_subjects = int(training_split_ratio * num_subjects)

    training_subjects = subjects[:num_training_subjects]
    validation_subjects = subjects[num_training_subjects:]

    training_set = tio.SubjectsDataset(training_subjects,
                                       transform=training_transform)

    validation_set = tio.SubjectsDataset(validation_subjects,
                                         transform=validation_transform)

    print('Training set:', len(training_set), 'subjects')
    print('Validation set:', len(validation_set), 'subjects')
    return training_set, validation_set
Esempio n. 12
0
 def test_affine(self):
     spacing = 1
     affine_name = 'pre_affine'
     transform = Resample(spacing, pre_affine_name=affine_name)
     transformed = transform(self.sample)
     for image_dict in transformed.values():
         if affine_name in image_dict.keys():
             new_affine = np.eye(4)
             new_affine[0, 3] = 10
             assert_array_equal(image_dict[AFFINE], new_affine)
         else:
             assert_array_equal(image_dict[AFFINE], np.eye(4))
Esempio n. 13
0
 def test_affine(self):
     spacing = 1
     affine_name = 'pre_affine'
     transform = Resample(spacing, pre_affine_name=affine_name)
     transformed = transform(self.sample_subject)
     for image in transformed.values():
         if affine_name in image:
             target_affine = np.eye(4)
             target_affine[:3, 3] = 10, 0, -0.1
             self.assertTensorAlmostEqual(image.affine, target_affine)
         else:
             self.assertTensorEqual(image.affine, np.eye(4))
Esempio n. 14
0
 def test_2d(self):
     image = ScalarImage(tensor=torch.rand(1, 2, 3, 1))
     transform = Resample(0.5)
     shape = transform(image).shape
     self.assertEqual(shape, (1, 4, 6, 1))
Esempio n. 15
0
 def test_missing_reference(self):
     transform = Resample('missing')
     with self.assertRaises(ValueError):
         transform(self.sample_subject)
Esempio n. 16
0
 def test_wrong_target_type(self):
     with self.assertRaises(ValueError):
         Resample(None)
Esempio n. 17
0
 def test_wrong_spacing_value(self):
     with self.assertRaises(ValueError):
         Resample(0)
Esempio n. 18
0
 def test_wrong_spacing_length(self):
     with self.assertRaises(ValueError):
         Resample((1, 2))
Esempio n. 19
0
def generate_dataset(data_path,
                     data_root='',
                     ref_path=None,
                     nb_subjects=5,
                     resampling='mni',
                     masking_method='label'):
    """
    Generate a torchio dataset from a csv file defining paths to subjects.

    :param data_path: path to a csv file
    :param data_root:
    :param ref_path:
    :param nb_subjects:
    :param resampling:
    :param masking_method:
    :return:
    """
    ds = pd.read_csv(data_path)
    ds = ds.dropna(subset=['suj'])
    np.random.seed(0)
    subject_idx = np.random.choice(range(len(ds)), nb_subjects, replace=False)
    directories = ds.iloc[subject_idx, 1]
    dir_list = directories.tolist()
    dir_list = map(lambda partial_dir: data_root + partial_dir, dir_list)

    subject_list = []
    for directory in dir_list:
        img_path = glob.glob(os.path.join(directory, 's*.nii.gz'))[0]

        mask_path = glob.glob(os.path.join(directory, 'niw_Mean*'))[0]
        coregistration_path = glob.glob(os.path.join(directory, 'aff*.txt'))[0]

        coregistration = np.loadtxt(coregistration_path, delimiter=' ')
        coregistration = np.linalg.inv(coregistration)

        subject = torchio.Subject(
            t1=torchio.Image(img_path,
                             torchio.INTENSITY,
                             coregistration=coregistration),
            label=torchio.Image(mask_path, torchio.LABEL),
            #ref=torchio.Image(ref_path, torchio.INTENSITY)
            # coregistration=coregistration,
        )
        print('adding img {} \n mask {}\n'.format(img_path, mask_path))
        subject_list.append(subject)

    transforms = [
        # Resample(1),
        RescaleIntensity((0, 1), (0, 99), masking_method=masking_method),
    ]

    if resampling == 'mni':
        # resampling_transform = ResampleWithFoV(
        #     target=nib.load(ref_path), image_interpolation=Interpolation.BSPLINE, coregistration_key='coregistration'
        # )
        resampling_transform = Resample(
            target='ref',
            image_interpolation=Interpolation.BSPLINE,
            coregistration='coregistration')
        transforms.insert(0, resampling_transform)
    elif resampling == 'mm':
        # resampling_transform = ResampleWithFoV(target=nib.load(ref_path), image_interpolation=Interpolation.BSPLINE)
        resampling_transform = Resample(
            target=ref_path, image_interpolation=Interpolation.BSPLINE)
        transforms.insert(0, resampling_transform)

    transform = Compose(transforms)

    return torchio.ImagesDataset(subject_list, transform=transform)
Esempio n. 20
0
 def test_missing_affine(self):
     transform = Resample(1, pre_affine_name='missing')
     with self.assertRaises(ValueError):
         transform(self.sample_subject)
Esempio n. 21
0
def ImagesFromDataFrame(dataframe,
                        psize,
                        headers,
                        q_max_length=10,
                        q_samples_per_volume=1,
                        q_num_workers=2,
                        q_verbose=False,
                        sampler='label',
                        train=True,
                        augmentations=None,
                        preprocessing=None,
                        in_memory=False):
    # Finding the dimension of the dataframe for computational purposes later
    num_row, num_col = dataframe.shape
    # num_channels = num_col - 1 # for non-segmentation tasks, this might be different
    # changing the column indices to make it easier
    dataframe.columns = range(0, num_col)
    dataframe.index = range(0, num_row)
    # This list will later contain the list of subjects
    subjects_list = []

    channelHeaders = headers['channelHeaders']
    labelHeader = headers['labelHeader']
    predictionHeaders = headers['predictionHeaders']
    subjectIDHeader = headers['subjectIDHeader']

    sampler = sampler.lower()  # for easier parsing

    # define the control points and swap axes for augmentation
    augmentation_patchAxesPoints = copy.deepcopy(psize)
    for i in range(len(augmentation_patchAxesPoints)):
        augmentation_patchAxesPoints[i] = max(
            round(augmentation_patchAxesPoints[i] / 10),
            1)  # always at least have 1

    # iterating through the dataframe
    resizeCheck = False
    for patient in range(num_row):
        # We need this dict for storing the meta data for each subject
        # such as different image modalities, labels, any other data
        subject_dict = {}
        subject_dict['subject_id'] = dataframe[subjectIDHeader][patient]
        # iterating through the channels/modalities/timepoints of the subject
        for channel in channelHeaders:
            # assigning the dict key to the channel
            if not in_memory:
                subject_dict[str(channel)] = Image(str(
                    dataframe[channel][patient]),
                                                   type=torchio.INTENSITY)
            else:
                img = sitk.ReadImage(str(dataframe[channel][patient]))
                array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
                subject_dict[str(channel)] = Image(
                    tensor=array,
                    type=torchio.INTENSITY,
                    path=dataframe[channel][patient])

            # if resize has been defined but resample is not (or is none)
            if not resizeCheck:
                if not (preprocessing is None) and ('resize' in preprocessing):
                    if (preprocessing['resize'] is not None):
                        resizeCheck = True
                        if not ('resample' in preprocessing):
                            preprocessing['resample'] = {}
                            if not ('resolution' in preprocessing['resample']):
                                preprocessing['resample'][
                                    'resolution'] = resize_image_resolution(
                                        subject_dict[str(channel)].as_sitk(),
                                        preprocessing['resize'])
                        else:
                            print(
                                'WARNING: \'resize\' is ignored as \'resample\' is defined under \'data_processing\', this will be skipped',
                                file=sys.stderr)
                else:
                    resizeCheck = True

        # # for regression
        # if predictionHeaders:
        #     # get the mask
        #     if (subject_dict['label'] is None) and (class_list is not None):
        #         sys.exit('The \'class_list\' parameter has been defined but a label file is not present for patient: ', patient)

        if labelHeader is not None:
            if not in_memory:
                subject_dict['label'] = Image(str(
                    dataframe[labelHeader][patient]),
                                              type=torchio.LABEL)
            else:
                img = sitk.ReadImage(str(dataframe[labelHeader][patient]))
                array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
                subject_dict['label'] = Image(
                    tensor=array,
                    type=torchio.LABEL,
                    path=dataframe[labelHeader][patient])

            subject_dict['path_to_metadata'] = str(
                dataframe[labelHeader][patient])
        else:
            subject_dict['label'] = "NA"
            subject_dict['path_to_metadata'] = str(dataframe[channel][patient])

        # iterating through the values to predict of the subject
        valueCounter = 0
        for values in predictionHeaders:
            # assigning the dict key to the channel
            subject_dict['value_' + str(valueCounter)] = np.array(
                dataframe[values][patient])
            valueCounter = valueCounter + 1

        # Initializing the subject object using the dict
        subject = Subject(subject_dict)

        # padding image, but only for label sampler, because we don't want to pad for uniform
        if 'label' in sampler or 'weight' in sampler:
            psize_pad = list(
                np.asarray(np.round(np.divide(psize, 2)), dtype=int))
            padder = Pad(
                psize_pad, padding_mode='symmetric'
            )  # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            subject = padder(subject)

        # Appending this subject to the list of subjects
        subjects_list.append(subject)

    augmentation_list = []

    # first, we want to do thresholding, followed by clipping, if it is present - required for inference as well
    if not (preprocessing is None):
        if train:  # we want the crop to only happen during training
            if 'crop_external_zero_planes' in preprocessing:
                augmentation_list.append(
                    global_preprocessing_dict['crop_external_zero_planes'](
                        psize))
        for key in ['threshold', 'clip']:
            if key in preprocessing:
                augmentation_list.append(global_preprocessing_dict[key](
                    min=preprocessing[key]['min'],
                    max=preprocessing[key]['max']))

        # first, we want to do the resampling, if it is present - required for inference as well
        if 'resample' in preprocessing:
            if 'resolution' in preprocessing['resample']:
                # resample_split = str(aug).split(':')
                resample_values = tuple(
                    np.array(preprocessing['resample']['resolution']).astype(
                        np.float))
                if len(resample_values) == 2:
                    resample_values = tuple(np.append(resample_values, 1))
                augmentation_list.append(Resample(resample_values))

        # next, we want to do the intensity normalize - required for inference as well
        if 'normalize' in preprocessing:
            augmentation_list.append(global_preprocessing_dict['normalize'])
        elif 'normalize_nonZero' in preprocessing:
            augmentation_list.append(
                global_preprocessing_dict['normalize_nonZero'])
        elif 'normalize_nonZero_masked' in preprocessing:
            augmentation_list.append(
                global_preprocessing_dict['normalize_nonZero_masked'])

    # other augmentations should only happen for training - and also setting the probabilities
    # for the augmentations
    if train and not (augmentations == None):
        for aug in augmentations:
            if aug != 'default_probability':
                actual_function = None

                if aug == 'flip':
                    if ('axes_to_flip' in augmentations[aug]):
                        print(
                            'WARNING: \'flip\' augmentation needs the key \'axis\' instead of \'axes_to_flip\'',
                            file=sys.stderr)
                        augmentations[aug]['axis'] = augmentations[aug][
                            'axes_to_flip']
                    actual_function = global_augs_dict[aug](
                        axes=augmentations[aug]['axis'],
                        p=augmentations[aug]['probability'])
                elif aug in ['rotate_90', 'rotate_180']:
                    for axis in augmentations[aug]['axis']:
                        augmentation_list.append(global_augs_dict[aug](
                            axis=axis, p=augmentations[aug]['probability']))
                elif aug in ['swap', 'elastic']:
                    actual_function = global_augs_dict[aug](
                        patch_size=augmentation_patchAxesPoints,
                        p=augmentations[aug]['probability'])
                elif aug == 'blur':
                    actual_function = global_augs_dict[aug](
                        std=augmentations[aug]['std'],
                        p=augmentations[aug]['probability'])
                elif aug == 'noise':
                    actual_function = global_augs_dict[aug](
                        mean=augmentations[aug]['mean'],
                        std=augmentations[aug]['std'],
                        p=augmentations[aug]['probability'])
                elif aug == 'anisotropic':
                    actual_function = global_augs_dict[aug](
                        axes=augmentations[aug]['axis'],
                        downsampling=augmentations[aug]['downsampling'],
                        p=augmentations[aug]['probability'])
                else:
                    actual_function = global_augs_dict[aug](
                        p=augmentations[aug]['probability'])
                if actual_function is not None:
                    augmentation_list.append(actual_function)

    if augmentation_list:
        transform = Compose(augmentation_list)
    else:
        transform = None
    subjects_dataset = torchio.SubjectsDataset(subjects_list,
                                               transform=transform)
    if not train:
        return subjects_dataset
    if sampler in ('weighted', 'weightedsampler', 'weightedsample'):
        sampler = global_sampler_dict[sampler](psize, probability_map='label')
    else:
        sampler = global_sampler_dict[sampler](psize)
    # all of these need to be read from model.yaml
    patches_queue = torchio.Queue(subjects_dataset,
                                  max_length=q_max_length,
                                  samples_per_volume=q_samples_per_volume,
                                  sampler=sampler,
                                  num_workers=q_num_workers,
                                  shuffle_subjects=True,
                                  shuffle_patches=True,
                                  verbose=q_verbose)
    return patches_queue
Esempio n. 22
0
 def reverse_resample(self, min_value=-1):
     transforms = [Resample(1 / self.ratio)]
     return Compose(transforms + [CropOrPad(self.opt.origshape, padding_mode=min_value)])
Esempio n. 23
0
def pre_transform() -> Compose:
    transform = Compose([
        Resample(1.0),
    ])
    return transform
Esempio n. 24
0
def get_transforms_for_preprocessing(parameters, current_transformations,
                                     train_mode, apply_zero_crop):
    """
    This function gets the pre-processing transformations from the parameters.

    Args:
        parameters (dict): The parameters dictionary.
        current_transformations (list): The current transformations list.
        train_mode (bool): Whether the data is in train mode or not.
        apply_zero_crop (bool): Whether to apply zero crop or not.

    Returns:
        list: The list of pre-processing transformations.
    """

    preprocessing_params_dict = parameters["data_preprocessing"]
    # first, we want to do thresholding, followed by clipping, if it is present - required for inference as well
    normalize_to_apply = None
    if not (preprocessing_params_dict is None):
        # go through preprocessing in the order they are specified
        for preprocess in preprocessing_params_dict:
            preprocess_lower = preprocess.lower()
            # special check for resize and resample
            if preprocess_lower == "resize_patch":
                resize_values = generic_3d_check(
                    preprocessing_params_dict[preprocess])
                current_transformations.append(Resize(resize_values))
            elif preprocess_lower == "resample":
                if "resolution" in preprocessing_params_dict[preprocess]:
                    # Need to take a look here
                    resample_values = generic_3d_check(
                        preprocessing_params_dict[preprocess]["resolution"])
                    current_transformations.append(Resample(resample_values))
            elif preprocess_lower in ["resample_minimum", "resample_min"]:
                if "resolution" in preprocessing_params_dict[preprocess]:
                    resample_values = generic_3d_check(
                        preprocessing_params_dict[preprocess]["resolution"])
                    current_transformations.append(
                        Resample_Minimum(resample_values))
            # special check for histogram_matching
            elif preprocess_lower == "histogram_matching":
                if preprocessing_params_dict[preprocess] is not False:
                    current_transformations.append(
                        global_preprocessing_dict[preprocess_lower](
                            preprocessing_params_dict[preprocess]))
            # special check for stain_normalizer
            elif preprocess_lower == "stain_normalizer":
                if normalize_to_apply is None:
                    normalize_to_apply = global_preprocessing_dict[
                        preprocess_lower](
                            preprocessing_params_dict[preprocess])
            # normalize should be applied at the end
            elif "normalize" in preprocess_lower:
                if normalize_to_apply is None:
                    normalize_to_apply = global_preprocessing_dict[
                        preprocess_lower]
            # preprocessing routines that we only want for training
            elif preprocess_lower in ["crop_external_zero_planes"]:
                if train_mode or apply_zero_crop:
                    current_transformations.append(
                        global_preprocessing_dict["crop_external_zero_planes"](
                            patch_size=parameters["patch_size"]))
            # everything else is taken in the order passed by user
            elif preprocess_lower in global_preprocessing_dict:
                current_transformations.append(
                    global_preprocessing_dict[preprocess_lower](
                        preprocessing_params_dict[preprocess]))

    # normalization type is applied at the end
    if normalize_to_apply is not None:
        current_transformations.append(normalize_to_apply)

    # compose the transformations
    transforms_to_apply = None
    if current_transformations:
        transforms_to_apply = Compose(current_transformations)

    return transforms_to_apply
Esempio n. 25
0
    ToCanonical,
    ZNormalization,
    CropOrPad,
    HistogramStandardization,
    OneOf,
    Compose,
)

landmarks = np.load('landmarks.npy')

transform = Compose([
    RescaleIntensity((0, 1)),
    HistogramStandardization({'mri': landmarks}),
    ZNormalization(masking_method=ZNormalization.mean),
    ToCanonical(),
    Resample((1, 1, 1)),
    CropOrPad((224, 224, 224)),
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def create_paths(datapath):
    #     Create paths to all nested images
    imagepaths = []
    for root, dirs, files in os.walk(datapath, topdown=False):
        for name in files:
            imagepaths.append(os.path.join(root, name))
    return imagepaths

Esempio n. 26
0
        Compose,
    )

    d_size, h_size, w_size = 155, 240, 240
    input_size = [7, 223, 223]
    spacing = (d_size / input_size[0], h_size / input_size[1],
               w_size / input_size[2])
    training_transform = Compose([
        # RescaleIntensity((0, 1)),  # so that there are no negative values for RandomMotion
        # RandomMotion(),
        # HistogramStandardization({MRI: landmarks}),
        RandomBiasField(),
        # ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        ToCanonical(),
        Resample(spacing),
        # CropOrPad((48, 60, 48)),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
    ])

    fold = 1
    data_root = '../../dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/'

    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    logging.getLogger().setLevel(logging.DEBUG)
Esempio n. 27
0
def define_transform(transform,
                     p,
                     blur_std=4,
                     motion_trans=10,
                     motion_deg=10,
                     motion_num=2,
                     biascoeff=0.5,
                     noise_std=0.25,
                     affine_trans=10,
                     affine_deg=10,
                     elastic_disp=7.5,
                     resample_size=1,
                     target_shape=0):
    ### (1) try with different blur
    if transform == 'blur':
        transforms = [RandomBlur(std=(blur_std, blur_std), p=p, seed=None)]
        transforms = Compose(transforms)

    ### (2) try with different motion artifacts
    if transform == 'motion':
        transforms = [
            RandomMotion(degrees=motion_deg,
                         translation=motion_trans,
                         num_transforms=motion_num,
                         image_interpolation=Interpolation.LINEAR,
                         p=p,
                         seed=None),
        ]
        transforms = Compose(transforms)
    ### (3) with random bias fields
    if transform == 'biasfield':
        transforms = [
            RandomBiasField(coefficients=biascoeff, order=3, p=p, seed=None)
        ]
        transforms = Compose(transforms)

    ### (4) try with different noise artifacts
    if transform == 'noise':
        transforms = [
            RandomNoise(mean=0, std=(noise_std, noise_std), p=p, seed=None)
        ]
        transforms = Compose(transforms)

    ### (5) try with different warp (affine transformatins)
    if transform == 'affine':
        transforms = [
            RandomAffine(scales=(1, 1),
                         degrees=(affine_deg),
                         isotropic=False,
                         default_pad_value='otsu',
                         image_interpolation=Interpolation.LINEAR,
                         p=p,
                         seed=None)
        ]
        transforms = Compose(transforms)

    ### (6) try with different warp (elastic transformations)
    if transform == 'elastic':
        transforms = [
            RandomElasticDeformation(num_control_points=elastic_disp,
                                     max_displacement=20,
                                     locked_borders=2,
                                     image_interpolation=Interpolation.LINEAR,
                                     p=p,
                                     seed=None),
        ]
        transforms = Compose(transforms)

    if transform == 'resample':
        transforms = [
            Resample(target=resample_size,
                     image_interpolation=Interpolation.LINEAR,
                     p=p),
            CropOrPad(target_shape=target_shape, p=1)
        ]

        transforms = Compose(transforms)

    return transforms