def get_train_transform(landmarks_path, resection_params=None):
    spatial_transform = tio.Compose((
        tio.OneOf({
            tio.RandomAffine(): 0.9,
            tio.RandomElasticDeformation(): 0.1,
        }),
        tio.RandomFlip(),
    ))
    resolution_transform = tio.OneOf(
        (
            tio.RandomAnisotropy(),
            tio.RandomBlur(),
        ),
        p=0.75,
    )
    transforms = []
    if resection_params is not None:
        transforms.append(get_simulation_transform(resection_params))
    if landmarks_path is not None:
        transforms.append(
            tio.HistogramStandardization({'image': landmarks_path}))
    transforms.extend([
        # tio.RandomGamma(p=0.2),
        resolution_transform,
        tio.RandomGhosting(p=0.2),
        tio.RandomSpike(p=0.2),
        tio.RandomMotion(p=0.2),
        tio.RandomBiasField(p=0.5),
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.RandomNoise(p=0.75),  # always after ZNorm and after blur!
        spatial_transform,
        get_tight_crop(),
    ])
    return tio.Compose(transforms)
Exemple #2
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
Exemple #3
0
def byol_aug(filename):
    """
        BYOL minimizes the distance between representations of each sample and a transformation of that sample.
        Examples of transformations include: translation, rotation, blurring, color inversion, color jitter, gaussian noise.

        Return an augmented dataset that consisted the above mentioned transformation. Will be used in the training.
        """
    image = tio.ScalarImage(filename)
    get_foreground = tio.ZNormalization.mean
    training_transform = tio.Compose([
        tio.CropOrPad((180, 220, 170)),  # zero mean, unit variance of foreground
        tio.ZNormalization(
            masking_method=get_foreground),
        tio.RandomBlur(p=0.25),  # blur 25% of times
        tio.RandomNoise(p=0.25),  # Gaussian noise 25% of times
        tio.OneOf({  # either
            tio.RandomAffine(): 0.8,  # random affine
            tio.RandomElasticDeformation(): 0.2,  # or random elastic deformation
        }, p=0.8),  # applied to 80% of images
        tio.RandomBiasField(p=0.3),  # magnetic field inhomogeneity 30% of times
        tio.OneOf({  # either
            tio.RandomMotion(): 1,  # random motion artifact
            tio.RandomSpike(): 2,  # or spikes
            tio.RandomGhosting(): 2,  # or ghosts
        }, p=0.5),  # applied to 50% of images
    ])

    tfs_image = training_transform(image)
    return tfs_image
Exemple #4
0
 def test_no_bias(self):
     transform = tio.RandomBiasField(coefficients=0.)
     transformed = transform(self.sample_subject)
     self.assertTensorAlmostEqual(
         self.sample_subject.t1.data,
         transformed.t1.data,
     )
Exemple #5
0
    def __call__(self, sample, metadata=None):
        if np.random.random() < self.p:
            # Get params
            random_bias_field = tio.Compose([
                tio.RandomBiasField(coefficients=self.coefficients,
                                    order=self.order,
                                    p=self.p)
            ])

            # Save params
            metadata[MetadataKW.BIAS_FIELD] = [random_bias_field]

        else:
            metadata[MetadataKW.BIAS_FIELD] = [None]

        if any(metadata[MetadataKW.BIAS_FIELD]):
            # Apply random bias field
            data_out, history = tio_transform(x=sample,
                                              transform=random_bias_field)

            # Keep data type
            data_out = data_out.astype(sample.dtype)

            # Update metadata to history
            metadata[MetadataKW.BIAS_FIELD] = [history]

            return data_out, metadata

        else:
            return sample, metadata
