Ejemplo n.º 1
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
Ejemplo n.º 2
0
def get_context(device, variables, augmentation_mode, **kwargs):
    context = base_config.get_context(device, variables, **kwargs)
    context.file_paths.append(os.path.abspath(__file__))
    context.config.update({'augmentation_mode': augmentation_mode})

    # training_transform is a tio.Compose where the second transform is the augmentation
    dataset_defn = context.get_component_definition("dataset")
    training_transform = dataset_defn['params']['transforms']['training']

    dwi_augmentation = ReconstructMeanDWI(num_dwis=(1, 7),
                                          num_directions=(1, 3),
                                          directionality=(4, 10))

    noise = tio.RandomNoise(std=0.035, p=0.3)
    blur = tio.RandomBlur((0, 1), p=0.2)
    standard_augmentations = tio.Compose([
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.RandomElasticDeformation(p=0.5,
                                     num_control_points=(7, 7, 4),
                                     locked_borders=1,
                                     image_interpolation='bspline',
                                     exclude="full_dwi"),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.OneOf([
            tio.Compose([blur, noise]),
            tio.Compose([noise, blur]),
        ])
    ],
                                         exclude="full_dwi")

    if augmentation_mode == 'no_augmentation':
        training_transform.transforms.pop(1)
    elif augmentation_mode == 'standard':
        training_transform.transforms[1] = standard_augmentations
    elif augmentation_mode == 'dwi_reconstruction':
        training_transform.transforms[1] = dwi_augmentation
    elif augmentation_mode == 'combined':
        training_transform.transforms[1] = tio.Compose(
            [dwi_augmentation, standard_augmentations])
    else:
        raise ValueError(f"Invalid augmentation mode {augmentation_mode}")

    return context
Ejemplo n.º 3
0
 def test_percentiles(self):
     low_quantile = np.percentile(self.sample_subject.t1.data, 5)
     high_quantile = np.percentile(self.sample_subject.t1.data, 95)
     low_indices = (self.sample_subject.t1.data < low_quantile).nonzero(
         as_tuple=True)
     high_indices = (self.sample_subject.t1.data > high_quantile).nonzero(
         as_tuple=True)
     rescale = tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(5, 95))
     transformed = rescale(self.sample_subject)
     assert (transformed.t1.data[low_indices] == 0).all()
     assert (transformed.t1.data[high_indices] == 1).all()
Ejemplo n.º 4
0
 def test_rescale_to_same_intentisy(self):
     min_t1 = float(self.sample_subject.t1.data.min())
     max_t1 = float(self.sample_subject.t1.data.max())
     transform = tio.RescaleIntensity(out_min_max=(min_t1, max_t1))
     transformed = transform(self.sample_subject)
     assert np.allclose(
         transformed.t1.data,
         self.sample_subject.t1.data,
         rtol=0,
         atol=1e-05,
     )
Ejemplo n.º 5
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
Ejemplo n.º 6
0
 def test_masking_using_label(self):
     transform = tio.RescaleIntensity(
         out_min_max=(0, 1), percentiles=(5, 95), masking_method='label')
     transformed = transform(self.sample_subject)
     mask = self.sample_subject.label.data > 0
     low_quantile = np.percentile(self.sample_subject.t1.data[mask], 5)
     high_quantile = np.percentile(self.sample_subject.t1.data[mask], 95)
     low_indices = (self.sample_subject.t1.data < low_quantile).nonzero(
         as_tuple=True)
     high_indices = (self.sample_subject.t1.data > high_quantile).nonzero(
         as_tuple=True)
     self.assertEqual(transformed.t1.data.min(), 0)
     self.assertEqual(transformed.t1.data.max(), 1)
     assert (transformed.t1.data[low_indices] == 0).all()
     assert (transformed.t1.data[high_indices] == 1).all()
Ejemplo n.º 7
0
 def test_ct(self):
     ct_max = 1500
     ct_min = -2000
     ct_range = ct_max - ct_min
     tensor = torch.rand(1, 30, 30, 30) * ct_range + ct_min
     ct = tio.ScalarImage(tensor=tensor)
     ct_air = -1000
     ct_bone = 1000
     rescale = tio.RescaleIntensity(
         out_min_max=(-1, 1),
         in_min_max=(ct_air, ct_bone),
     )
     rescaled = rescale(ct)
     assert rescaled.data.min() < -1
     assert rescaled.data.max() > 1
Ejemplo n.º 8
0
    def test_get_subjects(self):
        ct = tio.ScalarImage(tensor=torch.rand(1, 3, 3, 2))
        structure = tio.LabelMap(
            tensor=torch.ones((1, 3, 3, 2), dtype=torch.uint8))

        subject_1_dir = "tests/test_data/subjects/subject_1"

        os.makedirs(subject_1_dir, exist_ok=True)
        ct.save(os.path.join(subject_1_dir, "ct.nii"))
        structure.save(os.path.join(subject_1_dir, "structure.nii"))
        transform = tio.Compose(
            [tio.ToCanonical(),
             tio.RescaleIntensity(1, (1, 99.0))])
        subject_dataset = get_subjects(os.path.dirname(subject_1_dir),
                                       structures=["structure"],
                                       transform=transform)
        self.assertEqual(len(subject_dataset), 1)
        shutil.rmtree(os.path.dirname(subject_1_dir), ignore_errors=True)
