Beispiel #1
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
Beispiel #2
0
    def __getitem__(self, idx):
        # Generate one batch of data
        # ScalarImage expect 4DTensor, so add a singleton dimension
        image = self.CT_partition[idx].unsqueeze(0)
        mask = self.mask_partition[idx].unsqueeze(0)
        if self.augment:
            aug = tio.Compose([tio.OneOf\
                               ({tio.RandomAffine(scales= (0.9, 1.1, 0.9, 1.1, 1, 1),
                                                  degrees= (5.0, 5.0, 0)): 0.35,
                                 tio.RandomElasticDeformation(num_control_points=9,
                                                  max_displacement= (0.1, 0.1, 0.1),
                                                  locked_borders= 2,
                                                  image_interpolation= 'linear'): 0.35,
                                 tio.RandomFlip(axes=(2,)):.3}),
                              ])
            subject = tio.Subject(ct=tio.ScalarImage(tensor=image),
                                  mask=tio.ScalarImage(tensor=mask))
            output = aug(subject)
            augmented_image = output['ct']
            augmented_mask = output['mask']
            image = augmented_image.data
            mask = augmented_mask.data
        # note that mask is integer
        mask = mask.type(torch.IntTensor)
        image = image.type(torch.FloatTensor)

        #The tensor we pass into ScalarImage is C x W x H x D, so permute axes to
        # C x D x H x W. At the end we have N x 1 x D x H x W.
        image = image.permute(0, 3, 2, 1)
        mask = mask.permute(0, 3, 2, 1)

        # Return image and mask pair tensors
        return image, mask
Beispiel #3
0
 def test_no_sample(self):
     with tempfile.NamedTemporaryFile(delete=False) as f:
         input_dict = {'image': tio.ScalarImage(f.name)}
         subject = tio.Subject(input_dict)
         with self.assertRaises(RuntimeError):
             with self.assertWarns(UserWarning):
                 tio.RandomFlip()(subject)
def get_train_transform(landmarks_path, resection_params=None):
    spatial_transform = tio.Compose((
        tio.OneOf({
            tio.RandomAffine(): 0.9,
            tio.RandomElasticDeformation(): 0.1,
        }),
        tio.RandomFlip(),
    ))
    resolution_transform = tio.OneOf(
        (
            tio.RandomAnisotropy(),
            tio.RandomBlur(),
        ),
        p=0.75,
    )
    transforms = []
    if resection_params is not None:
        transforms.append(get_simulation_transform(resection_params))
    if landmarks_path is not None:
        transforms.append(
            tio.HistogramStandardization({'image': landmarks_path}))
    transforms.extend([
        # tio.RandomGamma(p=0.2),
        resolution_transform,
        tio.RandomGhosting(p=0.2),
        tio.RandomSpike(p=0.2),
        tio.RandomMotion(p=0.2),
        tio.RandomBiasField(p=0.5),
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.RandomNoise(p=0.75),  # always after ZNorm and after blur!
        spatial_transform,
        get_tight_crop(),
    ])
    return tio.Compose(transforms)
Beispiel #5
0
def load_kidney_seg(data_shape, batch=3, workers=4, transform=None):

    #take input transform and apply it after clip, normalization, resize
    if transform == None:
        transform = tio.RandomFlip(p=0.)
    #preprocess all
    clippy = Lambda(lambda x: torch.clip(x, -80, 300),
                    types_to_apply=[tio.INTENSITY])
    normal = RescaleIntensity((0., 1.))
    resize = Lambda(lambda x: torch.squeeze(
        interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0))
    rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL])
    transform = tio.Compose([clippy, normal, resize, rounding, transform])

    subject_list = []
    for i in range(210):
        pt_image = ("data/case_{:05d}/imaging.nii.gz".format(i))
        pt_label = ("data/case_{:05d}/segmentation.nii.gz".format(i))
        subject_list.append(
            tio.Subject(img=tio.ScalarImage(pt_image),
                        label=tio.LabelMap(pt_label)))
    dataset = tio.SubjectsDataset(subject_list, transform=transform)
    return DataLoader(dataset,
                      num_workers=workers,
                      batch_size=batch,
                      pin_memory=True)
Beispiel #6
0
 def test_apply_transform_to_file(self):
     transform = tio.RandomFlip()
     apply_transform_to_file(
         self.get_image_path('input'),
         transform,
         self.get_image_path('output'),
         verbose=True,
     )
Beispiel #7
0
def load_pretrain_datasets(data_shape, batch=3, workers=4, transform=None):

    data_path = '/home/mitch/Data/MSD/'
    directories = sorted(glob.glob(data_path + '*/'))

    loaders = []  #var to store dataloader for each task
    datasets = []  #store dataset objects before turning into loaders

    if transform == None:
        transform = tio.RandomFlip(p=0.)
    #preprocess all
    clippy = Lambda(lambda x: torch.clip(x, -80, 300),
                    types_to_apply=[tio.INTENSITY])
    normal = RescaleIntensity((0., 1.))
    resize = Lambda(lambda x: torch.squeeze(
        interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0))
    rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL])
    transform = tio.Compose([clippy, normal, resize, rounding, transform])

    #deal with weird shapes
    braintransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 2], dim=0),
                            types_to_apply=[tio.INTENSITY])
    braintransform = tio.Compose([braintransform, transform])
    prostatetransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 1], dim=0),
                               types_to_apply=[tio.INTENSITY])
    prostatetransform = tio.Compose([prostatetransform, transform])

    for i, directory in enumerate(directories):
        images = sorted(glob.glob(directory + 'imagesTr/*'))
        segs = sorted(glob.glob(directory + 'labelsTr/*'))

        subject_list = []

        for image, seg in zip(images, segs):

            subject_list.append(
                tio.Subject(img=tio.ScalarImage(image),
                            label=tio.LabelMap(seg)))

        #handle special cases
        if i == 0:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=braintransform))
        elif i == 4:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=prostatetransform))
        else:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=transform))

        loaders.append(
            DataLoader(datasets[-1],
                       num_workers=workers,
                       batch_size=batch,
                       pin_memory=True))

    return loaders
