Пример #1
0
def create_patchQs(train_subs, val_subs, patch_size, patch_qlen, patch_per_vol,
                   inference_strides):
    train_queue = None
    val_queue = None
    grid_samplers = []

    if train_subs is not None:
        sampler = tio.data.UniformSampler(patch_size)
        train_queue = tio.Queue(subjects_dataset=train_subs,
                                max_length=patch_qlen,
                                samples_per_volume=patch_per_vol,
                                sampler=sampler,
                                num_workers=0,
                                start_background=True)

    if val_subs is not None:
        stride_length, stride_width, stride_depth = inference_strides.split(
            ',')
        overlap = np.subtract(
            patch_size,
            (int(stride_length), int(stride_width), int(stride_depth)))
        for i in range(len(val_subs)):
            grid_sampler = tio.inference.GridSampler(val_subs[i], patch_size,
                                                     overlap)
            grid_samplers.append(grid_sampler)
        val_queue = torch.utils.data.ConcatDataset(grid_samplers)

    return train_queue, val_queue, grid_samplers
Пример #2
0
    def train_dataloader(self) -> DataLoader:
        training_transform = get_train_transforms()
        train_imageDataset = torchio.ImagesDataset(
            self.training_subjects, transform=training_transform)

        patches_training_set = torchio.Queue(
            subjects_dataset=train_imageDataset,
            # Maximum number of patches that can be stored in the queue.
            # Using a large number means that the queue needs to be filled less often,
            # but more CPU memory is needed to store the patches.
            max_length=self.max_queue_length,
            # Number of patches to extract from each volume.
            # A small number of patches ensures a large variability in the queue,
            # but training will be slower.
            samples_per_volume=self.samples_per_volume,
            #  A sampler used to extract patches from the volumes.
            sampler=torchio.sampler.UniformSampler(self.patch_size),
            num_workers=self.num_workers,
            # If True, the subjects dataset is shuffled at the beginning of each epoch,
            # i.e. when all patches from all subjects have been processed
            shuffle_subjects=False,
            # If True, patches are shuffled after filling the queue.
            shuffle_patches=True,
            verbose=True,
        )

        training_loader = DataLoader(patches_training_set,
                                     batch_size=self.hparams.batch_size)

        print(
            f"{ctime()}: getting number of training subjects {len(training_loader)}"
        )
        return training_loader
Пример #3
0
def plot_batch(sampler):
    queue = tio.Queue(dataset, max_queue_length, patches_per_volume, sampler)
    loader = torch.utils.data.DataLoader(queue, batch_size=16)
    batch = tio.utils.get_first_item(loader)

    fig, axes = plt.subplots(4, 4, figsize=(12, 10))
    for ax, im in zip(axes.flatten(), batch['t1']['data']):
        ax.imshow(im.squeeze(), cmap='gray')
    plt.suptitle(sampler.__class__.__name__)
    plt.tight_layout()
Пример #4
0
    def __init__(self, root_dir, img_range=(0,0)):
        self.root_dir = root_dir
        self.img_range = img_range


        subject_lists = []

        #check if there is a labels
        if self.root_dir[-1] != '/':
            self.root_dir += '/'

        self.is_labeled = os.path.isdir(self.root_dir + LABEL_DIR)

        self.files = [re.findall('[0-9]{4}', filename)[0] for filename in os.listdir(self.root_dir + TRAIN_DIR)]
        self.files = sorted(self.files, key = lambda f : int(f))

        # store all subjects in the list
        for img_num in range(img_range[0], img_range[1]+1):
            img_file = os.path.join(self.root_dir, TRAIN_DIR, IMG_PREFIX + self.files[img_num] + EXT)
            label_file = os.path.join(self.root_dir, LABEL_DIR, LABEL_PREFIX + self.files[img_num] + EXT)

            subject = torchio.Subject(
                torchio.Image('t1', img_file, torchio.INTENSITY),
                torchio.Image('label', label_file, torchio.LABEL)
            )

            subject_lists.append(subject)

            print(img_file)
            print(label_file)

        # Define transforms for data normalization and augmentation
        mtransforms = (
            ZNormalization(),
            #transforms.RandomNoise(std_range=(0, 0.25)),
            #transforms.RandomFlip(axes=(0,)),
        )

        self.subjects = torchio.ImagesDataset(subject_lists, transform=transforms.Compose(mtransforms))

        self.dataset = torchio.Queue(
            subjects_dataset=self.subjects,
            max_length=2,
            samples_per_volume=675,
            sampler_class=torchio.sampler.ImageSampler,
            patch_size=(240, 240, 3),
            num_workers=4,
            shuffle_subjects=False,
            shuffle_patches=True
        )

        print("Dataset details\n  Images: {}".format(self.img_range[1] - self.img_range[0] + 1))