Ejemplo n.º 9
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
Ejemplo n.º 10
0
def dataloader(handles, mode = 'train'):
    # If pickle exists, load it
    try:
        with open('../inputs/flpickles/' + mode + '.pickle', 'rb') as f:
            images = pickle.load(f)
            
    except:
        
        images = {}
        images['Image'] = []
        images['Label'] = []
        images['Gap'] = []
        images['ID'] = []

        # Data augmentations
        random_flip = tio.RandomFlip(axes=1)
        random_flip2 = tio.RandomFlip(axes=2)
        random_affine = tio.RandomAffine(seed=0, scales=(3, 3))
        random_elastic = tio.RandomElasticDeformation(
            max_displacement=(0, 20, 40),
            num_control_points=20,
            seed=0,
        )
        rescale = tio.RescaleIntensity((-1, 1), percentiles=(1, 99))
        standardize_foreground = tio.ZNormalization(masking_method=lambda x: x > x.mean())
        blur = tio.RandomBlur(seed=0)
        standardize = tio.ZNormalization()
        add_noise = tio.RandomNoise(std=0.5, seed=42)
        add_spike = tio.RandomSpike(seed=42)
        add_ghosts = tio.RandomGhosting(intensity=1.5, seed=42)
        add_motion = tio.RandomMotion(num_transforms=6, image_interpolation='nearest', seed=42)
        swap = tio.RandomSwap(patch_size = 7)

        # For each image
        for idx, row in handles.iterrows():
            im_aug = []
            lb_aug = []
            gap_aug = []
            imgs = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)   # change patch shape if necessary
            lbs = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)
            gaps = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)
            im = io.imread(row['Image'])
            im = im / 255 # Normalization
            im = np.expand_dims(im, axis=0)
            imgs[0] = im
            im_aug.append(imgs)
            images['ID'].append(row['ID'])
            if mode == 'train':
                im_flip1 = random_flip(im)
                imgs[0] = im_flip1
                im_aug.append(imgs)
                im_flip2 = random_flip2(im)
                imgs[0] = im_flip2
                im_aug.append(imgs)
                im_affine = random_affine(im)
                imgs[0] = im_affine
                im_aug.append(imgs)
                im_elastic = random_elastic(im)
                imgs[0] = im_elastic
                im_aug.append(imgs)
                im_rescale = rescale(im)
                imgs[0] = im_rescale
                im_aug.append(imgs)
                im_standard = standardize_foreground(im)
                imgs[0] = im_standard
                im_aug.append(imgs)
                im_blur = blur(im)
                imgs[0] = im_blur
                im_aug.append(imgs)
                im_noisy = add_noise(standardize(im))
                imgs[0] = im_noisy
                im_aug.append(imgs)
                im_spike = add_spike(im)
                imgs[0] = im_spike
                im_aug.append(imgs)
                im_ghost = add_ghosts(im)
                imgs[0] = im_ghost
                im_aug.append(imgs)
                im_motion = add_motion(im)
                imgs[0] = im_motion
                im_aug.append(imgs)
                im_swap = swap(im)
                imgs[0] = im_swap
                im_aug.append(imgs)
            images['Image'].append(np.array(im_aug))
            
            if mode != 'test':
                lb = io.imread(row['Label'])
                lb = label_converter(lb)
                lb = np.expand_dims(lb, axis=0)
                lbs[0] = lb
                lb_aug.append(lbs)
                gap = io.imread(row['Gap'])
                gap = np.expand_dims(gap, axis = 0)
                gaps[0] = gap
                gap_aug.append(gaps)
                if mode == 'train':
                    lb_flip1 = random_flip(lb)
                    lbs[0] = lb_flip1
                    lb_aug.append(lbs)
                    lb_flip2 = random_flip2(lb)
                    lbs[0] = lb_flip2
                    lb_aug.append(lbs)
                    lb_affine = random_affine(lb)
                    lbs[0] = lb_affine
                    lb_aug.append(lbs)
                    lb_elastic = random_elastic(lb)
                    lbs[0] = lb_elastic
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)

                    gap_flip1 = random_flip(gap)
                    gaps[0] = gap_flip1
                    gap_aug.append(gaps)
                    gap_flip2 = random_flip2(gap)
                    gaps[0] = gap_flip2
                    gap_aug.append(gaps)
                    gap_affine = random_affine(gap)
                    gaps[0] = gap_affine
                    gap_aug.append(gaps)
                    gap_elastic = random_elastic(gap)
                    gaps[0] = gap_elastic
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                images['Label'].append(np.array(lb_aug))
                images['Gap'].append(np.array(gap_aug))
        # Save images
        with open("../inputs/flpickles/" + mode + '.pickle', 'wb') as f:
            pickle.dump(images, f)
        with open('../inputs/flpickles/' + mode + '.pickle', 'rb') as f:
            images = pickle.load(f)

    return images