Beispiel #8
0
    def __init__(self, use_tio_flip=True):

        self.flip = tio.RandomFlip(p=0.5) if use_tio_flip is True else Flip3D()
        self.affine = tio.RandomAffine(p=0.5,
                                       scales=0.1,
                                       degrees=5,
                                       translation=0,
                                       image_interpolation="nearest")
        self.random_noise = tio.RandomNoise(
            p=0.5, std=(0, 0.1), include=["x"])  # don't apply noise to mask
        self.transform = tio.Compose(
            [self.flip, self.affine, self.random_noise], include=["x", "y"])
Beispiel #9
0
def get_transform(augmentation, landmarks_path):
    import datasets
    import torchio as tio
    if augmentation:
        return datasets.get_train_transform(landmarks_path)
    else:
        preprocess = datasets.get_test_transform(landmarks_path)
        augment = tio.Compose((tio.RandomFlip(),
                               tio.OneOf({
                                   tio.RandomAffine(): 0.8,
                                   tio.RandomElasticDeformation(): 0.2,
                               })))
        return tio.Compose((preprocess, augment))
Beispiel #10
0
def main(hdf_file, plot_dir):
    os.makedirs(plot_dir, exist_ok=True)

    # setup the datasource
    extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS))
    indexing_strategy = extr.SliceIndexing()
    dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor)

    seed = 1
    np.random.seed(seed)
    sample_idx = 55

    # set up transformations without augmentation
    transforms_augmentation = []
    transforms_before_augmentation = [tfm.Permute(permutation=(2, 0, 1)), ]  # to have the channel-dimension first
    transforms_after_augmentation = [tfm.Squeeze(entries=(defs.KEY_LABELS,)), ]  # get rid of the channel-dimension for the labels
    train_transforms = tfm.ComposeTransform(transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
    dataset.set_transform(train_transforms)
    sample = dataset[sample_idx]
    plot_sample(plot_dir, 'none', sample)

    # augmentation with pymia
    transforms_augmentation = [augm.RandomRotation90(axes=(-2, -1)), augm.RandomMirror()]
    train_transforms = tfm.ComposeTransform(
        transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
    dataset.set_transform(train_transforms)
    sample = dataset[sample_idx]
    plot_sample(plot_dir, 'pymia', sample)

    # augmentation with batchgenerators
    transforms_augmentation = [BatchgeneratorsTransform([
        bg_tfm.spatial_transforms.MirrorTransform(axes=(0, 1), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS),
        bg_tfm.noise_transforms.GaussianBlurTransform(blur_sigma=(0.2, 1.0), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS),
    ])]
    train_transforms = tfm.ComposeTransform(
        transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
    dataset.set_transform(train_transforms)
    sample = dataset[sample_idx]
    plot_sample(plot_dir, 'batchgenerators', sample)

    # augmentation with TorchIO
    transforms_augmentation = [TorchIOTransform(
        [tio.RandomFlip(axes=('LR'), flip_probability=1.0, keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed),
         tio.RandomAffine(scales=(0.9, 1.2), degrees=(10), isotropic=False, default_pad_value='otsu',
                          image_interpolation='NEAREST', keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed),
         ])]
    train_transforms = tfm.ComposeTransform(
        transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
    dataset.set_transform(train_transforms)
    sample = dataset[sample_idx]
    plot_sample(plot_dir, 'torchio', sample)
Beispiel #11
0
def get_dataset(
    input_path,
    tta_iterations=0,
    interpolation='bspline',
    tolerance=0.1,
    mni_transform_path=None,
):
    if mni_transform_path is None:
        image = tio.ScalarImage(input_path)
    else:
        affine = tio.io.read_matrix(mni_transform_path)
        image = tio.ScalarImage(input_path, **{TO_MNI: affine})
    subject = tio.Subject({IMAGE_NAME: image})
    landmarks = np.array([
        0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384,
        3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873,
        63.4408591, 100.
    ])
    hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks})
    preprocess_transforms = [
        tio.ToCanonical(),
        hist_std,
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    ]
    zooms = nib.load(input_path).header.get_zooms()
    pixdim = np.array(zooms)
    diff_to_1_iso = np.abs(pixdim - 1)
    if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None:
        kwargs = {'image_interpolation': interpolation}
        if mni_transform_path is not None:
            kwargs['pre_affine_name'] = TO_MNI
            kwargs['target'] = tio.datasets.Colin27().t1.path
        resample_transform = tio.Resample(**kwargs)
        preprocess_transforms.append(resample_transform)
    preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop'))
    preprocess_transform = tio.Compose(preprocess_transforms)
    no_aug_dataset = tio.SubjectsDataset([subject],
                                         transform=preprocess_transform)

    aug_subjects = tta_iterations * [subject]
    if not aug_subjects:
        return no_aug_dataset
    augment_transform = tio.Compose((
        preprocess_transform,
        tio.RandomFlip(),
        tio.RandomAffine(image_interpolation=interpolation),
    ))
    aug_dataset = tio.SubjectsDataset(aug_subjects,
                                      transform=augment_transform)
    dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset))
    return dataset
Beispiel #12
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
def get_context(device, variables, augmentation_mode, **kwargs):
    context = base_config.get_context(device, variables, **kwargs)
    context.file_paths.append(os.path.abspath(__file__))
    context.config.update({'augmentation_mode': augmentation_mode})

    # training_transform is a tio.Compose where the second transform is the augmentation
    dataset_defn = context.get_component_definition("dataset")
    training_transform = dataset_defn['params']['transforms']['training']

    dwi_augmentation = ReconstructMeanDWI(num_dwis=(1, 7),
                                          num_directions=(1, 3),
                                          directionality=(4, 10))

    noise = tio.RandomNoise(std=0.035, p=0.3)
    blur = tio.RandomBlur((0, 1), p=0.2)
    standard_augmentations = tio.Compose([
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.RandomElasticDeformation(p=0.5,
                                     num_control_points=(7, 7, 4),
                                     locked_borders=1,
                                     image_interpolation='bspline',
                                     exclude="full_dwi"),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.OneOf([
            tio.Compose([blur, noise]),
            tio.Compose([noise, blur]),
        ])
    ],
                                         exclude="full_dwi")

    if augmentation_mode == 'no_augmentation':
        training_transform.transforms.pop(1)
    elif augmentation_mode == 'standard':
        training_transform.transforms[1] = standard_augmentations
    elif augmentation_mode == 'dwi_reconstruction':
        training_transform.transforms[1] = dwi_augmentation
    elif augmentation_mode == 'combined':
        training_transform.transforms[1] = tio.Compose(
            [dwi_augmentation, standard_augmentations])
    else:
        raise ValueError(f"Invalid augmentation mode {augmentation_mode}")

    return context