Пример #5
0
    def get_data_loader(self, dataset: tio.SubjectsDataset, batch_size: int,
                        num_workers: int):
        queue = tio.Queue(
            dataset,
            max_length=self.max_length,
            samples_per_volume=self.samples_per_volume,
            sampler=self.sampler,
            num_workers=num_workers,
        )
        dataloader = DataLoader(dataset=queue,
                                batch_size=batch_size,
                                collate_fn=no_op)

        return dataloader
Пример #6
0
#%%
num_workers = 16
print('num_workers : ' + str(num_workers))
patch_size = 96
max_queue_length = 1024
samples_per_volume = 8
batch_size = 1

sampler = tio.data.UniformSampler(patch_size)

patches_training_set = tio.Queue(
    subjects_dataset=training_set,
    max_length=max_queue_length,
    samples_per_volume=samples_per_volume,
    sampler=sampler,
    num_workers=num_workers,
    shuffle_subjects=True,
    shuffle_patches=True,
)

patches_validation_set = tio.Queue(
    subjects_dataset=validation_set,
    max_length=max_queue_length,
    samples_per_volume=samples_per_volume,
    sampler=sampler,
    num_workers=num_workers,
    shuffle_subjects=False,
    shuffle_patches=False,
)
Пример #7
0
def prepare_dataload(patches=True):

    training_batch_size = 32
    validation_batch_size = 2 * training_batch_size

    patch_size = 25
    samples_per_volume = 10
    max_queue_length = 300
    sampler = tio.data.UniformSampler(patch_size)

    num_subjects = len(dataset)
    num_training_subjects = 245
    num_validation_subjects = 70
    num_test_subjects = 35

    num_split_subjects = num_training_subjects, num_validation_subjects, num_test_subjects
    training_subjects, validation_subjects, test_subjects = torch.utils.data.random_split(
        dataset, num_split_subjects)

    training_set = tio.SubjectsDataset(training_subjects,
                                       transform=training_transform)

    validation_set = tio.SubjectsDataset(validation_subjects,
                                         transform=validation_transform)

    test_set = tio.SubjectsDataset(test_subjects,
                                   transform=validation_transform)

    patches_training_set = tio.Queue(
        subjects_dataset=training_set,
        max_length=max_queue_length,
        samples_per_volume=samples_per_volume,
        sampler=sampler,
        num_workers=2,
        shuffle_subjects=True,
        shuffle_patches=True,
    )

    patches_validation_set = tio.Queue(
        subjects_dataset=validation_set,
        max_length=max_queue_length,
        samples_per_volume=samples_per_volume * 2,
        sampler=sampler,
        num_workers=2,
        shuffle_subjects=False,
        shuffle_patches=False,
    )

    patches_test_set = tio.Queue(
        subjects_dataset=test_set,
        max_length=max_queue_length,
        samples_per_volume=samples_per_volume * 2,
        sampler=sampler,
        num_workers=2,
        shuffle_subjects=False,
        shuffle_patches=False,
    )

    training_loader_patches = torch.utils.data.DataLoader(
        patches_training_set, batch_size=training_batch_size)

    validation_loader_patches = torch.utils.data.DataLoader(
        patches_validation_set, batch_size=validation_batch_size)

    test_loader_patches = torch.utils.data.DataLoader(
        patches_test_set, batch_size=validation_batch_size)

    training_loader = torch.utils.data.DataLoader(training_set, batch_size=2)

    validation_loader = torch.utils.data.DataLoader(validation_set,
                                                    batch_size=1)

    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1)

    if patches:
        return training_loader_patches, validation_loader_patches, test_loader_patches
    else:
        return training_loader, validation_loader, test_loader