def patch_sampler(img_filenames,
                  labelmap_filenames,
                  patch_size,
                  sampler_type,
                  out_dir,
                  max_patches=None,
                  voxel_spacing=(),
                  patch_overlap=(0, 0, 0),
                  min_labeled_voxels=1.0,
                  label_prob=0.8,
                  save_patches=False,
                  batch_size=None,
                  prepare_batches=False,
                  inference=False):
    """Reshape a 3D volumes into a collection of 2D patches
    The resulting patches are allocated in a dedicated array.
    
    Parameters
    ----------
    img_filenames : list of strings  
        Paths to images to extract patches from 
    patch_size : tuple of ints (patch_x, patch_y, patch_z)
        The dimensions of one patch
    patch_overlap : tuple of ints (0, patch_x, patch_y)
        The maximum patch overlap between the patches 
    min_labeled_voxels is not None: : float between 0 and 1
        The minimum percentage of labeled pixels for a patch. If set to None patches are extracted based on center_voxel.
    labelmap_filenames : list of strings 
        Paths to labelmap
        
    Returns
    -------
    img_patches, label_patches : array, shape = (n_patches, patch_x, patch_y, patch_z, 1)
         The collection of patches extracted from the volumes, where `n_patches`
         is the total number of patches extracted.
    """

    if max_patches is not None:
        max_patches = int(max_patches / len(img_filenames))
    img_patches = []
    label_patches = []
    patch_counter = 0
    save_counter = 0
    img_ids = []
    label_ids = []
    save_size = 1
    if prepare_batches: save_size = batch_size
    print(f'\nExtracting patches from: {img_filenames}\n')
    for i in tqdm(range(len(img_filenames)), leave=False):
        if voxel_spacing:
            util.update_affine(img_filenames[i], labelmap_filenames[i])
        if labelmap_filenames:
            subject = tio.Subject(img=tio.Image(img_filenames[i],
                                                type=tio.INTENSITY),
                                  labelmap=tio.LabelMap(labelmap_filenames[i]))
        # Apply transformations
        #transform = tio.ZNormalization()
        #transformed = transform(subject)
        transform = tio.RescaleIntensity((0, 1))
        transformed = transform(subject)
        if voxel_spacing:
            transform = tio.Resample(voxel_spacing)
            transformed = transform(transformed)
        num_img_patches = 0
        if sampler_type == 'grid':
            sampler = tio.data.GridSampler(transformed, patch_size,
                                           patch_overlap)
            for patch in sampler:
                img_patch = np.array(patch.img.data)
                label_patch = np.array(patch.labelmap.data)
                labeled_voxels = torch.count_nonzero(
                    patch.labelmap.data) >= patch_size[0] * patch_size[
                        1] * patch_size[2] * min_labeled_voxels
                center = label_patch[0,
                                     int(patch_size[0] / 2),
                                     int(patch_size[1] / 2),
                                     int(patch_size[2] / 2)] != 0
                if labeled_voxels or center:
                    img_patches.append(img_patch)
                    label_patches.append(label_patch)
                    patch_counter += 1
                    num_img_patches += 1
                if save_patches:
                    img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save(
                        img_patches, label_patches, img_ids, label_ids,
                        save_counter, patch_counter, save_size, patch_size,
                        inference, out_dir)
                # Check if max_patches for img
                if max_patches is not None:
                    if num_img_patches > max_patches:
                        break
        else:
            # Define sampler
            one_label = 1.0 - label_prob
            label_probabilities = {0: one_label, 1: label_prob}
            sampler = tio.data.LabelSampler(
                patch_size, label_probabilities=label_probabilities)
            if max_patches is None:
                generator = sampler(transformed)
            else:
                generator = sampler(transformed, max_patches)
            for patch in generator:
                img_patches.append(np.array(patch.img.data))
                label_patches.append(np.array(patch.labelmap.data))
                patch_counter += 1
                if save_patches:
                    img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save(
                        img_patches, label_patches, img_ids, label_ids,
                        save_counter, patch_counter, save_size, patch_size,
                        inference, out_dir)
    print(f'Finished extracting patches.')
    if save_patches:
        return img_ids, label_ids
    else:
        if patch_size[0] == 1:
            return np.array(img_patches).reshape(
                len(img_patches), patch_size[1], patch_size[2],
                1), np.array(label_patches).reshape(len(label_patches),
                                                    patch_size[1],
                                                    patch_size[2], 1)
        else:
            return np.array(img_patches).reshape(
                len(img_patches), patch_size[0], patch_size[1], patch_size[2],
                1), np.array(label_patches).reshape(len(label_patches),
                                                    patch_size[1],
                                                    patch_size[2], 1)