Beispiel #14
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
Beispiel #15
0
    subjects.append(subject)

# собираем особый датасет torchio с пациентами
dataset = tio.SubjectsDataset(subjects)

# приводим маску к 1 классу
if config.to_one_class:
    for subject in dataset.dry_iter():
        subject['mask'] = one(subject['mask'])

training_transform = tio.Compose([
    tio.Resample(4),
    tio.ZNormalization(
        masking_method=tio.ZNormalization.mean
    ),  # вот эту штуку все рекомендовали на форумах torchio. 
    tio.RandomFlip(p=0.25),
    tio.RandomNoise(p=0.25),
    # !!!  Приходится насильно переводить тензоры в float
    tio.Lambda(to_float)
])

validation_transform = tio.Compose([
    tio.Resample(4),
    tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    tio.RandomNoise(p=0.25),
    tio.Lambda(to_float)
])


def prepare_dataload(patches=True):
Beispiel #16
0
    def __getitem__(self, index):
        file_npy = self.df.iloc[index][0]
        assert os.path.exists(file_npy), f'npy file {file_npy} does not exists'
        array_npy = np.load(file_npy)  # shape (D,H,W)
        if array_npy.ndim > 3:
            array_npy = np.squeeze(array_npy)
        array_npy = np.expand_dims(array_npy, axis=0)  #(C,D,H,W)

        #if depth_interval==2  (128,128,128)->(64,128,128)
        depth_start_random = random.randint(0, 20) % self.depth_interval
        array_npy = array_npy[:, depth_start_random::self.depth_interval, :, :]

        subject1 = tio.Subject(oct=tio.ScalarImage(tensor=array_npy), )
        subjects_list = [subject1]

        crop_h = random.randint(0, self.random_crop_h)
        # pad_h_a, pad_h_b = math.floor(crop_h / 2), math.ceil(crop_h / 2)
        pad_h_a = random.randint(0, crop_h)
        pad_h_b = crop_h - pad_h_a

        transform_1 = tio.Compose([
            # tio.OneOf({
            #     tio.RandomAffine(): 0.8,
            #     tio.RandomElasticDeformation(): 0.2,
            # }, p=0.75,),
            # tio.RandomGamma(log_gamma=(-0.3, 0.3)),
            tio.RandomFlip(axes=2, flip_probability=0.5),
            # tio.RandomAffine(
            #     scales=(0, 0, 0.9, 1.1, 0, 0), degrees=(0, 0, -5, 5, 0, 0),
            #     image_interpolation='nearest'),
            tio.Crop(cropping=(0, 0, crop_h, 0, 0, 0)),  # (d,h,w) crop height
            tio.Pad(padding=(0, 0, pad_h_a, pad_h_b, 0, 0)),
            tio.RandomNoise(std=(0, self.random_noise)),
            tio.Resample(self.resample_ratio),
            # tio.RescaleIntensity((0, 255))
        ])

        if random.randint(1, 20) == 5:
            transform = tio.Compose([tio.Resample(self.resample_ratio)])
        else:
            transform = transform_1

        subjects_dataset = tio.SubjectsDataset(subjects_list,
                                               transform=transform)

        inputs = subjects_dataset[0]['oct'][tio.DATA]
        array_3d = np.squeeze(inputs.cpu().numpy())  #shape: (D,H,W)
        array_3d = array_3d.astype(np.uint8)

        if self.imgaug_iaa is not None:
            self.imgaug_iaa.deterministic = True
        else:
            if (self.image_shape is None) or\
                    (array_3d.shape[1:3]) == (self.image_shape[0:2]):  # (H,W)
                array_4d = np.expand_dims(array_3d, axis=-1)  #(D,H,W,C)

        if 'array_4d' not in locals().keys():
            list_images = []
            for i in range(array_3d.shape[0]):
                img = array_3d[i, :, :]  #(H,W)
                if (img.shape[0:2]) != (self.image_shape[0:2]):  # (H,W)
                    img = cv2.resize(
                        img, (self.image_shape[1],
                              self.image_shape[0]))  # resize(width,height)

                # cvtColor do not support float64
                img = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_GRAY2BGR)
                # other wise , MultiplyBrightness error
                img = img.astype(np.uint8)
                if self.imgaug_iaa is not None:
                    img = self.imgaug_iaa(image=img)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                list_images.append(img)

            array_4d = np.array(list_images)  # (D,H,W)
            array_4d = np.expand_dims(array_4d, axis=-1)  #(D,H,W,C)

        if self.imgaug_iaa is not None:
            self.imgaug_iaa.deterministic = False

        if self.channel_first:
            array_4d = np.transpose(array_4d,
                                    (3, 0, 1, 2))  #(D,H,W,C)->(C,D,H,W)

        array_4d = array_4d.astype(np.float32)
        array_4d = array_4d / 255.
        # if array_4d.shape != (1, 64, 64, 64):
        #     print(file_npy)

        # https://pytorch.org/docs/stable/data.html
        # It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing).
        # tensor_x = torch.from_numpy(array_4d)

        label = int(self.df.iloc[index][1])

        return array_4d, label
Beispiel #17
0
slice_affine = random_affine_zoom(slice_grid)
to_pil(slice_affine)

#Random flip
#Flipping images is a very cheap way to perform data augmentation.
#In medical images, it's very common to flip the images horizontally.
#We can specify the dimensions indices when instantiating a RandomFlip transform.
#However,if we don't know the image orientation, we can't know which dimension
#corresponds to the lateral axis. In TorchIO, you can use anatomical
#labels instead, so that you don't need to figure out image orientation
#to know which axis you would like to flip. To make sure the transform modifies
#the image, we will use the inferior-superior (longitudinal) axis and a flip
#probability of 1. If the flipping happened along any other axis, we might not
#notice it using this visualization.