Пример #8
0
def ImagesFromDataFrame(dataframe,
                        psize,
                        headers,
                        q_max_length=10,
                        q_samples_per_volume=1,
                        q_num_workers=2,
                        q_verbose=False,
                        sampler='label',
                        train=True,
                        augmentations=None,
                        preprocessing=None,
                        in_memory=False):
    # Finding the dimension of the dataframe for computational purposes later
    num_row, num_col = dataframe.shape
    # num_channels = num_col - 1 # for non-segmentation tasks, this might be different
    # changing the column indices to make it easier
    dataframe.columns = range(0, num_col)
    dataframe.index = range(0, num_row)
    # This list will later contain the list of subjects
    subjects_list = []

    channelHeaders = headers['channelHeaders']
    labelHeader = headers['labelHeader']
    predictionHeaders = headers['predictionHeaders']
    subjectIDHeader = headers['subjectIDHeader']

    sampler = sampler.lower()  # for easier parsing

    # define the control points and swap axes for augmentation
    augmentation_patchAxesPoints = copy.deepcopy(psize)
    for i in range(len(augmentation_patchAxesPoints)):
        augmentation_patchAxesPoints[i] = max(
            round(augmentation_patchAxesPoints[i] / 10),
            1)  # always at least have 1

    # iterating through the dataframe
    resizeCheck = False
    for patient in range(num_row):
        # We need this dict for storing the meta data for each subject
        # such as different image modalities, labels, any other data
        subject_dict = {}
        subject_dict['subject_id'] = dataframe[subjectIDHeader][patient]
        # iterating through the channels/modalities/timepoints of the subject
        for channel in channelHeaders:
            # assigning the dict key to the channel
            if not in_memory:
                subject_dict[str(channel)] = Image(str(
                    dataframe[channel][patient]),
                                                   type=torchio.INTENSITY)
            else:
                img = sitk.ReadImage(str(dataframe[channel][patient]))
                array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
                subject_dict[str(channel)] = Image(
                    tensor=array,
                    type=torchio.INTENSITY,
                    path=dataframe[channel][patient])

            # if resize has been defined but resample is not (or is none)
            if not resizeCheck:
                if not (preprocessing is None) and ('resize' in preprocessing):
                    if (preprocessing['resize'] is not None):
                        resizeCheck = True
                        if not ('resample' in preprocessing):
                            preprocessing['resample'] = {}
                            if not ('resolution' in preprocessing['resample']):
                                preprocessing['resample'][
                                    'resolution'] = resize_image_resolution(
                                        subject_dict[str(channel)].as_sitk(),
                                        preprocessing['resize'])
                        else:
                            print(
                                'WARNING: \'resize\' is ignored as \'resample\' is defined under \'data_processing\', this will be skipped',
                                file=sys.stderr)
                else:
                    resizeCheck = True

        # # for regression
        # if predictionHeaders:
        #     # get the mask
        #     if (subject_dict['label'] is None) and (class_list is not None):
        #         sys.exit('The \'class_list\' parameter has been defined but a label file is not present for patient: ', patient)

        if labelHeader is not None:
            if not in_memory:
                subject_dict['label'] = Image(str(
                    dataframe[labelHeader][patient]),
                                              type=torchio.LABEL)
            else:
                img = sitk.ReadImage(str(dataframe[labelHeader][patient]))
                array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
                subject_dict['label'] = Image(
                    tensor=array,
                    type=torchio.LABEL,
                    path=dataframe[labelHeader][patient])

            subject_dict['path_to_metadata'] = str(
                dataframe[labelHeader][patient])
        else:
            subject_dict['label'] = "NA"
            subject_dict['path_to_metadata'] = str(dataframe[channel][patient])

        # iterating through the values to predict of the subject
        valueCounter = 0
        for values in predictionHeaders:
            # assigning the dict key to the channel
            subject_dict['value_' + str(valueCounter)] = np.array(
                dataframe[values][patient])
            valueCounter = valueCounter + 1

        # Initializing the subject object using the dict
        subject = Subject(subject_dict)

        # padding image, but only for label sampler, because we don't want to pad for uniform
        if 'label' in sampler or 'weight' in sampler:
            psize_pad = list(
                np.asarray(np.round(np.divide(psize, 2)), dtype=int))
            padder = Pad(
                psize_pad, padding_mode='symmetric'
            )  # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            subject = padder(subject)

        # Appending this subject to the list of subjects
        subjects_list.append(subject)

    augmentation_list = []

    # first, we want to do thresholding, followed by clipping, if it is present - required for inference as well
    if not (preprocessing is None):
        if train:  # we want the crop to only happen during training
            if 'crop_external_zero_planes' in preprocessing:
                augmentation_list.append(
                    global_preprocessing_dict['crop_external_zero_planes'](
                        psize))
        for key in ['threshold', 'clip']:
            if key in preprocessing:
                augmentation_list.append(global_preprocessing_dict[key](
                    min=preprocessing[key]['min'],
                    max=preprocessing[key]['max']))

        # first, we want to do the resampling, if it is present - required for inference as well
        if 'resample' in preprocessing:
            if 'resolution' in preprocessing['resample']:
                # resample_split = str(aug).split(':')
                resample_values = tuple(
                    np.array(preprocessing['resample']['resolution']).astype(
                        np.float))
                if len(resample_values) == 2:
                    resample_values = tuple(np.append(resample_values, 1))
                augmentation_list.append(Resample(resample_values))

        # next, we want to do the intensity normalize - required for inference as well
        if 'normalize' in preprocessing:
            augmentation_list.append(global_preprocessing_dict['normalize'])
        elif 'normalize_nonZero' in preprocessing:
            augmentation_list.append(
                global_preprocessing_dict['normalize_nonZero'])
        elif 'normalize_nonZero_masked' in preprocessing:
            augmentation_list.append(
                global_preprocessing_dict['normalize_nonZero_masked'])

    # other augmentations should only happen for training - and also setting the probabilities
    # for the augmentations
    if train and not (augmentations == None):
        for aug in augmentations:
            if aug != 'default_probability':
                actual_function = None

                if aug == 'flip':
                    if ('axes_to_flip' in augmentations[aug]):
                        print(
                            'WARNING: \'flip\' augmentation needs the key \'axis\' instead of \'axes_to_flip\'',
                            file=sys.stderr)
                        augmentations[aug]['axis'] = augmentations[aug][
                            'axes_to_flip']
                    actual_function = global_augs_dict[aug](
                        axes=augmentations[aug]['axis'],
                        p=augmentations[aug]['probability'])
                elif aug in ['rotate_90', 'rotate_180']:
                    for axis in augmentations[aug]['axis']:
                        augmentation_list.append(global_augs_dict[aug](
                            axis=axis, p=augmentations[aug]['probability']))
                elif aug in ['swap', 'elastic']:
                    actual_function = global_augs_dict[aug](
                        patch_size=augmentation_patchAxesPoints,
                        p=augmentations[aug]['probability'])
                elif aug == 'blur':
                    actual_function = global_augs_dict[aug](
                        std=augmentations[aug]['std'],
                        p=augmentations[aug]['probability'])
                elif aug == 'noise':
                    actual_function = global_augs_dict[aug](
                        mean=augmentations[aug]['mean'],
                        std=augmentations[aug]['std'],
                        p=augmentations[aug]['probability'])
                elif aug == 'anisotropic':
                    actual_function = global_augs_dict[aug](
                        axes=augmentations[aug]['axis'],
                        downsampling=augmentations[aug]['downsampling'],
                        p=augmentations[aug]['probability'])
                else:
                    actual_function = global_augs_dict[aug](
                        p=augmentations[aug]['probability'])
                if actual_function is not None:
                    augmentation_list.append(actual_function)

    if augmentation_list:
        transform = Compose(augmentation_list)
    else:
        transform = None
    subjects_dataset = torchio.SubjectsDataset(subjects_list,
                                               transform=transform)
    if not train:
        return subjects_dataset
    if sampler in ('weighted', 'weightedsampler', 'weightedsample'):
        sampler = global_sampler_dict[sampler](psize, probability_map='label')
    else:
        sampler = global_sampler_dict[sampler](psize)
    # all of these need to be read from model.yaml
    patches_queue = torchio.Queue(subjects_dataset,
                                  max_length=q_max_length,
                                  samples_per_volume=q_samples_per_volume,
                                  sampler=sampler,
                                  num_workers=q_num_workers,
                                  shuffle_subjects=True,
                                  shuffle_patches=True,
                                  verbose=q_verbose)
    return patches_queue