Ejemplo n.º 12
0
def create_train_and_test_data_loaders(df, count_train):
    images = []
    regression_targets = []
    sizes = {}
    artifact_column_indices = [
        df.columns.get_loc(c) + 1 for c in artifacts if c in df
    ]
    for row in df.itertuples():
        try:
            exists = row.exists
        except AttributeError:
            exists = True  # assume that it exists by default
        if exists:
            images.append(row.file_path)

            row_targets = [row.overall_qa_assessment]
            for i in range(len(artifacts)):
                artifact_value = row[artifact_column_indices[i]]
                converted_result = convert_bool_to_int(artifact_value)
                row_targets.append(converted_result)
            regression_targets.append(row_targets)

            try:
                size = row.dimensions
                if size not in sizes:
                    sizes[size] = 1
                else:
                    sizes[size] += 1
            except AttributeError:
                pass

    ground_truth = np.asarray(regression_targets)
    count_val = df.shape[0] - count_train
    train_files = [
        torchio.Subject({
            'img': torchio.ScalarImage(img),
            'info': info
        })
        for img, info in zip(images[:count_train], ground_truth[:count_train])
    ]
    val_files = [
        torchio.Subject({
            'img': torchio.ScalarImage(img),
            'info': info
        }) for img, info in zip(images[-count_val:], ground_truth[-count_val:])
    ]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # calculate class weights
    class_count = len(artifacts)
    count0 = [0] * class_count
    count1 = [0] * class_count
    for s in range(count_train):
        for i in range(class_count):
            if regression_targets[s][i + regression_count] == 0:
                count0[i] += 1
            elif regression_targets[s][i + regression_count] == 1:
                count1[i] += 1
            # else ignore the missing data

    weights_array = np.zeros(class_count)
    for i in range(class_count):
        weights_array[i] = count0[i] / (count0[i] + count1[i])
    logger.info(f'weights_array: {weights_array}')
    class_weights = torch.tensor(weights_array, dtype=torch.float).to(device)

    rescale = torchio.RescaleIntensity(out_min_max=(0, 1))
    ghosting = CustomGhosting(p=0.2, intensity=(0.2, 0.8))
    motion = CustomMotion(p=0.15,
                          degrees=5.0,
                          translation=5.0,
                          num_transforms=1)
    inhomogeneity = CustomBiasField(p=0.05)
    spike = CustomSpike(p=0.03, num_spikes=(1, 1))
    # gamma = CustomGamma(p=0.1)  # after quick experimentation: gamma does not appear to help
    noise = CustomNoise(p=0.05)

    transforms = torchio.Compose(
        [rescale, ghosting, motion, inhomogeneity, spike, noise])

    # create a training data loader
    train_ds = torchio.SubjectsDataset(train_files, transform=transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=1,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = torchio.SubjectsDataset(val_files, transform=rescale)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    return train_loader, val_loader, class_weights, sizes
Ejemplo n.º 13
0
import torchio as tio
import matplotlib.pyplot as plt

torch.manual_seed(0)

batch_size = 4
subject = tio.datasets.FPG()
subject.remove_image('seg')
subjects = 4 * [subject]

transform = tio.Compose((
    tio.ToCanonical(),
    tio.RandomGamma(p=0.75),
    tio.RandomBlur(p=0.5),
    tio.RandomFlip(),
    tio.RescaleIntensity(out_min_max=(-1, 1)),
))

dataset = tio.SubjectsDataset(subjects, transform=transform)

transformed = dataset[0]
print('Applied transforms:')  # noqa: T001
pprint.pprint(transformed.history)  # noqa: T003
print('\nComposed transform to reproduce history:')  # noqa: T001
print(transformed.get_composed_history())  # noqa: T001
print('\nComposed transform to invert applied transforms when possible:'
      )  # noqa: T001, E501
print(transformed.get_inverse_transform(ignore_intensity=False))  # noqa: T001

loader = torch.utils.data.DataLoader(
    dataset,
Ejemplo n.º 14
0
def get_context(
    device,
    variables,
    fold=0,
    predict_hbt=False,
    training_batch_size=4,
):
    context = TorchContext(device, name="dmri-hippo", variables=variables)
    context.file_paths.append(os.path.abspath(__file__))
    context.config.update({'fold': fold})

    input_images = ["mean_dwi", "md", "fa"]
    output_labels = ["whole_roi", "hbt_roi"]

    subject_loader = ComposeLoaders([
        ImageLoader(glob_pattern="mean_dwi.*",
                    image_name='mean_dwi',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="md.*",
                    image_name='md',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="fa.*",
                    image_name='fa',
                    image_constructor=tio.ScalarImage),
        # ImageLoader(glob_pattern="full_dwi.*", image_name='full_dwi', image_constructor=tio.ScalarImage),
        # TensorLoader(glob_pattern="full_dwi_grad.b", tensor_name="grad", belongs_to="full_dwi"),
        ImageLoader(glob_pattern="whole_roi.*",
                    image_name="whole_roi",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_whole": 1,
                        "right_whole": 2
                    }),
        ImageLoader(glob_pattern="whole_roi_alt.*",
                    image_name="whole_roi_alt",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_whole": 1,
                        "right_whole": 2
                    }),
        ImageLoader(glob_pattern="hbt_roi.*",
                    image_name="hbt_roi",
                    image_constructor=tio.LabelMap,
                    label_values={
                        "left_head": 1,
                        "left_body": 2,
                        "left_tail": 3,
                        "right_head": 4,
                        "right_body": 5,
                        "right_tail": 6
                    }),
        ImageLoader(glob_pattern="../../atlas/whole_roi_union.*",
                    image_name="whole_roi_union",
                    image_constructor=tio.LabelMap,
                    uniform=True),
        AttributeLoader(glob_pattern='attributes.*'),
        AttributeLoader(
            glob_pattern='../../attributes/cross_validation_split.json',
            multi_subject=True,
            uniform=True),
        AttributeLoader(
            glob_pattern='../../attributes/ab300_validation_subjects.json',
            multi_subject=True,
            uniform=True),
        AttributeLoader(
            glob_pattern='../../attributes/cbbrain_test_subjects.json',
            multi_subject=True,
            uniform=True),
    ])

    cohorts = {}
    cohorts['all'] = RequireAttributes(input_images)
    cohorts['cross_validation'] = RequireAttributes(['fold'])
    cohorts['training'] = ComposeFilters(
        [cohorts['cross_validation'],
         ForbidAttributes({"fold": fold})])
    cohorts['cbbrain_validation'] = ComposeFilters(
        [cohorts['cross_validation'],
         RequireAttributes({"fold": fold})])
    cohorts['cbbrain_test'] = RequireAttributes({'cbbrain_test': True})
    cohorts['ab300_validation'] = RequireAttributes({'ab300_validation': True})
    cohorts['ab300_validation_plot'] = ComposeFilters(
        [cohorts['ab300_validation'],
         RandomSelectFilter(num_subjects=20)])
    cohorts['cbbrain'] = RequireAttributes({"protocol": "cbbrain"})
    cohorts['ab300'] = RequireAttributes({"protocol": "ab300"})
    cohorts['rescans'] = ForbidAttributes({"rescan_id": "None"})
    cohorts['fasd'] = RequireAttributes({"pathologies": "FASD"})
    cohorts['inter_rater'] = RequireAttributes(["whole_roi_alt"])

    common_transforms_1 = tio.Compose([
        tio.CropOrPad((96, 88, 24),
                      padding_mode='minimum',
                      mask_name='whole_roi_union'),
        CustomRemapLabels(remapping=[("right_whole", 2, 1)],
                          masking_method="Right",
                          include=["whole_roi"]),
        CustomRemapLabels(remapping=[("right_head", 4, 1),
                                     ("right_body", 5, 2),
                                     ("right_tail", 6, 3)],
                          masking_method="Right",
                          include=["hbt_roi"]),
    ])

    noise = tio.RandomNoise(std=0.035, p=0.3)
    blur = tio.RandomBlur((0, 1), p=0.2)
    standard_augmentations = tio.Compose([
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.RandomElasticDeformation(p=0.5,
                                     num_control_points=(7, 7, 4),
                                     locked_borders=1,
                                     image_interpolation='bspline',
                                     exclude=["full_dwi"]),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.OneOf([
            tio.Compose([blur, noise]),
            tio.Compose([noise, blur]),
        ])
    ],
                                         exclude="full_dwi")

    common_transforms_2 = tio.Compose([
        tio.RescaleIntensity((-1., 1.), (0.5, 99.5)),
        ConcatenateImages(image_names=["mean_dwi", "md", "fa"],
                          image_channels=[1, 1, 1],
                          new_image_name="X"),
        RenameProperty(old_name="hbt_roi" if predict_hbt else "whole_roi",
                       new_name="y"),
        CustomOneHot(include=["y"])
    ])

    transforms = {
        'default':
        tio.Compose([common_transforms_1, common_transforms_2]),
        'training':
        tio.Compose(
            [common_transforms_1, standard_augmentations,
             common_transforms_2]),
    }

    context.add_component("dataset",
                          SubjectFolder,
                          root='$DATASET_PATH',
                          subject_path="subjects",
                          subject_loader=subject_loader,
                          cohorts=cohorts,
                          transforms=transforms)
    context.add_component("model",
                          NestedResUNet,
                          input_channels=3,
                          output_channels=4 if predict_hbt else 2,
                          filters=40,
                          dropout_p=0.2)
    context.add_component("optimizer",
                          Adam,
                          params="self.model.parameters()",
                          lr=0.0002)
    context.add_component("criterion", HybridLogisticDiceLoss)

    training_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            'y_pred_eval', 'y_eval'),
                            log_name='training_segmentation_eval',
                            interval=10),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "Axial",
            'mean_dwi',
            'y_pred_eval',
            'y_eval',
            slice_id=12,
            legend=True,
            ncol=2,
            split_subjects=False),
                            log_name=f"contour_image_training",
                            interval=50),
    ]

    curve_params = {
        "left_whole":
        np.array([-1.96312119e-01, 9.46668029e+00, 2.33635173e+03]),
        "right_whole":
        np.array([-2.68467331e-01, 1.67925603e+01, 2.07224236e+03])
    }

    validation_evaluators = [
        ScheduledEvaluation(evaluator=LabelMapEvaluator(
            'y_pred_eval',
            curve_params=curve_params,
            curve_attribute='age',
            stats_to_output=('volume', 'error', 'absolute_error',
                             'squared_error', 'percent_diff')),
                            log_name="predicted_label_eval",
                            cohorts=['cbbrain_validation', 'ab300_validation'],
                            interval=50),
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            "y_pred_eval", "y_eval"),
                            log_name="segmentation_eval",
                            cohorts=['cbbrain_validation'],
                            interval=50),
        ScheduledEvaluation(
            evaluator=ContourImageEvaluator("Axial",
                                            "mean_dwi",
                                            "y_pred_eval",
                                            "y_eval",
                                            slice_id=10,
                                            legend=True,
                                            ncol=5,
                                            split_subjects=False),
            log_name="contour_image_axial",
            cohorts=['cbbrain_validation', 'ab300_validation_plot'],
            interval=250),
        ScheduledEvaluation(
            evaluator=ContourImageEvaluator("Coronal",
                                            "mean_dwi",
                                            "y_pred_eval",
                                            "y_eval",
                                            slice_id=44,
                                            legend=True,
                                            ncol=2,
                                            split_subjects=False),
            log_name="contour_image_coronal",
            cohorts=['cbbrain_validation', 'ab300_validation_plot'],
            interval=250),
    ]

    def scoring_function(evaluation_dict):
        # Grab the output of the SegmentationEvaluator
        seg_eval_cbbrain = evaluation_dict['segmentation_eval'][
            'cbbrain_validation']["summary_stats"]

        # Get the mean dice for each label (the mean is across subjects)
        cbbrain_dice = seg_eval_cbbrain['mean', :, 'dice']

        # Now take the mean across all labels
        cbbrain_dice = cbbrain_dice.mean()
        score = cbbrain_dice
        return score

    train_predictor = StandardPredict(sagittal_split=True,
                                      image_names=['X', 'y'])
    validation_predictor = StandardPredict(sagittal_split=True,
                                           image_names=['X'])

    train_dataloader_factory = StandardDataLoader(sampler=RandomSampler)
    validation_dataloader_factory = StandardDataLoader(
        sampler=SequentialSampler)

    context.add_component(
        "trainer",
        SegmentationTrainer,
        training_batch_size=training_batch_size,
        save_rate=100,
        scoring_interval=50,
        scoring_function=scoring_function,
        one_time_evaluators=[],
        training_evaluators=training_evaluators,
        validation_evaluators=validation_evaluators,
        max_iterations_with_no_improvement=2000,
        train_predictor=train_predictor,
        validation_predictor=validation_predictor,
        train_dataloader_factory=train_dataloader_factory,
        validation_dataloader_factory=validation_dataloader_factory)

    return context