Exemple #6
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
def get_context(device, variables, augmentation_mode, **kwargs):
    context = base_config.get_context(device, variables, **kwargs)
    context.file_paths.append(os.path.abspath(__file__))
    context.config.update({'augmentation_mode': augmentation_mode})

    # training_transform is a tio.Compose where the second transform is the augmentation
    dataset_defn = context.get_component_definition("dataset")
    training_transform = dataset_defn['params']['transforms']['training']

    dwi_augmentation = ReconstructMeanDWI(num_dwis=(1, 7),
                                          num_directions=(1, 3),
                                          directionality=(4, 10))

    noise = tio.RandomNoise(std=0.035, p=0.3)
    blur = tio.RandomBlur((0, 1), p=0.2)
    standard_augmentations = tio.Compose([
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.RandomElasticDeformation(p=0.5,
                                     num_control_points=(7, 7, 4),
                                     locked_borders=1,
                                     image_interpolation='bspline',
                                     exclude="full_dwi"),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.OneOf([
            tio.Compose([blur, noise]),
            tio.Compose([noise, blur]),
        ])
    ],
                                         exclude="full_dwi")

    if augmentation_mode == 'no_augmentation':
        training_transform.transforms.pop(1)
    elif augmentation_mode == 'standard':
        training_transform.transforms[1] = standard_augmentations
    elif augmentation_mode == 'dwi_reconstruction':
        training_transform.transforms[1] = dwi_augmentation
    elif augmentation_mode == 'combined':
        training_transform.transforms[1] = tio.Compose(
            [dwi_augmentation, standard_augmentations])
    else:
        raise ValueError(f"Invalid augmentation mode {augmentation_mode}")

    return context
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
Exemple #9
0
add_noise = tio.RandomNoise(std=0.5, seed=42)
standard = standardize(fpg_ras)
noisy = add_noise(standard)
show_fpg(noisy)

#MRI-specific transforms
#TorchIO includes some transforms to simulate image artifacts specific to MRI modalities.

#Random bias field (DID NOT WORK - REWORK)
#Magnetic field inhomogeneities in the MRI scanner produce low-frequency intensity distortions
#in the images, which are typically corrected using algorithms such as N4ITK.
#To simulate this artifact, we can use RandomBiasField.

#For this example, we will use an image that has been preprocessed so it's meant to be unbiased.

add_bias = tio.RandomBiasField(coefficients=1, seed=0)
mni_bias = add_bias(fpg_ras)
mni_bias.seg = mni_bias.heart
show_fpg(mni_bias)

#k -space transforms
#MR images are generated by computing the inverse Fourier transform of the k-space, which is the signal received by the coils
#in the scanner. If the k-space is altered, an artifact will be created in the image. These artifacts are typically
# accidental, but we can use transforms to simulate them.