Пример #9
0
training_transform = Compose([RescaleIntensity((0, 1)), RandomNoise(p=0.05)])
validation_transform = Compose([RescaleIntensity((0, 1))])
test_transform = Compose([RescaleIntensity((0, 1))])

training_dataset = tio.SubjectsDataset(training_subjects,
                                       transform=training_transform)
validation_dataset = tio.SubjectsDataset(validation_subjects,
                                         transform=validation_transform)
test_dataset = tio.SubjectsDataset(test_subjects, transform=test_transform)
'''Patching'''

patches_training_set = tio.Queue(
    subjects_dataset=training_dataset,
    max_length=max_queue_length,
    samples_per_volume=samples_per_volume,
    sampler=tio.sampler.UniformSampler(patch_size),
    # shuffle_subjects=True,
    # shuffle_patches=True,
)

patches_validation_set = tio.Queue(
    subjects_dataset=validation_dataset,
    max_length=max_queue_length,
    samples_per_volume=samples_per_volume * 2,
    sampler=tio.sampler.UniformSampler(patch_size),
    # shuffle_subjects=False,
    # shuffle_patches=False,
)

training_loader = torch.utils.data.DataLoader(patches_training_set,
                                              batch_size=training_batch_size,
Пример #10
0
def get_loaders(data, cv_split, training_transform = False,
        validation_transform = False, patch_size = 64,
        patches = False, samples_per_volume = 6,
        max_queue_length = 180, training_batch_size = 1,
        validation_batch_size = 1, mask = False, input_type = 'T1'):
    
    """
    Function creates dataloaders 
    
    Arguments:
        * data (data_processor.DataMriSegmentation): torchio dataset
        * cv_split (list): list of two arrays, one with train indexes, other with test indexes
        * training_transform (bool/torchio.transforms): whether or not to use transform for training images
        * validation_transform (bool/torchio.transforms): whether or not to use  transform for validation images
        * patch_size (int): size of patches
        * patches (bool): if True, than patch-based training will be applied
        https://niftynet.readthedocs.io/en/dev/window_sizes.html - about patch based training
        * samples_per_volume (int): number of patches to extract from each volume
        * max_queue_length (int): maximum number of patches that can be stored in the queue
        * training_batch_size (int): size of batches for training
        * validation_batch_size (int): size of batches for validation
        * mask (bool): if True, than masked images will be used 
    
    Output:
        * training_loader (torch.utils.data.DataLoader): loader for train
        * validation_loader (torch.utils.data.DataLoader): loader for test
    """
    
    training_idx, validation_idx = cv_split
    mask = data.mask
    
    print('Training set:', len(training_idx), 'subjects')
    print('Validation set:', len(validation_idx), 'subjects')
    print(f'Input type is {input_type}')
    
    if input_type == 'T1':
        training_set = get_torchio_dataset(
            list(data.img_files[training_idx].values), 
            list(data.img_seg[training_idx].values),
            training_transform)

        validation_set = get_torchio_dataset(
            list(data.img_files[validation_idx].values), 
            list(data.img_seg[validation_idx].values),
            validation_transform)

        if mask in ['bb', 'combined']:
            print(f'Mask type is {mask}')
            # if using masked data for training
            training_set = get_torchio_dataset(
                list(data.img_files[training_idx].values), 
                list(data.img_mask[training_idx].values),
                training_transform)

            validation_set = get_torchio_dataset(
                list(data.img_files[validation_idx].values), 
                list(data.img_mask[validation_idx].values),
                validation_transform)

        training_loader = torch.utils.data.DataLoader(
            training_set, batch_size = training_batch_size)

        validation_loader = torch.utils.data.DataLoader(
            validation_set, batch_size = validation_batch_size)
         
    if input_type == 'seg':
        if mask in ['bb', 'combined']:
            print(f'Mask type is {mask}')
            # if using masked data for training
            training_set = get_torchio_dataset(
                list(data.img_seg[training_idx].values), 
                list(data.img_mask[training_idx].values),
                training_transform)

            validation_set = get_torchio_dataset(
                list(data.img_seg[validation_idx].values), 
                list(data.img_mask[validation_idx].values),
                validation_transform)

        training_loader = torch.utils.data.DataLoader(
            training_set, batch_size = training_batch_size)

        validation_loader = torch.utils.data.DataLoader(
            validation_set, batch_size = validation_batch_size)
    
    if patches:
        # https://niftynet.readthedocs.io/en/dev/window_sizes.html - about patch based training
        # https://torchio.readthedocs.io/data/patch_training.html - about Queue
        patches_training_set = torchio.Queue(
            subjects_dataset = training_set,
            max_length = max_queue_length,
            samples_per_volume = samples_per_volume,
            patch_size = patch_size,
            sampler_class = torchio.sampler.ImageSampler,
            num_workers = multiprocessing.cpu_count(),
            shuffle_subjects = True,
            shuffle_patches = True,
        )

        patches_validation_set = torchio.Queue(
            subjects_dataset = validation_set,
            max_length = max_queue_length,
            samples_per_volume = samples_per_volume,
            patch_size = patch_size,
            sampler_class = torchio.sampler.ImageSampler,
            num_workers = multiprocessing.cpu_count(),
            shuffle_subjects = False,
            shuffle_patches = False,
        )

        training_loader = torch.utils.data.DataLoader(
            patches_training_set, batch_size = training_batch_size)

        validation_loader = torch.utils.data.DataLoader(
            patches_validation_set, batch_size = validation_batch_size)
        
        print('Patches mode')
        print('Training loader length:', len(training_loader))
        print('Validation loader length:', len(validation_loader))
    
    return training_loader, validation_loader
Пример #11
0
import torchio as tio
from torch.utils.data import DataLoader
import resource
import time

n_subjects = 16
max_length = 40
samples_per_volume = 5
num_workers = 8
patch_size = 128
batch_size = 2

sampler = tio.data.UniformSampler(patch_size)
subject = tio.datasets.Colin27()
dataset = tio.SubjectsDataset(n_subjects * [subject])
queue = tio.Queue(dataset, max_length, samples_per_volume, sampler,
                  num_workers)


class DummyDataModule(pl.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(queue, batch_size=batch_size)


class DummyModule(pl.LightningModule):
    def configure_optimizers(self):
        pass

    def training_step(self, *args, **kwargs):
        #pdb.set_trace()  # Use inspect_mem() here.
        time.sleep(0.1)
        main_memory = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1000
Пример #12
0
        RandomElasticDeformation(): 0.2,
    }, p=0.5),  # Changed from p=0.75 24/6/20
])
# Create the datasets
training_dataset = torchio.ImagesDataset(
    [train_subject], transform=training_transform)