random_flip = tio.RandomFlip(axes=['inferior-superior'], flip_probability=1)
fpg_flipped = random_flip(fpg_ras)
show_fpg(fpg_flipped)

#Random elastic deformation To simulate anatomical variations in our images,
#we can apply a non-linear deformation using RandomElasticDeformation.

max_displacement = 10, 10, 0  # in x, y and z directions
random_elastic = tio.RandomElasticDeformation(max_displacement=max_displacement, seed=0)
slice_elastic = random_elastic(slice_grid)
to_pil(slice_elastic)

#As explained in the documentation, one can change the number of grid control
#points to set the deformation smoothness.

random_elastic = tio.RandomElasticDeformation(
Beispiel #18
0
 def test_tensor_flip(self):
     sample_input = torch.ones((4, 30, 30, 30))
     tio.RandomFlip()(sample_input)
Beispiel #19
0
def data_loader(train_data_folder=None,
                validation_data_folder=None,
                test_data_folder=None,
                debug_data_folder=None,
                num_workers=1,
                aug_type='aug0'):

    if train_data_folder is not None:

        train_images_dir = Path(os.path.join(train_data_folder, 'images'))
        train_image_paths = sorted(train_images_dir.glob('*.mha'))

        train_subjects = []
        for image_path in train_image_paths:
            subject = tio.Subject(image=tio.ScalarImage(image_path), )
            train_subjects.append(subject)

        # z,W,H,N=train_subjects[0].image.data.shape
        # for i,subject in zip(range(len(train_subjects)), train_subjects):
        #     train_subjects[i].image.data=train_subjects[i].image.data.reshape(N,z,W,H)

        print("Shape of training image before loading is: " +
              str(train_subjects[0].image.data.shape))

        if aug_type == 'aug0':
            training_transform = tio.Compose([])
        elif aug_type == 'aug1':
            training_transform = tio.Compose([tio.RandomFlip(axes=(0, 1, 2))])
        elif aug_type == 'aug2':
            training_transform = tio.Compose([
                tio.RandomFlip(axes=(0, 1, 2)),
                tio.RandomNoise(mean=0, std=0.1),
                tio.RandomBlur(std=(2.5, 2.5, 0.0))
            ])
        elif aug_type == 'aug4':
            training_transform = tio.Compose([
                tio.RandomAffine(degrees=0,
                                 scales=(0.15, 0.15, 0),
                                 translation=(40, 40, 0),
                                 default_pad_value='minimum',
                                 image_interpolation='linear'),
                tio.RandomFlip(axes=(0, 1, 2)),
                tio.RandomNoise(mean=0, std=0.1),
                tio.RandomBlur(std=(2.5, 2.5, 0.0))
            ])
        elif aug_type == 'aug5':
            training_transform = tio.Compose([
                tio.RandomFlip(axes=(0, 1, 2)),
                tio.RandomAffine(degrees=(0, 0, 0, 0, -10, 10),
                                 scales=0,
                                 translation=0,
                                 center='image',
                                 default_pad_value='minimum',
                                 image_interpolation='linear')
            ])

        train_set = tio.SubjectsDataset(train_subjects,
                                        transform=training_transform)

        # # Plotting the first patient for inspection
        # print("Plotting first subject from the train set...")
        # Single_Subject = train_set[0]
        # Single_Subject.plot()

        print('Training set:', len(train_set), 'subjects')
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=1,
                                                   shuffle=True,
                                                   num_workers=num_workers)

    else:
        train_loader = None

    if validation_data_folder is not None:

        validation_images_dir = Path(
            os.path.join(validation_data_folder, 'images'))
        validation_image_paths = sorted(validation_images_dir.glob('*.mha'))

        validation_subjects = []
        for image_path in validation_image_paths:
            subject = tio.Subject(image=tio.ScalarImage(image_path), )
            validation_subjects.append(subject)

        validation_transform = tio.Compose([])

        validation_set = tio.SubjectsDataset(validation_subjects,
                                             transform=validation_transform)

        print('Validation set:', len(validation_set), 'subjects')

        validation_loader = torch.utils.data.DataLoader(
            validation_set, batch_size=1, num_workers=num_workers)

    else:
        validation_loader = None

    if test_data_folder is not None:

        test_images_dir = Path(os.path.join(test_data_folder, 'images'))
        test_image_paths = sorted(test_images_dir.glob('*.mha'))

        test_subjects = []
        for image_path in test_image_paths:
            subject = tio.Subject(image=tio.ScalarImage(image_path), )
            test_subjects.append(subject)

        test_transform = tio.Compose([])

        test_set = tio.SubjectsDataset(test_subjects, transform=test_transform)

        print('Test set:', len(test_set), 'subjects')

        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=1,
                                                  num_workers=num_workers)

    else:
        test_loader = None

    if debug_data_folder is not None:

        debug_images_dir = Path(os.path.join(debug_data_folder, 'images'))
        debug_image_paths = [sorted(debug_images_dir.glob('*.mha'))[0]]

        debug_subjects = []
        for image_path in debug_image_paths:
            subject = tio.Subject(image=tio.ScalarImage(image_path), )
            debug_subjects.append(subject)

        # z,W,H,N=debug_subjects[0].image.data.shape
        # for i,subject in zip(range(len(debug_subjects)), debug_subjects):
        #     debug_subjects[i].image.data=debug_subjects[i].image.data.reshape(N,z,H,W)

        print("Shape of debug image before loading is: " +
              str(debug_subjects[0].image.data.shape))

        if aug_type == 'aug0':
            debug_transform = tio.Compose([])
        elif aug_type == 'aug1':
            debug_transform = tio.Compose([tio.RandomFlip(axes=(0, 1, 2))])
        elif aug_type == 'aug2':
            debug_transform = tio.Compose([
                tio.RandomFlip(axes=(0, 1, 2)),
                tio.RandomNoise(mean=0, std=0.1),
                tio.RandomBlur(std=(2.5, 2.5, 0.0))
            ])
        elif aug_type == 'aug4':
            debug_transform = tio.Compose([
                tio.RandomAffine(degrees=0,
                                 scales=(0.15, 0.15, 0),
                                 translation=(40, 40, 0),
                                 default_pad_value='minimum',
                                 image_interpolation='linear'),
                tio.RandomFlip(axes=(0, 1, 2)),
                tio.RandomNoise(mean=0, std=0.1),
                tio.RandomBlur(std=(2.5, 2.5, 0.0))
            ])
        elif aug_type == 'aug5':
            debug_transform = tio.Compose([
                tio.RandomFlip(axes=(0, 1, 2)),
                tio.RandomAffine(degrees=(0, 0, 0, 0, -10, 10),
                                 scales=0,
                                 translation=0,
                                 center='image',
                                 default_pad_value='minimum',
                                 image_interpolation='linear')
            ])

        debug_set = tio.SubjectsDataset(debug_subjects,
                                        transform=debug_transform)

        # Plotting the first patient for inspection
        # print("Plotting first subject from the debug set...")
        # Single_Subject = debug_set[0]
        # Single_Subject.plot()

        print('Debug set:', len(debug_set), 'subjects')

        debug_loader = torch.utils.data.DataLoader(debug_set,
                                                   batch_size=1,
                                                   num_workers=num_workers)

    else:
        debug_loader = None

    return train_loader, validation_loader, test_loader, debug_loader
