def get_train_transform(landmarks_path, resection_params=None): spatial_transform = tio.Compose(( tio.OneOf({ tio.RandomAffine(): 0.9, tio.RandomElasticDeformation(): 0.1, }), tio.RandomFlip(), )) resolution_transform = tio.OneOf( ( tio.RandomAnisotropy(), tio.RandomBlur(), ), p=0.75, ) transforms = [] if resection_params is not None: transforms.append(get_simulation_transform(resection_params)) if landmarks_path is not None: transforms.append( tio.HistogramStandardization({'image': landmarks_path})) transforms.extend([ # tio.RandomGamma(p=0.2), resolution_transform, tio.RandomGhosting(p=0.2), tio.RandomSpike(p=0.2), tio.RandomMotion(p=0.2), tio.RandomBiasField(p=0.5), tio.ZNormalization(masking_method=tio.ZNormalization.mean), tio.RandomNoise(p=0.75), # always after ZNorm and after blur! spatial_transform, get_tight_crop(), ]) return tio.Compose(transforms)
def byol_aug(filename): """ BYOL minimizes the distance between representations of each sample and a transformation of that sample. Examples of transformations include: translation, rotation, blurring, color inversion, color jitter, gaussian noise. Return an augmented dataset that consisted the above mentioned transformation. Will be used in the training. """ image = tio.ScalarImage(filename) get_foreground = tio.ZNormalization.mean training_transform = tio.Compose([ tio.CropOrPad((180, 220, 170)), # zero mean, unit variance of foreground tio.ZNormalization( masking_method=get_foreground), tio.RandomBlur(p=0.25), # blur 25% of times tio.RandomNoise(p=0.25), # Gaussian noise 25% of times tio.OneOf({ # either tio.RandomAffine(): 0.8, # random affine tio.RandomElasticDeformation(): 0.2, # or random elastic deformation }, p=0.8), # applied to 80% of images tio.RandomBiasField(p=0.3), # magnetic field inhomogeneity 30% of times tio.OneOf({ # either tio.RandomMotion(): 1, # random motion artifact tio.RandomSpike(): 2, # or spikes tio.RandomGhosting(): 2, # or ghosts }, p=0.5), # applied to 50% of images ]) tfs_image = training_transform(image) return tfs_image
def test_transforms(self): landmarks_dict = dict( t1=np.linspace(0, 100, 13), t2=np.linspace(0, 100, 13), ) elastic = torchio.RandomElasticDeformation(max_displacement=1) transforms = ( torchio.CropOrPad((9, 21, 30)), torchio.ToCanonical(), torchio.Resample((1, 1.1, 1.25)), torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1), torchio.RandomMotion(), torchio.RandomGhosting(axes=(0, 1, 2)), torchio.RandomSpike(), torchio.RandomNoise(), torchio.RandomBlur(), torchio.RandomSwap(patch_size=2, num_iterations=5), torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY), torchio.RandomBiasField(), torchio.RescaleIntensity((0, 1)), torchio.ZNormalization(masking_method='label'), torchio.HistogramStandardization(landmarks_dict=landmarks_dict), elastic, torchio.RandomAffine(), torchio.OneOf({ torchio.RandomAffine(): 3, elastic: 1 }), torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3), torchio.Crop((3, 2, 8, 0, 1, 4)), ) transform = torchio.Compose(transforms) transform(self.sample)
def getContrastAugs(p=0.75): aug_dict = { AdjustSigmoid(): 0.30, AdjustLog(): 0.30, AdjustGamma(): 0.30, AdaptiveHistogramEqualization(): 0.10, } return torchvision.transforms.RandomApply([tio.OneOf(aug_dict)], p=p)
def __getitem__(self, index): img_path = os.path.join(self.root_dir, self.df.iloc[index, 1]) subject = tio.Subject(img=Image(img_path, type=tio.INTENSITY)) if (self.transform): # in training phase transformations = ( tio.ZNormalization(), tio.Resample(target=2, pre_affine_name="affine"), # preprocessing tio.OneOf(transforms_dict), tio.OneOf(transforms_dict2)) else: # in validation and testing phase transformations = ( tio.ZNormalization(), tio.Resample(target=2, pre_affine_name="affine") # preprocessing ) transformations = Compose(transformations) transformed_image = transformations(subject) get_image = transformed_image.img tensor_resampled_image = get_image.data tensor_resampled_image = tensor_resampled_image.unsqueeze( dim=0) # adding batch size resampled_image = torch.nn.functional.interpolate( input=tensor_resampled_image, size=(256, 256, 166), mode='trilinear' ) # trilinear because it had 5D (mini-batch x channels x height x width x depth) resampled_image = np.reshape(resampled_image, (1, 256, 256, 166)) y_label = 0.0 if self.df.iloc[index, 2] == 'AD' else 1.0 y_label = torch.tensor(y_label, dtype=torch.float) return resampled_image, y_label
def get_transform(augmentation, landmarks_path): import datasets import torchio as tio if augmentation: return datasets.get_train_transform(landmarks_path) else: preprocess = datasets.get_test_transform(landmarks_path) augment = tio.Compose((tio.RandomFlip(), tio.OneOf({ tio.RandomAffine(): 0.8, tio.RandomElasticDeformation(): 0.2, }))) return tio.Compose((preprocess, augment))
def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = { channel: np.linspace(0, 100, 13) for channel in channels } disp = 1 if is_3d else (1, 1, 0.01) elastic = tio.RandomElasticDeformation(max_displacement=disp) cp_args = (9, 21, 30) if is_3d else (21, 30, 1) resize_args = (10, 20, 30) if is_3d else (10, 20, 1) flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) remapping = {1: 2, 2: 1, 3: 20, 4: 25} transforms = [ tio.CropOrPad(cp_args), tio.EnsureShapeMultiple(2, method='crop'), tio.Resize(resize_args), tio.ToCanonical(), tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), tio.CopyAffine(channels[0]), tio.Resample((1, 1.1, 1.25)), tio.RandomFlip(axes=flip_axes, flip_probability=1), tio.RandomMotion(), tio.RandomGhosting(axes=(0, 1, 2)), tio.RandomSpike(), tio.RandomNoise(), tio.RandomBlur(), tio.RandomSwap(patch_size=swap_patch, num_iterations=5), tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), tio.RandomBiasField(), tio.RescaleIntensity(out_min_max=(0, 1)), tio.ZNormalization(), tio.HistogramStandardization(landmarks_dict), elastic, tio.RandomAffine(), tio.OneOf({ tio.RandomAffine(): 3, elastic: 1, }), tio.RemapLabels(remapping=remapping, masking_method='Left'), tio.RemoveLabels([1, 3]), tio.SequentialLabels(), tio.Pad(pad_args, padding_mode=3), tio.Crop(crop_args), ] if labels: transforms.append(tio.RandomLabelsToImage(label_key='label')) return tio.Compose(transforms)
def get_context(device, variables, augmentation_mode, **kwargs): context = base_config.get_context(device, variables, **kwargs) context.file_paths.append(os.path.abspath(__file__)) context.config.update({'augmentation_mode': augmentation_mode}) # training_transform is a tio.Compose where the second transform is the augmentation dataset_defn = context.get_component_definition("dataset") training_transform = dataset_defn['params']['transforms']['training'] dwi_augmentation = ReconstructMeanDWI(num_dwis=(1, 7), num_directions=(1, 3), directionality=(4, 10)) noise = tio.RandomNoise(std=0.035, p=0.3) blur = tio.RandomBlur((0, 1), p=0.2) standard_augmentations = tio.Compose([ tio.RandomFlip(axes=(0, 1, 2)), tio.RandomElasticDeformation(p=0.5, num_control_points=(7, 7, 4), locked_borders=1, image_interpolation='bspline', exclude="full_dwi"), tio.RandomBiasField(p=0.5), tio.RescaleIntensity((0, 1), (0.01, 99.9)), tio.RandomGamma(p=0.8), tio.RescaleIntensity((-1, 1)), tio.OneOf([ tio.Compose([blur, noise]), tio.Compose([noise, blur]), ]) ], exclude="full_dwi") if augmentation_mode == 'no_augmentation': training_transform.transforms.pop(1) elif augmentation_mode == 'standard': training_transform.transforms[1] = standard_augmentations elif augmentation_mode == 'dwi_reconstruction': training_transform.transforms[1] = dwi_augmentation elif augmentation_mode == 'combined': training_transform.transforms[1] = tio.Compose( [dwi_augmentation, standard_augmentations]) else: raise ValueError(f"Invalid augmentation mode {augmentation_mode}") return context
def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = { channel: np.linspace(0, 100, 13) for channel in channels } disp = 1 if is_3d else (1, 1, 0.01) elastic = torchio.RandomElasticDeformation(max_displacement=disp) cp_args = (9, 21, 30) if is_3d else (21, 30, 1) flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) transforms = [ torchio.CropOrPad(cp_args), torchio.ToCanonical(), torchio.RandomDownsample(downsampling=(1.75, 2), axes=axes_downsample), torchio.Resample((1, 1.1, 1.25)), torchio.RandomFlip(axes=flip_axes, flip_probability=1), torchio.RandomMotion(), torchio.RandomGhosting(axes=(0, 1, 2)), torchio.RandomSpike(), torchio.RandomNoise(), torchio.RandomBlur(), torchio.RandomSwap(patch_size=swap_patch, num_iterations=5), torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY), torchio.RandomBiasField(), torchio.RescaleIntensity((0, 1)), torchio.ZNormalization(), torchio.HistogramStandardization(landmarks_dict), elastic, torchio.RandomAffine(), torchio.OneOf({ torchio.RandomAffine(): 3, elastic: 1, }), torchio.Pad(pad_args, padding_mode=3), torchio.Crop(crop_args), ] if labels: transforms.append(torchio.RandomLabelsToImage(label_key='label')) return torchio.Compose(transforms)
import pandas as pd import torch import torchio as tio from torch.utils.data import DataLoader from torch.utils.data import Dataset from torchio import Image from torchvision.transforms import Compose transforms_dict = { tio.RandomAffine(): 0.55, tio.RandomElasticDeformation(): 0.25 } transforms_dict2 = {tio.RandomBlur(): 0.25, tio.RandomMotion(): 0.25} # for aumentation transform_flip = tio.OneOf(transforms_dict) class ADNIDataloaderAllData(Dataset): def __init__(self, df, root_dir, transform): self.df = df self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.df) def __getitem__(self, index): img_path = os.path.join(self.root_dir, self.df.iloc[index, 1]) subject = tio.Subject(img=Image(img_path, type=tio.INTENSITY))
def get_context(device, variables, fold=0, **kwargs): context = TorchContext(device, name="msseg2", variables=variables) context.file_paths.append(os.path.abspath(__file__)) context.config = config = {'fold': fold, 'patch_size': 96} input_images = ["flair_time01", "flair_time02"] subject_loader = ComposeLoaders([ ImageLoader(glob_pattern="flair_time01*", image_name='flair_time01', image_constructor=tio.ScalarImage), ImageLoader(glob_pattern="flair_time02*", image_name='flair_time02', image_constructor=tio.ScalarImage), ImageLoader(glob_pattern="brain_mask.*", image_name='brain_mask', image_constructor=tio.LabelMap, label_values={"brain": 1}), ImageLoader(glob_pattern="ground_truth.*", image_name="ground_truth", image_constructor=tio.LabelMap, label_values={"lesion": 1}), ]) cohorts = {} cohorts['all'] = RequireAttributes(input_images) cohorts['validation'] = RandomFoldFilter(num_folds=5, selection=fold, seed=0xDEADBEEF) cohorts['training'] = NegateFilter(cohorts['validation']) common_transforms_1 = tio.Compose([ SetDataType(torch.float), EnforceConsistentAffine(source_image_name='flair_time01'), TargetResample(target_spacing=1, tolerance=0.11), CropToMask('brain_mask'), MinSizePad(config['patch_size']) ]) augmentations = tio.Compose([ RandomPermuteDimensions(), tio.RandomFlip(axes=(0, 1, 2)), tio.OneOf( { tio.RandomElasticDeformation(): 0.2, tio.RandomAffine(scales=0.2, degrees=45, default_pad_value='otsu'): 0.8, }, p=0.75), tio.RandomBiasField(p=0.5), tio.RescaleIntensity((0, 1), (0.01, 99.9)), tio.RandomGamma(p=0.8), tio.RescaleIntensity((-1, 1)), tio.RandomBlur((0, 1), p=0.2), tio.RandomNoise(std=0.1, p=0.35) ]) common_transforms_2 = tio.Compose([ tio.RescaleIntensity((-1, 1.), (0.05, 99.5)), ConcatenateImages(image_names=["flair_time01", "flair_time02"], image_channels=[1, 1], new_image_name="X"), RenameProperty(old_name='ground_truth', new_name='y'), CustomOneHot(include="y"), ]) transforms = { 'default': tio.Compose([common_transforms_1, common_transforms_2]), 'training': tio.Compose([ common_transforms_1, augmentations, common_transforms_2, ImageFromLabels(new_image_name="patch_probability", label_weights=[('brain_mask', 'brain', 1), ('y', 'lesion', 100)]) ]), } context.add_component("dataset", SubjectFolder, root='$DATASET_PATH', subject_path="", subject_loader=subject_loader, cohorts=cohorts, transforms=transforms) context.add_component("model", ModularUNet, in_channels=2, out_channels=2, filters=[40, 40, 80, 80, 120, 120], depth=6, block_params={'residual': True}, downsample_class=BlurConv3d, downsample_params={ 'kernel_size': 3, 'stride': 2, 'padding': 1 }, upsample_class=BlurConvTranspose3d, upsample_params={ 'kernel_size': 3, 'stride': 2, 'padding': 1, 'output_padding': 0 }) context.add_component("optimizer", SGD, params="self.model.parameters()", lr=0.001, momentum=0.95) context.add_component("criterion", HybridLogisticDiceLoss, logistic_class_weights=[1, 100]) training_evaluators = [ ScheduledEvaluation(evaluator=SegmentationEvaluator( 'y_pred_eval', 'y_eval'), log_name='training_segmentation_eval', interval=15), ScheduledEvaluation(evaluator=LabelMapEvaluator('y_pred_eval'), log_name='training_label_eval', interval=15), ScheduledEvaluation(evaluator=ContourImageEvaluator( "random", 'flair_time02', 'y_pred_eval', 'y_eval', slice_id=0, legend=True, ncol=2, interesting_slice=True, split_subjects=False), log_name=f"contour_image", interval=15), ] validation_evaluators = [ ScheduledEvaluation(evaluator=SegmentationEvaluator( "y_pred_eval", "y_eval"), log_name="segmentation_eval", cohorts=["validation"], interval=50), ScheduledEvaluation(evaluator=ContourImageEvaluator( "interesting", 'flair_time02', 'y_pred_eval', 'y_eval', slice_id=0, legend=True, ncol=1, interesting_slice=True, split_subjects=True), log_name=f"contour_image", cohorts=["validation"], interval=50), ] def scoring_function(evaluation_dict): # Grab the output of the SegmentationEvaluator seg_eval = evaluation_dict['segmentation_eval']['validation'] # Take mean dice, while accounting for subjects which have no lesions. # Dice is 0/0 = nan when the model correctly outputs no lesions. This is counted as a score of 1.0. # Dice is (>0)/0 = posinf when the model incorrectly predicts lesions when there are none. # This is counted as a score of 0.0. dice = torch.tensor(seg_eval["subject_stats"]['dice.lesion']) dice = dice.nan_to_num(nan=1.0, posinf=0.0) score = dice.mean() return score train_predictor = StandardPredict(image_names=['X', 'y']) validation_predictor = PatchPredict(patch_batch_size=32, patch_size=config['patch_size'], patch_overlap=(config['patch_size'] // 8), padding_mode=None, overlap_mode='average', image_names=['X']) patch_sampler = tio.WeightedSampler(patch_size=config['patch_size'], probability_map='patch_probability') train_dataloader_factory = PatchDataLoader(max_length=100, samples_per_volume=1, sampler=patch_sampler) validation_dataloader_factory = StandardDataLoader( sampler=SequentialSampler) context.add_component( "trainer", SegmentationTrainer, training_batch_size=4, save_rate=100, scoring_interval=50, scoring_function=scoring_function, one_time_evaluators=[], training_evaluators=training_evaluators, validation_evaluators=validation_evaluators, max_iterations_with_no_improvement=2000, train_predictor=train_predictor, validation_predictor=validation_predictor, train_dataloader_factory=train_dataloader_factory, validation_dataloader_factory=validation_dataloader_factory) return context
# print(0) # # # plt.imshow(one_subject.mri.data[0, 40, :, :]) # # print(1) training_transform = tio.Compose( [ tio.ToCanonical(), tio.RandomBlur(std=(0, 1), seed=seed, p=0.1), # blur 50% of times tio.RandomNoise(std=5, seed=1, p=0.5), # Gaussian noise 50% of times tio.OneOf( { # either tio.RandomAffine(scales=(0.95, 1.05), degrees=5, seed=seed): 0.75, # random affine tio.RandomElasticDeformation(max_displacement=(5, 5, 5), seed=seed): 0.25, # or random elastic deformation }, p=0.8), # applied to 80% of images ]) for one_subject in dataset: image0 = one_subject.mri plt.imshow(image0.data[0, int(image0.shape[1] / 2), :, :]) plt.show() break dataset_augmented = SubjectsDataset(subjects, transform=training_transform) for one_subject in dataset_augmented: image = one_subject.mri plt.imshow(image.data[0, int(image.shape[1] / 2), :, :])
if data == 'dhcp': normalization = tio.ZNormalization(masking_method='label') spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75) if data == 'equinus': normalization = tio.ZNormalization( masking_method=tio.ZNormalization.mean) #spatial = tio.RandomAffine(scales=0.1,degrees=(20,0,0), translation=0, p=0.75) spatial = tio.OneOf( { tio.RandomAffine(scales=0.1, degrees=(20, 0, 0), translation=0): 0.8, tio.RandomElasticDeformation(): 0.2, }, p=0.75, ) transforms = [flip, spatial, bias, normalization, noise, onehot] training_transform = tio.Compose(transforms) validation_transform = tio.Compose([normalization, onehot]) #%% seed = 42 # for reproducibility num_subjects = len(dataset) num_training_subjects = int(training_split_ratio * num_subjects) num_validation_subjects = num_subjects - num_training_subjects
def get_context( device, variables, fold=0, predict_hbt=False, training_batch_size=4, ): context = TorchContext(device, name="dmri-hippo", variables=variables) context.file_paths.append(os.path.abspath(__file__)) context.config.update({'fold': fold}) input_images = ["mean_dwi", "md", "fa"] output_labels = ["whole_roi", "hbt_roi"] subject_loader = ComposeLoaders([ ImageLoader(glob_pattern="mean_dwi.*", image_name='mean_dwi', image_constructor=tio.ScalarImage), ImageLoader(glob_pattern="md.*", image_name='md', image_constructor=tio.ScalarImage), ImageLoader(glob_pattern="fa.*", image_name='fa', image_constructor=tio.ScalarImage), # ImageLoader(glob_pattern="full_dwi.*", image_name='full_dwi', image_constructor=tio.ScalarImage), # TensorLoader(glob_pattern="full_dwi_grad.b", tensor_name="grad", belongs_to="full_dwi"), ImageLoader(glob_pattern="whole_roi.*", image_name="whole_roi", image_constructor=tio.LabelMap, label_values={ "left_whole": 1, "right_whole": 2 }), ImageLoader(glob_pattern="whole_roi_alt.*", image_name="whole_roi_alt", image_constructor=tio.LabelMap, label_values={ "left_whole": 1, "right_whole": 2 }), ImageLoader(glob_pattern="hbt_roi.*", image_name="hbt_roi", image_constructor=tio.LabelMap, label_values={ "left_head": 1, "left_body": 2, "left_tail": 3, "right_head": 4, "right_body": 5, "right_tail": 6 }), ImageLoader(glob_pattern="../../atlas/whole_roi_union.*", image_name="whole_roi_union", image_constructor=tio.LabelMap, uniform=True), AttributeLoader(glob_pattern='attributes.*'), AttributeLoader( glob_pattern='../../attributes/cross_validation_split.json', multi_subject=True, uniform=True), AttributeLoader( glob_pattern='../../attributes/ab300_validation_subjects.json', multi_subject=True, uniform=True), AttributeLoader( glob_pattern='../../attributes/cbbrain_test_subjects.json', multi_subject=True, uniform=True), ]) cohorts = {} cohorts['all'] = RequireAttributes(input_images) cohorts['cross_validation'] = RequireAttributes(['fold']) cohorts['training'] = ComposeFilters( [cohorts['cross_validation'], ForbidAttributes({"fold": fold})]) cohorts['cbbrain_validation'] = ComposeFilters( [cohorts['cross_validation'], RequireAttributes({"fold": fold})]) cohorts['cbbrain_test'] = RequireAttributes({'cbbrain_test': True}) cohorts['ab300_validation'] = RequireAttributes({'ab300_validation': True}) cohorts['ab300_validation_plot'] = ComposeFilters( [cohorts['ab300_validation'], RandomSelectFilter(num_subjects=20)]) cohorts['cbbrain'] = RequireAttributes({"protocol": "cbbrain"}) cohorts['ab300'] = RequireAttributes({"protocol": "ab300"}) cohorts['rescans'] = ForbidAttributes({"rescan_id": "None"}) cohorts['fasd'] = RequireAttributes({"pathologies": "FASD"}) cohorts['inter_rater'] = RequireAttributes(["whole_roi_alt"]) common_transforms_1 = tio.Compose([ tio.CropOrPad((96, 88, 24), padding_mode='minimum', mask_name='whole_roi_union'), CustomRemapLabels(remapping=[("right_whole", 2, 1)], masking_method="Right", include=["whole_roi"]), CustomRemapLabels(remapping=[("right_head", 4, 1), ("right_body", 5, 2), ("right_tail", 6, 3)], masking_method="Right", include=["hbt_roi"]), ]) noise = tio.RandomNoise(std=0.035, p=0.3) blur = tio.RandomBlur((0, 1), p=0.2) standard_augmentations = tio.Compose([ tio.RandomFlip(axes=(0, 1, 2)), tio.RandomElasticDeformation(p=0.5, num_control_points=(7, 7, 4), locked_borders=1, image_interpolation='bspline', exclude=["full_dwi"]), tio.RandomBiasField(p=0.5), tio.RescaleIntensity((0, 1), (0.01, 99.9)), tio.RandomGamma(p=0.8), tio.RescaleIntensity((-1, 1)), tio.OneOf([ tio.Compose([blur, noise]), tio.Compose([noise, blur]), ]) ], exclude="full_dwi") common_transforms_2 = tio.Compose([ tio.RescaleIntensity((-1., 1.), (0.5, 99.5)), ConcatenateImages(image_names=["mean_dwi", "md", "fa"], image_channels=[1, 1, 1], new_image_name="X"), RenameProperty(old_name="hbt_roi" if predict_hbt else "whole_roi", new_name="y"), CustomOneHot(include=["y"]) ]) transforms = { 'default': tio.Compose([common_transforms_1, common_transforms_2]), 'training': tio.Compose( [common_transforms_1, standard_augmentations, common_transforms_2]), } context.add_component("dataset", SubjectFolder, root='$DATASET_PATH', subject_path="subjects", subject_loader=subject_loader, cohorts=cohorts, transforms=transforms) context.add_component("model", NestedResUNet, input_channels=3, output_channels=4 if predict_hbt else 2, filters=40, dropout_p=0.2) context.add_component("optimizer", Adam, params="self.model.parameters()", lr=0.0002) context.add_component("criterion", HybridLogisticDiceLoss) training_evaluators = [ ScheduledEvaluation(evaluator=SegmentationEvaluator( 'y_pred_eval', 'y_eval'), log_name='training_segmentation_eval', interval=10), ScheduledEvaluation(evaluator=ContourImageEvaluator( "Axial", 'mean_dwi', 'y_pred_eval', 'y_eval', slice_id=12, legend=True, ncol=2, split_subjects=False), log_name=f"contour_image_training", interval=50), ] curve_params = { "left_whole": np.array([-1.96312119e-01, 9.46668029e+00, 2.33635173e+03]), "right_whole": np.array([-2.68467331e-01, 1.67925603e+01, 2.07224236e+03]) } validation_evaluators = [ ScheduledEvaluation(evaluator=LabelMapEvaluator( 'y_pred_eval', curve_params=curve_params, curve_attribute='age', stats_to_output=('volume', 'error', 'absolute_error', 'squared_error', 'percent_diff')), log_name="predicted_label_eval", cohorts=['cbbrain_validation', 'ab300_validation'], interval=50), ScheduledEvaluation(evaluator=SegmentationEvaluator( "y_pred_eval", "y_eval"), log_name="segmentation_eval", cohorts=['cbbrain_validation'], interval=50), ScheduledEvaluation( evaluator=ContourImageEvaluator("Axial", "mean_dwi", "y_pred_eval", "y_eval", slice_id=10, legend=True, ncol=5, split_subjects=False), log_name="contour_image_axial", cohorts=['cbbrain_validation', 'ab300_validation_plot'], interval=250), ScheduledEvaluation( evaluator=ContourImageEvaluator("Coronal", "mean_dwi", "y_pred_eval", "y_eval", slice_id=44, legend=True, ncol=2, split_subjects=False), log_name="contour_image_coronal", cohorts=['cbbrain_validation', 'ab300_validation_plot'], interval=250), ] def scoring_function(evaluation_dict): # Grab the output of the SegmentationEvaluator seg_eval_cbbrain = evaluation_dict['segmentation_eval'][ 'cbbrain_validation']["summary_stats"] # Get the mean dice for each label (the mean is across subjects) cbbrain_dice = seg_eval_cbbrain['mean', :, 'dice'] # Now take the mean across all labels cbbrain_dice = cbbrain_dice.mean() score = cbbrain_dice return score train_predictor = StandardPredict(sagittal_split=True, image_names=['X', 'y']) validation_predictor = StandardPredict(sagittal_split=True, image_names=['X']) train_dataloader_factory = StandardDataLoader(sampler=RandomSampler) validation_dataloader_factory = StandardDataLoader( sampler=SequentialSampler) context.add_component( "trainer", SegmentationTrainer, training_batch_size=training_batch_size, save_rate=100, scoring_interval=50, scoring_function=scoring_function, one_time_evaluators=[], training_evaluators=training_evaluators, validation_evaluators=validation_evaluators, max_iterations_with_no_improvement=2000, train_predictor=train_predictor, validation_predictor=validation_predictor, train_dataloader_factory=train_dataloader_factory, validation_dataloader_factory=validation_dataloader_factory) return context