Ejemplo n.º 15
0
 def test_empty_mask(self):
     subject = copy.deepcopy(self.sample_subject)
     subject.label.set_data(subject.label.data * 0)
     rescale = tio.RescaleIntensity(masking_method='label')
     with self.assertWarns(RuntimeWarning):
         rescale(subject)
Ejemplo n.º 16
0
 def test_too_many_values_for_percentiles(self):
     with self.assertRaises(ValueError):
         tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(1, 2, 3))
Ejemplo n.º 17
0
random_elastic = tio.RandomElasticDeformation(
    max_displacement=2 * np.array(max_displacement), seed=0,
)
slice_large_displacement = random_elastic(slice_grid)
to_pil(slice_large_displacement)

#Intensity transforms
#Intensity transforms modify only scalar images, whereas label maps are left
#as they were.

#Preprocessing (normalization)
#Rescale intensity
#We can change the intensities range of our images so that it lies within
#e.g. 0 and 1, or -1 and 1, using RescaleIntensity.

rescale = tio.RescaleIntensity((-1, 1))
rescaled = rescale(fpg)
fig, axes = plt.subplots(2, 1)
sns.distplot(fpg.mri.data, ax=axes[0], kde=False)
sns.distplot(rescaled.mri.data, ax=axes[1], kde=False)
axes[0].set_title('Original histogram')
axes[1].set_title('Intensity rescaling')
axes[0].set_ylim((0, 1e6))
axes[1].set_ylim((0, 1e6))
plt.tight_layout()