Beispiel #20
0
def get_context(device, variables, fold=0, **kwargs):
    context = TorchContext(device, name="msseg2", variables=variables)
    context.file_paths.append(os.path.abspath(__file__))
    context.config = config = {'fold': fold, 'patch_size': 96}

    input_images = ["flair_time01", "flair_time02"]

    subject_loader = ComposeLoaders([
        ImageLoader(glob_pattern="flair_time01*",
                    image_name='flair_time01',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="flair_time02*",
                    image_name='flair_time02',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="brain_mask.*",
                    image_name='brain_mask',
                    image_constructor=tio.LabelMap,
                    label_values={"brain": 1}),
        ImageLoader(glob_pattern="ground_truth.*",
                    image_name="ground_truth",
                    image_constructor=tio.LabelMap,
                    label_values={"lesion": 1}),
    ])

    cohorts = {}
    cohorts['all'] = RequireAttributes(input_images)
    cohorts['validation'] = RandomFoldFilter(num_folds=5,
                                             selection=fold,
                                             seed=0xDEADBEEF)
    cohorts['training'] = NegateFilter(cohorts['validation'])

    common_transforms_1 = tio.Compose([
        SetDataType(torch.float),
        EnforceConsistentAffine(source_image_name='flair_time01'),
        TargetResample(target_spacing=1, tolerance=0.11),
        CropToMask('brain_mask'),
        MinSizePad(config['patch_size'])
    ])

    augmentations = tio.Compose([
        RandomPermuteDimensions(),
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.OneOf(
            {
                tio.RandomElasticDeformation():
                0.2,
                tio.RandomAffine(scales=0.2,
                                 degrees=45,
                                 default_pad_value='otsu'):
                0.8,
            },
            p=0.75),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.RandomBlur((0, 1), p=0.2),
        tio.RandomNoise(std=0.1, p=0.35)
    ])

    common_transforms_2 = tio.Compose([
        tio.RescaleIntensity((-1, 1.), (0.05, 99.5)),
        ConcatenateImages(image_names=["flair_time01", "flair_time02"],
                          image_channels=[1, 1],
                          new_image_name="X"),
        RenameProperty(old_name='ground_truth', new_name='y'),
        CustomOneHot(include="y"),
    ])

    transforms = {
        'default':
        tio.Compose([common_transforms_1, common_transforms_2]),
        'training':
        tio.Compose([
            common_transforms_1, augmentations, common_transforms_2,
            ImageFromLabels(new_image_name="patch_probability",
                            label_weights=[('brain_mask', 'brain', 1),
                                           ('y', 'lesion', 100)])
        ]),
    }

    context.add_component("dataset",
                          SubjectFolder,
                          root='$DATASET_PATH',
                          subject_path="",
                          subject_loader=subject_loader,
                          cohorts=cohorts,
                          transforms=transforms)
    context.add_component("model",
                          ModularUNet,
                          in_channels=2,
                          out_channels=2,
                          filters=[40, 40, 80, 80, 120, 120],
                          depth=6,
                          block_params={'residual': True},
                          downsample_class=BlurConv3d,
                          downsample_params={
                              'kernel_size': 3,
                              'stride': 2,
                              'padding': 1
                          },
                          upsample_class=BlurConvTranspose3d,
                          upsample_params={
                              'kernel_size': 3,
                              'stride': 2,
                              'padding': 1,
                              'output_padding': 0
                          })
    context.add_component("optimizer",
                          SGD,
                          params="self.model.parameters()",
                          lr=0.001,
                          momentum=0.95)
    context.add_component("criterion",
                          HybridLogisticDiceLoss,
                          logistic_class_weights=[1, 100])

    training_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            'y_pred_eval', 'y_eval'),
                            log_name='training_segmentation_eval',
                            interval=15),
        ScheduledEvaluation(evaluator=LabelMapEvaluator('y_pred_eval'),
                            log_name='training_label_eval',
                            interval=15),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "random",
            'flair_time02',
            'y_pred_eval',
            'y_eval',
            slice_id=0,
            legend=True,
            ncol=2,
            interesting_slice=True,
            split_subjects=False),
                            log_name=f"contour_image",
                            interval=15),
    ]

    validation_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            "y_pred_eval", "y_eval"),
                            log_name="segmentation_eval",
                            cohorts=["validation"],
                            interval=50),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "interesting",
            'flair_time02',
            'y_pred_eval',
            'y_eval',
            slice_id=0,
            legend=True,
            ncol=1,
            interesting_slice=True,
            split_subjects=True),
                            log_name=f"contour_image",
                            cohorts=["validation"],
                            interval=50),
    ]

    def scoring_function(evaluation_dict):
        # Grab the output of the SegmentationEvaluator
        seg_eval = evaluation_dict['segmentation_eval']['validation']

        # Take mean dice, while accounting for subjects which have no lesions.
        # Dice is 0/0 = nan when the model correctly outputs no lesions. This is counted as a score of 1.0.
        # Dice is (>0)/0 = posinf when the model incorrectly predicts lesions when there are none.
        # This is counted as a score of 0.0.
        dice = torch.tensor(seg_eval["subject_stats"]['dice.lesion'])
        dice = dice.nan_to_num(nan=1.0, posinf=0.0)
        score = dice.mean()

        return score

    train_predictor = StandardPredict(image_names=['X', 'y'])
    validation_predictor = PatchPredict(patch_batch_size=32,
                                        patch_size=config['patch_size'],
                                        patch_overlap=(config['patch_size'] //
                                                       8),
                                        padding_mode=None,
                                        overlap_mode='average',
                                        image_names=['X'])

    patch_sampler = tio.WeightedSampler(patch_size=config['patch_size'],
                                        probability_map='patch_probability')
    train_dataloader_factory = PatchDataLoader(max_length=100,
                                               samples_per_volume=1,
                                               sampler=patch_sampler)
    validation_dataloader_factory = StandardDataLoader(
        sampler=SequentialSampler)

    context.add_component(
        "trainer",
        SegmentationTrainer,
        training_batch_size=4,
        save_rate=100,
        scoring_interval=50,
        scoring_function=scoring_function,
        one_time_evaluators=[],
        training_evaluators=training_evaluators,
        validation_evaluators=validation_evaluators,
        max_iterations_with_no_improvement=2000,
        train_predictor=train_predictor,
        validation_predictor=validation_predictor,
        train_dataloader_factory=train_dataloader_factory,
        validation_dataloader_factory=validation_dataloader_factory)

    return context