#Random spike
#Sometimes, signal peaks can appear in k-space. If one adds a high-energy component at e.g. 440 Hz in the spectrum of
# an audio signal, a tone of that frequency will be audible in the time domain. Similarly, spikes in k-space manifest
# as stripes in image space. They can be simulated using RandomSpike. The number of spikes doesn't affect the transform
# run time, so try adding more!
Exemple #10
0
def get_context(device, variables, fold=0, **kwargs):
    context = TorchContext(device, name="msseg2", variables=variables)
    context.file_paths.append(os.path.abspath(__file__))
    context.config = config = {'fold': fold, 'patch_size': 96}

    input_images = ["flair_time01", "flair_time02"]

    subject_loader = ComposeLoaders([
        ImageLoader(glob_pattern="flair_time01*",
                    image_name='flair_time01',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="flair_time02*",
                    image_name='flair_time02',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="brain_mask.*",
                    image_name='brain_mask',
                    image_constructor=tio.LabelMap,
                    label_values={"brain": 1}),
        ImageLoader(glob_pattern="ground_truth.*",
                    image_name="ground_truth",
                    image_constructor=tio.LabelMap,
                    label_values={"lesion": 1}),
    ])

    cohorts = {}
    cohorts['all'] = RequireAttributes(input_images)
    cohorts['validation'] = RandomFoldFilter(num_folds=5,
                                             selection=fold,
                                             seed=0xDEADBEEF)
    cohorts['training'] = NegateFilter(cohorts['validation'])

    common_transforms_1 = tio.Compose([
        SetDataType(torch.float),
        EnforceConsistentAffine(source_image_name='flair_time01'),
        TargetResample(target_spacing=1, tolerance=0.11),
        CropToMask('brain_mask'),
        MinSizePad(config['patch_size'])
    ])

    augmentations = tio.Compose([
        RandomPermuteDimensions(),
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.OneOf(
            {
                tio.RandomElasticDeformation():
                0.2,
                tio.RandomAffine(scales=0.2,
                                 degrees=45,
                                 default_pad_value='otsu'):
                0.8,
            },
            p=0.75),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.RandomBlur((0, 1), p=0.2),
        tio.RandomNoise(std=0.1, p=0.35)
    ])

    common_transforms_2 = tio.Compose([
        tio.RescaleIntensity((-1, 1.), (0.05, 99.5)),
        ConcatenateImages(image_names=["flair_time01", "flair_time02"],
                          image_channels=[1, 1],
                          new_image_name="X"),
        RenameProperty(old_name='ground_truth', new_name='y'),
        CustomOneHot(include="y"),
    ])

    transforms = {
        'default':
        tio.Compose([common_transforms_1, common_transforms_2]),
        'training':
        tio.Compose([
            common_transforms_1, augmentations, common_transforms_2,
            ImageFromLabels(new_image_name="patch_probability",
                            label_weights=[('brain_mask', 'brain', 1),
                                           ('y', 'lesion', 100)])
        ]),
    }

    context.add_component("dataset",
                          SubjectFolder,
                          root='$DATASET_PATH',
                          subject_path="",
                          subject_loader=subject_loader,
                          cohorts=cohorts,
                          transforms=transforms)
    context.add_component("model",
                          ModularUNet,
                          in_channels=2,
                          out_channels=2,
                          filters=[40, 40, 80, 80, 120, 120],
                          depth=6,
                          block_params={'residual': True},
                          downsample_class=BlurConv3d,
                          downsample_params={
                              'kernel_size': 3,
                              'stride': 2,
                              'padding': 1
                          },
                          upsample_class=BlurConvTranspose3d,
                          upsample_params={
                              'kernel_size': 3,
                              'stride': 2,
                              'padding': 1,
                              'output_padding': 0
                          })
    context.add_component("optimizer",
                          SGD,
                          params="self.model.parameters()",
                          lr=0.001,
                          momentum=0.95)
    context.add_component("criterion",
                          HybridLogisticDiceLoss,
                          logistic_class_weights=[1, 100])

    training_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            'y_pred_eval', 'y_eval'),
                            log_name='training_segmentation_eval',
                            interval=15),
        ScheduledEvaluation(evaluator=LabelMapEvaluator('y_pred_eval'),
                            log_name='training_label_eval',
                            interval=15),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "random",
            'flair_time02',
            'y_pred_eval',
            'y_eval',
            slice_id=0,
            legend=True,
            ncol=2,
            interesting_slice=True,
            split_subjects=False),
                            log_name=f"contour_image",
                            interval=15),
    ]

    validation_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            "y_pred_eval", "y_eval"),
                            log_name="segmentation_eval",
                            cohorts=["validation"],
                            interval=50),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "interesting",
            'flair_time02',
            'y_pred_eval',
            'y_eval',
            slice_id=0,
            legend=True,
            ncol=1,
            interesting_slice=True,
            split_subjects=True),
                            log_name=f"contour_image",
                            cohorts=["validation"],
                            interval=50),
    ]

    def scoring_function(evaluation_dict):
        # Grab the output of the SegmentationEvaluator
        seg_eval = evaluation_dict['segmentation_eval']['validation']

        # Take mean dice, while accounting for subjects which have no lesions.
        # Dice is 0/0 = nan when the model correctly outputs no lesions. This is counted as a score of 1.0.
        # Dice is (>0)/0 = posinf when the model incorrectly predicts lesions when there are none.
        # This is counted as a score of 0.0.
        dice = torch.tensor(seg_eval["subject_stats"]['dice.lesion'])
        dice = dice.nan_to_num(nan=1.0, posinf=0.0)
        score = dice.mean()

        return score

    train_predictor = StandardPredict(image_names=['X', 'y'])
    validation_predictor = PatchPredict(patch_batch_size=32,
                                        patch_size=config['patch_size'],
                                        patch_overlap=(config['patch_size'] //
                                                       8),
                                        padding_mode=None,
                                        overlap_mode='average',
                                        image_names=['X'])

    patch_sampler = tio.WeightedSampler(patch_size=config['patch_size'],
                                        probability_map='patch_probability')
    train_dataloader_factory = PatchDataLoader(max_length=100,
                                               samples_per_volume=1,
                                               sampler=patch_sampler)
    validation_dataloader_factory = StandardDataLoader(
        sampler=SequentialSampler)

    context.add_component(
        "trainer",
        SegmentationTrainer,
        training_batch_size=4,
        save_rate=100,
        scoring_interval=50,
        scoring_function=scoring_function,
        one_time_evaluators=[],
        training_evaluators=training_evaluators,
        validation_evaluators=validation_evaluators,
        max_iterations_with_no_improvement=2000,
        train_predictor=train_predictor,
        validation_predictor=validation_predictor,
        train_dataloader_factory=train_dataloader_factory,
        validation_dataloader_factory=validation_dataloader_factory)

    return context