#There seem to be some outliers with very high intensity.
#We might be able to get rid of those by mapping some percentiles to our final values.

rescale = tio.RescaleIntensity((-1, 1), percentiles=(1, 99))
rescaled = rescale(fpg_ras)
Ejemplo n.º 18
0
 def test_min_percentile_higher_than_max_percentile(self):
     with self.assertRaises(ValueError):
         tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(1, 0))
Ejemplo n.º 19
0
 def test_wrong_out_min_max_type(self):
     with self.assertRaises(ValueError):
         tio.RescaleIntensity(out_min_max='wrong')
Ejemplo n.º 20
0
 def test_out_min_higher_than_out_max(self):
     with self.assertRaises(ValueError):
         tio.RescaleIntensity(out_min_max=(1, 0))
Ejemplo n.º 21
0
        '--exlcude_to_softmax',
        type=int,
        default=0,
        help='nb volume to exclude (from the end) from the softmax activation')
    parser.add_argument('-c',
                        '--CropOrPad',
                        type=str,
                        default='None',
                        help='tuple of target dim')

    args = parser.parse_args()
    nb_vol_exclude = int(args.exlcude_to_softmax)
    vol_crop_pad = eval(args.CropOrPad)

    volume = torchio.ScalarImage(path=args.volume)
    tscale = torchio.RescaleIntensity(percentiles=(0, 99))
    volume = tscale(volume)

    if vol_crop_pad:
        tpad = torchio.CropOrPad(target_shape=vol_crop_pad)
        volume = tpad(volume)

    model_struct = {
        'module': args.model_module,
        'name': args.model_name,
        'last_one': False,
        'path': args.model,
        'device': args.device
    }
    config = Config(None, None, save_files=False)
    model_struct = config.parse_model_file(model_struct)
    if args.device.startswith("cuda"):
        if torch.cuda.is_available():
            device = torch.device(args.device)
        else:
            device = torch.device("cpu")
            print("cuda not available, switched to cpu")
    else:
        device = torch.device(args.device)
    print("using device", device)

    context = Context(device, file_name=args.model_path, variables=dict(DATASET_FOLDER=args.dataset_path))

    # Fix torchio deprecating something...
    fixed_transform = tio.Compose([
        tio.Crop((62, 62, 70, 58, 0, 0)),
        tio.RescaleIntensity((-1, 1), (0.5, 99.5)),
        tio.Pad((0, 0, 0, 0, 2, 2)),
        tio.ZNormalization(),
    ])
    context.dataset.subjects_dataset.subject_dataset.set_transform(fixed_transform)

    if args.out_folder != "" and not os.path.exists(args.out_folder):
        print(args.out_folder, "does not exist. Creating it.")
        os.makedirs(args.out_folder)

    total = len(context.dataset) // 2
    pbar = tqdm(total=total)
    context.model.eval()
    for i in range(total):
        out_folder = args.out_folder
        if out_folder == "":