validation_dataset = torchio.ImagesDataset(
    [valid_subject])
# Define the queue of sampled patches for training and validation
sampler = torchio.data.UniformSampler(PATCH_SIZE)
patches_training_set = torchio.Queue(
    subjects_dataset=training_dataset,
    max_length=MAX_QUEUE_LENGTH,
    samples_per_volume=TRAIN_PATCHES,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    shuffle_subjects=False,
    shuffle_patches=True,
)

patches_validation_set = torchio.Queue(
    subjects_dataset=validation_dataset,
    max_length=MAX_QUEUE_LENGTH,
    samples_per_volume=VALID_PATCHES,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    shuffle_subjects=False,
    shuffle_patches=False,
)
Пример #13
0
def ImagesFromDataFrame(
    dataframe, parameters, train, apply_zero_crop=False, loader_type=""
):
    """
    Reads the pandas dataframe and gives the dataloader to use for training/validation/testing

    Parameters
    ----------
    dataframe : pandas.DataFrame
        The main input dataframe which is calculated after splitting the data CSV
    parameters : dict
        The parameters dictionary
    train : bool
        If the dataloader is for training or not. For training, the patching infrastructure and data augmentation is applied.
    apply_zero_crop : bool
        If enabled, the crop_external_zero_plane is applied.
    loader_type : str
        Type of loader for printing.

    Returns
    -------
    subjects_dataset: torchio.SubjectsDataset
        This is the output for validation/testing, where patching and data augmentation is disregarded
    patches_queue: torchio.Queue
        This is the output for training, which is the subjects_dataset queue after patching and data augmentation is taken into account
    """
    # store in previous variable names
    patch_size = parameters["patch_size"]
    headers = parameters["headers"]
    q_max_length = parameters["q_max_length"]
    q_samples_per_volume = parameters["q_samples_per_volume"]
    q_num_workers = parameters["q_num_workers"]
    q_verbose = parameters["q_verbose"]
    sampler = parameters["patch_sampler"]
    augmentations = parameters["data_augmentation"]
    preprocessing = parameters["data_preprocessing"]
    in_memory = parameters["in_memory"]
    enable_padding = parameters["enable_padding"]

    # Finding the dimension of the dataframe for computational purposes later
    num_row, num_col = dataframe.shape
    # changing the column indices to make it easier
    dataframe.columns = range(0, num_col)
    dataframe.index = range(0, num_row)
    # This list will later contain the list of subjects
    subjects_list = []
    subjects_with_error = []

    channelHeaders = headers["channelHeaders"]
    labelHeader = headers["labelHeader"]
    predictionHeaders = headers["predictionHeaders"]
    subjectIDHeader = headers["subjectIDHeader"]

    # this basically means that label sampler is selected with padding
    if isinstance(sampler, dict):
        sampler_padding = sampler["label"]["padding_type"]
        sampler = "label"
    else:
        sampler = sampler.lower()  # for easier parsing
        sampler_padding = "symmetric"

    resize_images_flag = False
    # if resize has been defined but resample is not (or is none)
    if not (preprocessing is None):
        for key in preprocessing.keys():
            # check for different resizing keys
            if key in ["resize", "resize_image", "resize_images"]:
                if not (preprocessing[key] is None):
                    resize_images_flag = True
                    preprocessing["resize_image"] = preprocessing[key]
                    break

    # iterating through the dataframe
    for patient in tqdm(
        range(num_row), desc="Constructing queue for " + loader_type + " data"
    ):
        # We need this dict for storing the meta data for each subject
        # such as different image modalities, labels, any other data
        subject_dict = {}
        subject_dict["subject_id"] = str(dataframe[subjectIDHeader][patient])
        skip_subject = False
        # iterating through the channels/modalities/timepoints of the subject
        for channel in channelHeaders:
            # sanity check for malformed csv
            if not os.path.isfile(str(dataframe[channel][patient])):
                skip_subject = True

            subject_dict[str(channel)] = torchio.ScalarImage(
                dataframe[channel][patient]
            )

            # store image spacing information if not present
            if "spacing" not in subject_dict:
                file_reader = sitk.ImageFileReader()
                file_reader.SetFileName(dataframe[channel][patient])
                file_reader.ReadImageInformation()
                subject_dict["spacing"] = torch.Tensor(file_reader.GetSpacing())

            # if resize_image is requested, the perform per-image resize with appropriate interpolator
            if resize_images_flag:
                img_resized = resize_image(
                    subject_dict[str(channel)].as_sitk(), preprocessing["resize_image"]
                )
                # always ensure resized image spacing is used
                subject_dict["spacing"] = torch.Tensor(img_resized.GetSpacing())
                subject_dict[str(channel)] = torchio.ScalarImage.from_sitk(img_resized)

        # # for regression -- this logic needs to be thought through
        # if predictionHeaders:
        #     # get the mask
        #     if (subject_dict['label'] is None) and (class_list is not None):
        #         sys.exit('The \'class_list\' parameter has been defined but a label file is not present for patient: ', patient)

        if labelHeader is not None:
            if not os.path.isfile(str(dataframe[labelHeader][patient])):
                skip_subject = True

            subject_dict["label"] = torchio.LabelMap(dataframe[labelHeader][patient])
            subject_dict["path_to_metadata"] = str(dataframe[labelHeader][patient])

            # if resize is requested, the perform per-image resize with appropriate interpolator
            if resize_images_flag:
                img_resized = resize_image(
                    subject_dict["label"].as_sitk(),
                    preprocessing["resize_image"],
                    sitk.sitkNearestNeighbor,
                )
                subject_dict["label"] = torchio.LabelMap.from_sitk(img_resized)

        else:
            subject_dict["label"] = "NA"
            subject_dict["path_to_metadata"] = str(dataframe[channel][patient])

        # iterating through the values to predict of the subject
        valueCounter = 0
        for values in predictionHeaders:
            # assigning the dict key to the channel
            subject_dict["value_" + str(valueCounter)] = np.array(
                dataframe[values][patient]
            )
            valueCounter += 1

        # skip subject the condition was tripped
        if not skip_subject:
            # Initializing the subject object using the dict
            subject = torchio.Subject(subject_dict)
            # https://github.com/fepegar/torchio/discussions/587#discussioncomment-928834
            # this is causing memory usage to explode, see https://github.com/CBICA/GaNDLF/issues/128
            if parameters["verbose"]:
                print(
                    "Checking consistency of images in subject '"
                    + subject["subject_id"]
                    + "'"
                )
            try:
                perform_sanity_check_on_subject(subject, parameters)
            except Exception as e:
                subjects_with_error.append(subject["subject_id"])

            # # padding image, but only for label sampler, because we don't want to pad for uniform
            if "label" in sampler or "weight" in sampler:
                if enable_padding:
                    psize_pad = list(
                        np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int)
                    )
                    # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
                    padder = Pad(psize_pad, padding_mode=sampler_padding)
                    subject = padder(subject)

            # load subject into memory: https://github.com/fepegar/torchio/discussions/568#discussioncomment-859027
            if in_memory:
                subject.load()

            # Appending this subject to the list of subjects
            subjects_list.append(subject)

    if subjects_with_error:
        raise ValueError(
            "The following subjects could not be loaded, please recheck or remove and retry:",
            subjects_with_error,
        )

    transformations_list = []

    # augmentations are applied to the training set only
    if train and not (augmentations == None):
        for aug in augmentations:
            aug_lower = aug.lower()
            if aug_lower in global_augs_dict:
                transformations_list.append(
                    global_augs_dict[aug_lower](augmentations[aug])
                )

    transform = get_transforms_for_preprocessing(
        parameters, transformations_list, train, apply_zero_crop
    )

    subjects_dataset = torchio.SubjectsDataset(subjects_list, transform=transform)
    if not train:
        return subjects_dataset
    if sampler in ("weighted", "weightedsampler", "weightedsample"):
        sampler = global_sampler_dict[sampler](patch_size, probability_map="label")
    else:
        sampler = global_sampler_dict[sampler](patch_size)
    # all of these need to be read from model.yaml
    patches_queue = torchio.Queue(
        subjects_dataset,
        max_length=q_max_length,
        samples_per_volume=q_samples_per_volume,
        sampler=sampler,
        num_workers=q_num_workers,
        shuffle_subjects=True,
        shuffle_patches=True,
        verbose=q_verbose,
    )
    return patches_queue