Beispiel #21
0
def dataloader(handles, mode = 'train'):
    # If pickle exists, load it
    try:
        with open('../inputs/flpickles/' + mode + '.pickle', 'rb') as f:
            images = pickle.load(f)
            
    except:
        
        images = {}
        images['Image'] = []
        images['Label'] = []
        images['Gap'] = []
        images['ID'] = []

        # Data augmentations
        random_flip = tio.RandomFlip(axes=1)
        random_flip2 = tio.RandomFlip(axes=2)
        random_affine = tio.RandomAffine(seed=0, scales=(3, 3))
        random_elastic = tio.RandomElasticDeformation(
            max_displacement=(0, 20, 40),
            num_control_points=20,
            seed=0,
        )
        rescale = tio.RescaleIntensity((-1, 1), percentiles=(1, 99))
        standardize_foreground = tio.ZNormalization(masking_method=lambda x: x > x.mean())
        blur = tio.RandomBlur(seed=0)
        standardize = tio.ZNormalization()
        add_noise = tio.RandomNoise(std=0.5, seed=42)
        add_spike = tio.RandomSpike(seed=42)
        add_ghosts = tio.RandomGhosting(intensity=1.5, seed=42)
        add_motion = tio.RandomMotion(num_transforms=6, image_interpolation='nearest', seed=42)
        swap = tio.RandomSwap(patch_size = 7)

        # For each image
        for idx, row in handles.iterrows():
            im_aug = []
            lb_aug = []
            gap_aug = []
            imgs = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)   # change patch shape if necessary
            lbs = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)
            gaps = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)
            im = io.imread(row['Image'])
            im = im / 255 # Normalization
            im = np.expand_dims(im, axis=0)
            imgs[0] = im
            im_aug.append(imgs)
            images['ID'].append(row['ID'])
            if mode == 'train':
                im_flip1 = random_flip(im)
                imgs[0] = im_flip1
                im_aug.append(imgs)
                im_flip2 = random_flip2(im)
                imgs[0] = im_flip2
                im_aug.append(imgs)
                im_affine = random_affine(im)
                imgs[0] = im_affine
                im_aug.append(imgs)
                im_elastic = random_elastic(im)
                imgs[0] = im_elastic
                im_aug.append(imgs)
                im_rescale = rescale(im)
                imgs[0] = im_rescale
                im_aug.append(imgs)
                im_standard = standardize_foreground(im)
                imgs[0] = im_standard
                im_aug.append(imgs)
                im_blur = blur(im)
                imgs[0] = im_blur
                im_aug.append(imgs)
                im_noisy = add_noise(standardize(im))
                imgs[0] = im_noisy
                im_aug.append(imgs)
                im_spike = add_spike(im)
                imgs[0] = im_spike
                im_aug.append(imgs)
                im_ghost = add_ghosts(im)
                imgs[0] = im_ghost
                im_aug.append(imgs)
                im_motion = add_motion(im)
                imgs[0] = im_motion
                im_aug.append(imgs)
                im_swap = swap(im)
                imgs[0] = im_swap
                im_aug.append(imgs)
            images['Image'].append(np.array(im_aug))
            
            if mode != 'test':
                lb = io.imread(row['Label'])
                lb = label_converter(lb)
                lb = np.expand_dims(lb, axis=0)
                lbs[0] = lb
                lb_aug.append(lbs)
                gap = io.imread(row['Gap'])
                gap = np.expand_dims(gap, axis = 0)
                gaps[0] = gap
                gap_aug.append(gaps)
                if mode == 'train':
                    lb_flip1 = random_flip(lb)
                    lbs[0] = lb_flip1
                    lb_aug.append(lbs)
                    lb_flip2 = random_flip2(lb)
                    lbs[0] = lb_flip2
                    lb_aug.append(lbs)
                    lb_affine = random_affine(lb)
                    lbs[0] = lb_affine
                    lb_aug.append(lbs)
                    lb_elastic = random_elastic(lb)
                    lbs[0] = lb_elastic
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)

                    gap_flip1 = random_flip(gap)
                    gaps[0] = gap_flip1
                    gap_aug.append(gaps)
                    gap_flip2 = random_flip2(gap)
                    gaps[0] = gap_flip2
                    gap_aug.append(gaps)
                    gap_affine = random_affine(gap)
                    gaps[0] = gap_affine
                    gap_aug.append(gaps)
                    gap_elastic = random_elastic(gap)
                    gaps[0] = gap_elastic
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                images['Label'].append(np.array(lb_aug))
                images['Gap'].append(np.array(gap_aug))
        # Save images
        with open("../inputs/flpickles/" + mode + '.pickle', 'wb') as f:
            pickle.dump(images, f)
        with open('../inputs/flpickles/' + mode + '.pickle', 'rb') as f:
            images = pickle.load(f)

    return images