Exemple #11
0
 def test_wrong_coefficient_type(self):
     with self.assertRaises(ValueError):
         tio.RandomBiasField(coefficients='wrong')
Exemple #12
0
normalization = tio.ZNormalization(masking_method=tio.ZNormalization.mean)
onehot = tio.OneHot()

prefix += '_bias_flip_elastic_noise'

# spatial = tio.OneOf({
#     tio.RandomAffine(scales=0.1,degrees=30): 0.8,
#     tio.RandomElasticDeformation(): 0.2,
#   },
#   p=0.75,
# )

spatial = tio.RandomAffine(scales=0.05,degrees=10,p=0.75)

bias = tio.RandomBiasField(p=0.3)
flip = tio.RandomFlip(axes=('LR',), flip_probability=0.5)
noise = tio.RandomNoise(std=0.25, p=0.25)

transforms = [bias, normalization, flip, spatial, noise, onehot]#, bias, normalization, flip, spatial, noise, onehot]

training_transform = tio.Compose(transforms)
validation_transform = tio.Compose([normalization, onehot])

#subject = dataset[0]
#transformed_subject = training_transform(subject)
#transformed_subject.plot()


#%%
Exemple #13
0
 def test_small_image(self):
     # https://github.com/fepegar/torchio/issues/300
     tio.RandomBiasField()(torch.rand(1, 2, 3, 4))
Exemple #14
0
 def test_wrong_order_type(self):
     with self.assertRaises(TypeError):
         tio.RandomBiasField(order='wrong')
Exemple #15
0
 def test_negative_order(self):
     with self.assertRaises(ValueError):
         tio.RandomBiasField(order=-1)
                zt1=ScannerInfoT1[site],
                zt2=ScannerInfoT2[site],
                zpd=ScannerInfoPD[site],
            )
            subjects.append(subject)

    subjects = subjects[:max_subjects]
    dataset = tio.SubjectsDataset(subjects)
    print('Dataset size:', len(dataset), 'subjects')

    #%%
    normalization = tio.ZNormalization()

    spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75)

    bias = tio.RandomBiasField(coefficients=0.5, p=0.3)
    flip = tio.RandomFlip(axes=('LR', ), flip_probability=0.5)
    noise = tio.RandomNoise(std=0.1, p=0.25)

    transforms = [flip, spatial, normalization]

    training_transform = tio.Compose(transforms)
    validation_transform = tio.Compose([normalization])

    #%%
    seed = 42  # for reproducibility

    num_subjects = len(dataset)
    num_training_subjects = int(training_split_ratio * num_subjects)
    num_validation_subjects = num_subjects - num_training_subjects