Пример #14
0
def get_loaders(data, cv_split,
        training_transform = False,
        validation_transform = False,
        patch_size = 64,
        patches = False,
        samples_per_volume = 6,
        max_queue_length = 180,
        training_batch_size = 1,
        validation_batch_size = 1,
        mask = False):
    
    """
    The function creates dataloaders 
    
        weights_stem (str): ['full_size', 'patches'] #sizes of training objects
        transform (bool): False # data augmentation
        batch_size (int): 1 # batch sizes for training
        
    """
    
    training_idx, validation_idx = cv_split
    
    print('Training set:', len(training_idx), 'subjects')
    print('Validation set:', len(validation_idx), 'subjects')
    
    training_set = get_torchio_dataset(
        list(data.img_files[training_idx].values), 
        list(data.img_seg[training_idx].values),
        training_transform)
    
    validation_set = get_torchio_dataset(
        list(data.img_files[validation_idx].values), 
        list(data.img_seg[validation_idx].values),
        validation_transform)
    
    if mask:
        # if using masked data for training
        training_set = get_torchio_dataset(
            list(data.img_files[training_idx].values), 
            list(data.img_mask[training_idx].values),
            training_transform)
        
        validation_set = get_torchio_dataset(
            list(data.img_files[validation_idx].values), 
            list(data.img_mask[validation_idx].values),
            validation_transform)
    
    training_loader = torch.utils.data.DataLoader(
        training_set, batch_size=training_batch_size)

    validation_loader = torch.utils.data.DataLoader(
        validation_set, batch_size=validation_batch_size)
    
    if patches:

        patches_training_set = torchio.Queue(
            subjects_dataset=training_set,
            max_length=max_queue_length,
            samples_per_volume=samples_per_volume,
            patch_size=patch_size,
            sampler_class=torchio.sampler.ImageSampler,
            num_workers=multiprocessing.cpu_count(),
            shuffle_subjects=True,
            shuffle_patches=True,
        )

        patches_validation_set = torchio.Queue(
            subjects_dataset=validation_set,
            max_length=max_queue_length,
            samples_per_volume=samples_per_volume,
            patch_size=patch_size,
            sampler_class=torchio.sampler.ImageSampler,
            num_workers=multiprocessing.cpu_count(),
            shuffle_subjects=False,
            shuffle_patches=False,
        )

        training_loader = torch.utils.data.DataLoader(
            patches_training_set, batch_size=training_batch_size)

        validation_loader = torch.utils.data.DataLoader(
            patches_validation_set, batch_size=validation_batch_size)
        
        print('Training loader length:', len(training_loader))
        print('Validation loader length:', len(validation_loader))
    
    return training_loader, validation_loader