Beispiel #22
0
 def test_history(self):
     transformed = tio.RandomFlip()(self.sample_subject)
     self.assertIs(len(transformed.history), 1)
Beispiel #23
0
    if in_channels == 3:
        subject = tio.Subject(
            hr=tio.ScalarImage(t2_file),
            lr_1=tio.ScalarImage(t2_file),
            lr_2=tio.ScalarImage(t2_file),
            lr_3=tio.ScalarImage(t2_file),
        )

    subjects.append(subject)

print('DHCP Dataset size:', len(subjects), 'subjects')

# DATA AUGMENTATION
normalization = tio.ZNormalization()
spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75)
flip = tio.RandomFlip(axes=('LR', ), flip_probability=0.5)

tocanonical = tio.ToCanonical()

b1 = tio.Blur(std=(0.001, 0.001, 1), include='lr_1')  #blur
d1 = tio.Resample((0.8, 0.8, 2), include='lr_1')  #downsampling
u1 = tio.Resample(target='hr', include='lr_1')  #upsampling

if in_channels == 3:
    b2 = tio.Blur(std=(0.001, 1, 0.001), include='lr_2')  #blur
    d2 = tio.Resample((0.8, 2, 0.8), include='lr_2')  #downsampling
    u2 = tio.Resample(target='hr', include='lr_2')  #upsampling

    b3 = tio.Blur(std=(1, 0.001, 0.001), include='lr_3')  #blur
    d3 = tio.Resample((2, 0.8, 0.8), include='lr_3')  #downsampling
    u3 = tio.Resample(target='hr', include='lr_3')  #upsampling
Beispiel #24
0
import torch
import torchio as tio
import matplotlib.pyplot as plt

torch.manual_seed(0)

batch_size = 4
subject = tio.datasets.FPG()
subject.remove_image('seg')
subjects = 4 * [subject]

transform = tio.Compose((
    tio.ToCanonical(),
    tio.RandomGamma(p=0.75),
    tio.RandomBlur(p=0.5),
    tio.RandomFlip(),
    tio.RescaleIntensity((-1, 1)),
))

dataset = tio.SubjectsDataset(subjects, transform=transform)

transformed = dataset[0]
print('Applied transforms:')  # noqa: T001
pprint.pprint(transformed.history)  # noqa: T003
print('\nComposed transform to reproduce history:')  # noqa: T001
print(transformed.get_composed_history())  # noqa: T001
print('\nComposed transform to invert applied transforms when possible:'
      )  # noqa: T001, E501
print(transformed.get_inverse_transform(ignore_intensity=False))  # noqa: T001