Ejemplo n.º 23
0
def get_context(device, variables, fold=0, **kwargs):
    context = TorchContext(device, name="msseg2", variables=variables)
    context.file_paths.append(os.path.abspath(__file__))
    context.config = config = {'fold': fold, 'patch_size': 96}

    input_images = ["flair_time01", "flair_time02"]

    subject_loader = ComposeLoaders([
        ImageLoader(glob_pattern="flair_time01*",
                    image_name='flair_time01',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="flair_time02*",
                    image_name='flair_time02',
                    image_constructor=tio.ScalarImage),
        ImageLoader(glob_pattern="brain_mask.*",
                    image_name='brain_mask',
                    image_constructor=tio.LabelMap,
                    label_values={"brain": 1}),
        ImageLoader(glob_pattern="ground_truth.*",
                    image_name="ground_truth",
                    image_constructor=tio.LabelMap,
                    label_values={"lesion": 1}),
    ])

    cohorts = {}
    cohorts['all'] = RequireAttributes(input_images)
    cohorts['validation'] = RandomFoldFilter(num_folds=5,
                                             selection=fold,
                                             seed=0xDEADBEEF)
    cohorts['training'] = NegateFilter(cohorts['validation'])

    common_transforms_1 = tio.Compose([
        SetDataType(torch.float),
        EnforceConsistentAffine(source_image_name='flair_time01'),
        TargetResample(target_spacing=1, tolerance=0.11),
        CropToMask('brain_mask'),
        MinSizePad(config['patch_size'])
    ])

    augmentations = tio.Compose([
        RandomPermuteDimensions(),
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.OneOf(
            {
                tio.RandomElasticDeformation():
                0.2,
                tio.RandomAffine(scales=0.2,
                                 degrees=45,
                                 default_pad_value='otsu'):
                0.8,
            },
            p=0.75),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.RandomBlur((0, 1), p=0.2),
        tio.RandomNoise(std=0.1, p=0.35)
    ])

    common_transforms_2 = tio.Compose([
        tio.RescaleIntensity((-1, 1.), (0.05, 99.5)),
        ConcatenateImages(image_names=["flair_time01", "flair_time02"],
                          image_channels=[1, 1],
                          new_image_name="X"),
        RenameProperty(old_name='ground_truth', new_name='y'),
        CustomOneHot(include="y"),
    ])

    transforms = {
        'default':
        tio.Compose([common_transforms_1, common_transforms_2]),
        'training':
        tio.Compose([
            common_transforms_1, augmentations, common_transforms_2,
            ImageFromLabels(new_image_name="patch_probability",
                            label_weights=[('brain_mask', 'brain', 1),
                                           ('y', 'lesion', 100)])
        ]),
    }

    context.add_component("dataset",
                          SubjectFolder,
                          root='$DATASET_PATH',
                          subject_path="",
                          subject_loader=subject_loader,
                          cohorts=cohorts,
                          transforms=transforms)
    context.add_component("model",
                          ModularUNet,
                          in_channels=2,
                          out_channels=2,
                          filters=[40, 40, 80, 80, 120, 120],
                          depth=6,
                          block_params={'residual': True},
                          downsample_class=BlurConv3d,
                          downsample_params={
                              'kernel_size': 3,
                              'stride': 2,
                              'padding': 1
                          },
                          upsample_class=BlurConvTranspose3d,
                          upsample_params={
                              'kernel_size': 3,
                              'stride': 2,
                              'padding': 1,
                              'output_padding': 0
                          })
    context.add_component("optimizer",
                          SGD,
                          params="self.model.parameters()",
                          lr=0.001,
                          momentum=0.95)
    context.add_component("criterion",
                          HybridLogisticDiceLoss,
                          logistic_class_weights=[1, 100])

    training_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            'y_pred_eval', 'y_eval'),
                            log_name='training_segmentation_eval',
                            interval=15),
        ScheduledEvaluation(evaluator=LabelMapEvaluator('y_pred_eval'),
                            log_name='training_label_eval',
                            interval=15),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "random",
            'flair_time02',
            'y_pred_eval',
            'y_eval',
            slice_id=0,
            legend=True,
            ncol=2,
            interesting_slice=True,
            split_subjects=False),
                            log_name=f"contour_image",
                            interval=15),
    ]

    validation_evaluators = [
        ScheduledEvaluation(evaluator=SegmentationEvaluator(
            "y_pred_eval", "y_eval"),
                            log_name="segmentation_eval",
                            cohorts=["validation"],
                            interval=50),
        ScheduledEvaluation(evaluator=ContourImageEvaluator(
            "interesting",
            'flair_time02',
            'y_pred_eval',
            'y_eval',
            slice_id=0,
            legend=True,
            ncol=1,
            interesting_slice=True,
            split_subjects=True),
                            log_name=f"contour_image",
                            cohorts=["validation"],
                            interval=50),
    ]

    def scoring_function(evaluation_dict):
        # Grab the output of the SegmentationEvaluator
        seg_eval = evaluation_dict['segmentation_eval']['validation']

        # Take mean dice, while accounting for subjects which have no lesions.
        # Dice is 0/0 = nan when the model correctly outputs no lesions. This is counted as a score of 1.0.
        # Dice is (>0)/0 = posinf when the model incorrectly predicts lesions when there are none.
        # This is counted as a score of 0.0.
        dice = torch.tensor(seg_eval["subject_stats"]['dice.lesion'])
        dice = dice.nan_to_num(nan=1.0, posinf=0.0)
        score = dice.mean()

        return score

    train_predictor = StandardPredict(image_names=['X', 'y'])
    validation_predictor = PatchPredict(patch_batch_size=32,
                                        patch_size=config['patch_size'],
                                        patch_overlap=(config['patch_size'] //
                                                       8),
                                        padding_mode=None,
                                        overlap_mode='average',
                                        image_names=['X'])

    patch_sampler = tio.WeightedSampler(patch_size=config['patch_size'],
                                        probability_map='patch_probability')
    train_dataloader_factory = PatchDataLoader(max_length=100,
                                               samples_per_volume=1,
                                               sampler=patch_sampler)
    validation_dataloader_factory = StandardDataLoader(
        sampler=SequentialSampler)

    context.add_component(
        "trainer",
        SegmentationTrainer,
        training_batch_size=4,
        save_rate=100,
        scoring_interval=50,
        scoring_function=scoring_function,
        one_time_evaluators=[],
        training_evaluators=training_evaluators,
        validation_evaluators=validation_evaluators,
        max_iterations_with_no_improvement=2000,
        train_predictor=train_predictor,
        validation_predictor=validation_predictor,
        train_dataloader_factory=train_dataloader_factory,
        validation_dataloader_factory=validation_dataloader_factory)

    return context