def get_context(
    device,
    variables,
    fold=0,
    predict_hbt=False,
    training_batch_size=4,
):
    context = TorchContext(device, name="dmri-hippo", variables=variables)
    context.file_paths.append(os.path.abspath(__file__))
    context.config.update({'fold': fold})

    input_images = ["mean_dwi", "md", "fa"]
    output_labels = ["whole_roi", "hbt_roi"]

    subject_loader = ComposeLoaders([
        ImageLoader(glob_pattern="mean_dwi.*",
                    image_name='mean_dwi',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="md.*",
                    image_name='md',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="fa.*",
                    image_name='fa',
                    image_constructor=tio.ScalarImage),
        # ImageLoader(glob_pattern="full_dwi.*", image_name='full_dwi', image_constructor=tio.ScalarImage),
        # TensorLoader(glob_pattern="full_dwi_grad.b", tensor_name="grad", belongs_to="full_dwi"),
        ImageLoader(glob_pattern="whole_roi.*",
                    image_name="whole_roi",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_whole": 1,
                        "right_whole": 2
                    }),
        ImageLoader(glob_pattern="whole_roi_alt.*",
                    image_name="whole_roi_alt",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_whole": 1,
                        "right_whole": 2
                    }),
        ImageLoader(glob_pattern="hbt_roi.*",
                    image_name="hbt_roi",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_head": 1,
                        "left_body": 2,
                        "left_tail": 3,
                        "right_head": 4,
                        "right_body": 5,
                        "right_tail": 6
                    }),
        ImageLoader(glob_pattern="../../atlas/whole_roi_union.*",
                    image_name="whole_roi_union",
                    image_constructor=tio.LabelMap,
                    uniform=True),
        AttributeLoader(glob_pattern='attributes.*'),
        AttributeLoader(
            glob_pattern='../../attributes/cross_validation_split.json',
            multi_subject=True,
            uniform=True),
        AttributeLoader(
            glob_pattern='../../attributes/ab300_validation_subjects.json',
            multi_subject=True,
            uniform=True),
        AttributeLoader(
            glob_pattern='../../attributes/cbbrain_test_subjects.json',
            multi_subject=True,
            uniform=True),
    ])

    cohorts = {}
    cohorts['all'] = RequireAttributes(input_images)
    cohorts['cross_validation'] = RequireAttributes(['fold'])
    cohorts['training'] = ComposeFilters(
        [cohorts['cross_validation'],
         ForbidAttributes({"fold": fold})])
    cohorts['cbbrain_validation'] = ComposeFilters(
        [cohorts['cross_validation'],
         RequireAttributes({"fold": fold})])
    cohorts['cbbrain_test'] = RequireAttributes({'cbbrain_test': True})
    cohorts['ab300_validation'] = RequireAttributes({'ab300_validation': True})
    cohorts['ab300_validation_plot'] = ComposeFilters(
        [cohorts['ab300_validation'],
         RandomSelectFilter(num_subjects=20)])
    cohorts['cbbrain'] = RequireAttributes({"protocol": "cbbrain"})
    cohorts['ab300'] = RequireAttributes({"protocol": "ab300"})
    cohorts['rescans'] = ForbidAttributes({"rescan_id": "None"})
    cohorts['fasd'] = RequireAttributes({"pathologies": "FASD"})
    cohorts['inter_rater'] = RequireAttributes(["whole_roi_alt"])

    common_transforms_1 = tio.Compose([
        tio.CropOrPad((96, 88, 24),
                      padding_mode='minimum',
                      mask_name='whole_roi_union'),
        CustomRemapLabels(remapping=[("right_whole", 2, 1)],
                          masking_method="Right",
                          include=["whole_roi"]),
        CustomRemapLabels(remapping=[("right_head", 4, 1),
                                     ("right_body", 5, 2),
                                     ("right_tail", 6, 3)],
                          masking_method="Right",
                          include=["hbt_roi"]),
    ])

    noise = tio.RandomNoise(std=0.035, p=0.3)
    blur = tio.RandomBlur((0, 1), p=0.2)
    standard_augmentations = tio.Compose([
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.RandomElasticDeformation(p=0.5,
                                     num_control_points=(7, 7, 4),
                                     locked_borders=1,
                                     image_interpolation='bspline',
                                     exclude=["full_dwi"]),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.OneOf([
            tio.Compose([blur, noise]),
            tio.Compose([noise, blur]),
        ])
    ],
                                         exclude="full_dwi")

    common_transforms_2 = tio.Compose([
        tio.RescaleIntensity((-1., 1.), (0.5, 99.5)),
        ConcatenateImages(image_names=["mean_dwi", "md", "fa"],
                          image_channels=[1, 1, 1],
                          new_image_name="X"),
        RenameProperty(old_name="hbt_roi" if predict_hbt else "whole_roi",
                       new_name="y"),
        CustomOneHot(include=["y"])
    ])

    transforms = {
        'default':
        tio.Compose([common_transforms_1, common_transforms_2]),
        'training':
        tio.Compose(
            [common_transforms_1, standard_augmentations,
             common_transforms_2]),
    }

    context.add_component("dataset",
                          SubjectFolder,
                          root='$DATASET_PATH',
                          subject_path="subjects",
                          subject_loader=subject_loader,
                          cohorts=cohorts,
                          transforms=transforms)
    context.add_component("model",
                          NestedResUNet,
                          input_channels=3,
                          output_channels=4 if predict_hbt else 2,
                          filters=40,
                          dropout_p=0.2)
    context.add_component("optimizer",
                          Adam,
                          params="self.model.parameters()",
                          lr=0.0002)
    context.add_component("criterion", HybridLogisticDiceLoss)

    training_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            'y_pred_eval', 'y_eval'),
                            log_name='training_segmentation_eval',
                            interval=10),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "Axial",
            'mean_dwi',
            'y_pred_eval',
            'y_eval',
            slice_id=12,
            legend=True,
            ncol=2,
            split_subjects=False),
                            log_name=f"contour_image_training",
                            interval=50),
    ]

    curve_params = {
        "left_whole":
        np.array([-1.96312119e-01, 9.46668029e+00, 2.33635173e+03]),
        "right_whole":
        np.array([-2.68467331e-01, 1.67925603e+01, 2.07224236e+03])
    }

    validation_evaluators = [
        ScheduledEvaluation(evaluator=LabelMapEvaluator(
            'y_pred_eval',
            curve_params=curve_params,
            curve_attribute='age',
            stats_to_output=('volume', 'error', 'absolute_error',
                             'squared_error', 'percent_diff')),
                            log_name="predicted_label_eval",
                            cohorts=['cbbrain_validation', 'ab300_validation'],
                            interval=50),
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            "y_pred_eval", "y_eval"),
                            log_name="segmentation_eval",
                            cohorts=['cbbrain_validation'],
                            interval=50),
        ScheduledEvaluation(
            evaluator=ContourImageEvaluator("Axial",
                                            "mean_dwi",
                                            "y_pred_eval",
                                            "y_eval",
                                            slice_id=10,
                                            legend=True,
                                            ncol=5,
                                            split_subjects=False),
            log_name="contour_image_axial",
            cohorts=['cbbrain_validation', 'ab300_validation_plot'],
            interval=250),
        ScheduledEvaluation(
            evaluator=ContourImageEvaluator("Coronal",
                                            "mean_dwi",
                                            "y_pred_eval",
                                            "y_eval",
                                            slice_id=44,
                                            legend=True,
                                            ncol=2,
                                            split_subjects=False),
            log_name="contour_image_coronal",
            cohorts=['cbbrain_validation', 'ab300_validation_plot'],
            interval=250),
    ]

    def scoring_function(evaluation_dict):
        # Grab the output of the SegmentationEvaluator
        seg_eval_cbbrain = evaluation_dict['segmentation_eval'][
            'cbbrain_validation']["summary_stats"]

        # Get the mean dice for each label (the mean is across subjects)
        cbbrain_dice = seg_eval_cbbrain['mean', :, 'dice']

        # Now take the mean across all labels
        cbbrain_dice = cbbrain_dice.mean()
        score = cbbrain_dice
        return score

    train_predictor = StandardPredict(sagittal_split=True,
                                      image_names=['X', 'y'])
    validation_predictor = StandardPredict(sagittal_split=True,
                                           image_names=['X'])

    train_dataloader_factory = StandardDataLoader(sampler=RandomSampler)
    validation_dataloader_factory = StandardDataLoader(
        sampler=SequentialSampler)

    context.add_component(
        "trainer",
        SegmentationTrainer,
        training_batch_size=training_batch_size,
        save_rate=100,
        scoring_interval=50,
        scoring_function=scoring_function,
        one_time_evaluators=[],
        training_evaluators=training_evaluators,
        validation_evaluators=validation_evaluators,
        max_iterations_with_no_improvement=2000,
        train_predictor=train_predictor,
        validation_predictor=validation_predictor,
        train_dataloader_factory=train_dataloader_factory,
        validation_dataloader_factory=validation_dataloader_factory)

    return context