loader = torch.utils.data.DataLoader(
def get_context(
    device,
    variables,
    fold=0,
    predict_hbt=False,
    training_batch_size=4,
):
    context = TorchContext(device, name="dmri-hippo", variables=variables)
    context.file_paths.append(os.path.abspath(__file__))
    context.config.update({'fold': fold})

    input_images = ["mean_dwi", "md", "fa"]
    output_labels = ["whole_roi", "hbt_roi"]

    subject_loader = ComposeLoaders([
        ImageLoader(glob_pattern="mean_dwi.*",
                    image_name='mean_dwi',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="md.*",
                    image_name='md',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="fa.*",
                    image_name='fa',
                    image_constructor=tio.ScalarImage),
        # ImageLoader(glob_pattern="full_dwi.*", image_name='full_dwi', image_constructor=tio.ScalarImage),
        # TensorLoader(glob_pattern="full_dwi_grad.b", tensor_name="grad", belongs_to="full_dwi"),
        ImageLoader(glob_pattern="whole_roi.*",
                    image_name="whole_roi",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_whole": 1,
                        "right_whole": 2
                    }),
        ImageLoader(glob_pattern="whole_roi_alt.*",
                    image_name="whole_roi_alt",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_whole": 1,
                        "right_whole": 2
                    }),
        ImageLoader(glob_pattern="hbt_roi.*",
                    image_name="hbt_roi",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_head": 1,
                        "left_body": 2,
                        "left_tail": 3,
                        "right_head": 4,
                        "right_body": 5,
                        "right_tail": 6
                    }),
        ImageLoader(glob_pattern="../../atlas/whole_roi_union.*",
                    image_name="whole_roi_union",
                    image_constructor=tio.LabelMap,
                    uniform=True),
        AttributeLoader(glob_pattern='attributes.*'),
        AttributeLoader(
            glob_pattern='../../attributes/cross_validation_split.json',
            multi_subject=True,
            uniform=True),
        AttributeLoader(
            glob_pattern='../../attributes/ab300_validation_subjects.json',
            multi_subject=True,
            uniform=True),
        AttributeLoader(
            glob_pattern='../../attributes/cbbrain_test_subjects.json',
            multi_subject=True,
            uniform=True),
    ])

    cohorts = {}
    cohorts['all'] = RequireAttributes(input_images)
    cohorts['cross_validation'] = RequireAttributes(['fold'])
    cohorts['training'] = ComposeFilters(
        [cohorts['cross_validation'],
         ForbidAttributes({"fold": fold})])
    cohorts['cbbrain_validation'] = ComposeFilters(
        [cohorts['cross_validation'],
         RequireAttributes({"fold": fold})])
    cohorts['cbbrain_test'] = RequireAttributes({'cbbrain_test': True})
    cohorts['ab300_validation'] = RequireAttributes({'ab300_validation': True})
    cohorts['ab300_validation_plot'] = ComposeFilters(
        [cohorts['ab300_validation'],
         RandomSelectFilter(num_subjects=20)])
    cohorts['cbbrain'] = RequireAttributes({"protocol": "cbbrain"})
    cohorts['ab300'] = RequireAttributes({"protocol": "ab300"})
    cohorts['rescans'] = ForbidAttributes({"rescan_id": "None"})
    cohorts['fasd'] = RequireAttributes({"pathologies": "FASD"})
    cohorts['inter_rater'] = RequireAttributes(["whole_roi_alt"])

    common_transforms_1 = tio.Compose([
        tio.CropOrPad((96, 88, 24),
                      padding_mode='minimum',
                      mask_name='whole_roi_union'),
        CustomRemapLabels(remapping=[("right_whole", 2, 1)],
                          masking_method="Right",
                          include=["whole_roi"]),
        CustomRemapLabels(remapping=[("right_head", 4, 1),
                                     ("right_body", 5, 2),
                                     ("right_tail", 6, 3)],
                          masking_method="Right",
                          include=["hbt_roi"]),
    ])

    noise = tio.RandomNoise(std=0.035, p=0.3)
    blur = tio.RandomBlur((0, 1), p=0.2)
    standard_augmentations = tio.Compose([
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.RandomElasticDeformation(p=0.5,
                                     num_control_points=(7, 7, 4),
                                     locked_borders=1,
                                     image_interpolation='bspline',
                                     exclude=["full_dwi"]),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.OneOf([
            tio.Compose([blur, noise]),
            tio.Compose([noise, blur]),
        ])
    ],
                                         exclude="full_dwi")

    common_transforms_2 = tio.Compose([
        tio.RescaleIntensity((-1., 1.), (0.5, 99.5)),
        ConcatenateImages(image_names=["mean_dwi", "md", "fa"],
                          image_channels=[1, 1, 1],
                          new_image_name="X"),
        RenameProperty(old_name="hbt_roi" if predict_hbt else "whole_roi",
                       new_name="y"),
        CustomOneHot(include=["y"])
    ])

    transforms = {
        'default':
        tio.Compose([common_transforms_1, common_transforms_2]),
        'training':
        tio.Compose(
            [common_transforms_1, standard_augmentations,
             common_transforms_2]),
    }

    context.add_component("dataset",
                          SubjectFolder,
                          root='$DATASET_PATH',
                          subject_path="subjects",
                          subject_loader=subject_loader,
                          cohorts=cohorts,
                          transforms=transforms)
    context.add_component("model",
                          NestedResUNet,
                          input_channels=3,
                          output_channels=4 if predict_hbt else 2,
                          filters=40,
                          dropout_p=0.2)
    context.add_component("optimizer",
                          Adam,
                          params="self.model.parameters()",
                          lr=0.0002)
    context.add_component("criterion", HybridLogisticDiceLoss)

    training_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            'y_pred_eval', 'y_eval'),
                            log_name='training_segmentation_eval',
                            interval=10),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "Axial",
            'mean_dwi',
            'y_pred_eval',
            'y_eval',
            slice_id=12,
            legend=True,
            ncol=2,
            split_subjects=False),
                            log_name=f"contour_image_training",
                            interval=50),
    ]

    curve_params = {
        "left_whole":
        np.array([-1.96312119e-01, 9.46668029e+00, 2.33635173e+03]),
        "right_whole":
        np.array([-2.68467331e-01, 1.67925603e+01, 2.07224236e+03])
    }

    validation_evaluators = [
        ScheduledEvaluation(evaluator=LabelMapEvaluator(
            'y_pred_eval',
            curve_params=curve_params,
            curve_attribute='age',
            stats_to_output=('volume', 'error', 'absolute_error',
                             'squared_error', 'percent_diff')),
                            log_name="predicted_label_eval",
                            cohorts=['cbbrain_validation', 'ab300_validation'],
                            interval=50),
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            "y_pred_eval", "y_eval"),
                            log_name="segmentation_eval",
                            cohorts=['cbbrain_validation'],
                            interval=50),
        ScheduledEvaluation(
            evaluator=ContourImageEvaluator("Axial",
                                            "mean_dwi",
                                            "y_pred_eval",
                                            "y_eval",
                                            slice_id=10,
                                            legend=True,
                                            ncol=5,
                                            split_subjects=False),
            log_name="contour_image_axial",
            cohorts=['cbbrain_validation', 'ab300_validation_plot'],
            interval=250),
        ScheduledEvaluation(
            evaluator=ContourImageEvaluator("Coronal",
                                            "mean_dwi",
                                            "y_pred_eval",
                                            "y_eval",
                                            slice_id=44,
                                            legend=True,
                                            ncol=2,
                                            split_subjects=False),
            log_name="contour_image_coronal",
            cohorts=['cbbrain_validation', 'ab300_validation_plot'],
            interval=250),
    ]

    def scoring_function(evaluation_dict):
        # Grab the output of the SegmentationEvaluator
        seg_eval_cbbrain = evaluation_dict['segmentation_eval'][
            'cbbrain_validation']["summary_stats"]

        # Get the mean dice for each label (the mean is across subjects)
        cbbrain_dice = seg_eval_cbbrain['mean', :, 'dice']

        # Now take the mean across all labels
        cbbrain_dice = cbbrain_dice.mean()
        score = cbbrain_dice
        return score

    train_predictor = StandardPredict(sagittal_split=True,
                                      image_names=['X', 'y'])
    validation_predictor = StandardPredict(sagittal_split=True,
                                           image_names=['X'])

    train_dataloader_factory = StandardDataLoader(sampler=RandomSampler)
    validation_dataloader_factory = StandardDataLoader(
        sampler=SequentialSampler)

    context.add_component(
        "trainer",
        SegmentationTrainer,
        training_batch_size=training_batch_size,
        save_rate=100,
        scoring_interval=50,
        scoring_function=scoring_function,
        one_time_evaluators=[],
        training_evaluators=training_evaluators,
        validation_evaluators=validation_evaluators,
        max_iterations_with_no_improvement=2000,
        train_predictor=train_predictor,
        validation_predictor=validation_predictor,
        train_dataloader_factory=train_dataloader_factory,
        validation_dataloader_factory=validation_dataloader_factory)

    return context