Ejemplo n.º 24
0
 def test_min_max(self):
     transform = tio.RescaleIntensity(out_min_max=(0, 1))
     transformed = transform(self.sample_subject)
     self.assertEqual(transformed.t1.data.min(), 0)
     self.assertEqual(transformed.t1.data.max(), 1)
Ejemplo n.º 25
0
import torchio as tio
import matplotlib.pyplot as plt

torch.manual_seed(0)

batch_size = 4
subject = tio.datasets.FPG()
subject.remove_image('seg')
subjects = 4 * [subject]

transform = tio.Compose((
    tio.ToCanonical(),
    tio.RandomGamma(p=0.75),
    tio.RandomBlur(p=0.5),
    tio.RandomFlip(),
    tio.RescaleIntensity((-1, 1)),
))

dataset = tio.SubjectsDataset(subjects, transform=transform)

transformed = dataset[0]
print('Applied transforms:')  # noqa: T001
pprint.pprint(transformed.history)  # noqa: T003
print('\nComposed transform to reproduce history:')  # noqa: T001
print(transformed.get_composed_history())  # noqa: T001
print('\nComposed transform to invert applied transforms when possible:'
      )  # noqa: T001, E501
print(transformed.get_inverse_transform(ignore_intensity=False))  # noqa: T001

loader = torch.utils.data.DataLoader(
    dataset,
Ejemplo n.º 26
0
 def test_wrong_percentiles_type(self):
     with self.assertRaises(ValueError):
         tio.RescaleIntensity(out_min_max=(0, 1), percentiles='wrong')