def create_patchQs(train_subs, val_subs, patch_size, patch_qlen, patch_per_vol, inference_strides): train_queue = None val_queue = None grid_samplers = [] if train_subs is not None: sampler = tio.data.UniformSampler(patch_size) train_queue = tio.Queue(subjects_dataset=train_subs, max_length=patch_qlen, samples_per_volume=patch_per_vol, sampler=sampler, num_workers=0, start_background=True) if val_subs is not None: stride_length, stride_width, stride_depth = inference_strides.split( ',') overlap = np.subtract( patch_size, (int(stride_length), int(stride_width), int(stride_depth))) for i in range(len(val_subs)): grid_sampler = tio.inference.GridSampler(val_subs[i], patch_size, overlap) grid_samplers.append(grid_sampler) val_queue = torch.utils.data.ConcatDataset(grid_samplers) return train_queue, val_queue, grid_samplers
def train_dataloader(self) -> DataLoader: training_transform = get_train_transforms() train_imageDataset = torchio.ImagesDataset( self.training_subjects, transform=training_transform) patches_training_set = torchio.Queue( subjects_dataset=train_imageDataset, # Maximum number of patches that can be stored in the queue. # Using a large number means that the queue needs to be filled less often, # but more CPU memory is needed to store the patches. max_length=self.max_queue_length, # Number of patches to extract from each volume. # A small number of patches ensures a large variability in the queue, # but training will be slower. samples_per_volume=self.samples_per_volume, # A sampler used to extract patches from the volumes. sampler=torchio.sampler.UniformSampler(self.patch_size), num_workers=self.num_workers, # If True, the subjects dataset is shuffled at the beginning of each epoch, # i.e. when all patches from all subjects have been processed shuffle_subjects=False, # If True, patches are shuffled after filling the queue. shuffle_patches=True, verbose=True, ) training_loader = DataLoader(patches_training_set, batch_size=self.hparams.batch_size) print( f"{ctime()}: getting number of training subjects {len(training_loader)}" ) return training_loader
def plot_batch(sampler): queue = tio.Queue(dataset, max_queue_length, patches_per_volume, sampler) loader = torch.utils.data.DataLoader(queue, batch_size=16) batch = tio.utils.get_first_item(loader) fig, axes = plt.subplots(4, 4, figsize=(12, 10)) for ax, im in zip(axes.flatten(), batch['t1']['data']): ax.imshow(im.squeeze(), cmap='gray') plt.suptitle(sampler.__class__.__name__) plt.tight_layout()
def __init__(self, root_dir, img_range=(0,0)): self.root_dir = root_dir self.img_range = img_range subject_lists = [] #check if there is a labels if self.root_dir[-1] != '/': self.root_dir += '/' self.is_labeled = os.path.isdir(self.root_dir + LABEL_DIR) self.files = [re.findall('[0-9]{4}', filename)[0] for filename in os.listdir(self.root_dir + TRAIN_DIR)] self.files = sorted(self.files, key = lambda f : int(f)) # store all subjects in the list for img_num in range(img_range[0], img_range[1]+1): img_file = os.path.join(self.root_dir, TRAIN_DIR, IMG_PREFIX + self.files[img_num] + EXT) label_file = os.path.join(self.root_dir, LABEL_DIR, LABEL_PREFIX + self.files[img_num] + EXT) subject = torchio.Subject( torchio.Image('t1', img_file, torchio.INTENSITY), torchio.Image('label', label_file, torchio.LABEL) ) subject_lists.append(subject) print(img_file) print(label_file) # Define transforms for data normalization and augmentation mtransforms = ( ZNormalization(), #transforms.RandomNoise(std_range=(0, 0.25)), #transforms.RandomFlip(axes=(0,)), ) self.subjects = torchio.ImagesDataset(subject_lists, transform=transforms.Compose(mtransforms)) self.dataset = torchio.Queue( subjects_dataset=self.subjects, max_length=2, samples_per_volume=675, sampler_class=torchio.sampler.ImageSampler, patch_size=(240, 240, 3), num_workers=4, shuffle_subjects=False, shuffle_patches=True ) print("Dataset details\n Images: {}".format(self.img_range[1] - self.img_range[0] + 1))
def get_data_loader(self, dataset: tio.SubjectsDataset, batch_size: int, num_workers: int): queue = tio.Queue( dataset, max_length=self.max_length, samples_per_volume=self.samples_per_volume, sampler=self.sampler, num_workers=num_workers, ) dataloader = DataLoader(dataset=queue, batch_size=batch_size, collate_fn=no_op) return dataloader
#%% num_workers = 16 print('num_workers : ' + str(num_workers)) patch_size = 96 max_queue_length = 1024 samples_per_volume = 8 batch_size = 1 sampler = tio.data.UniformSampler(patch_size) patches_training_set = tio.Queue( subjects_dataset=training_set, max_length=max_queue_length, samples_per_volume=samples_per_volume, sampler=sampler, num_workers=num_workers, shuffle_subjects=True, shuffle_patches=True, ) patches_validation_set = tio.Queue( subjects_dataset=validation_set, max_length=max_queue_length, samples_per_volume=samples_per_volume, sampler=sampler, num_workers=num_workers, shuffle_subjects=False, shuffle_patches=False, )
def prepare_dataload(patches=True): training_batch_size = 32 validation_batch_size = 2 * training_batch_size patch_size = 25 samples_per_volume = 10 max_queue_length = 300 sampler = tio.data.UniformSampler(patch_size) num_subjects = len(dataset) num_training_subjects = 245 num_validation_subjects = 70 num_test_subjects = 35 num_split_subjects = num_training_subjects, num_validation_subjects, num_test_subjects training_subjects, validation_subjects, test_subjects = torch.utils.data.random_split( dataset, num_split_subjects) training_set = tio.SubjectsDataset(training_subjects, transform=training_transform) validation_set = tio.SubjectsDataset(validation_subjects, transform=validation_transform) test_set = tio.SubjectsDataset(test_subjects, transform=validation_transform) patches_training_set = tio.Queue( subjects_dataset=training_set, max_length=max_queue_length, samples_per_volume=samples_per_volume, sampler=sampler, num_workers=2, shuffle_subjects=True, shuffle_patches=True, ) patches_validation_set = tio.Queue( subjects_dataset=validation_set, max_length=max_queue_length, samples_per_volume=samples_per_volume * 2, sampler=sampler, num_workers=2, shuffle_subjects=False, shuffle_patches=False, ) patches_test_set = tio.Queue( subjects_dataset=test_set, max_length=max_queue_length, samples_per_volume=samples_per_volume * 2, sampler=sampler, num_workers=2, shuffle_subjects=False, shuffle_patches=False, ) training_loader_patches = torch.utils.data.DataLoader( patches_training_set, batch_size=training_batch_size) validation_loader_patches = torch.utils.data.DataLoader( patches_validation_set, batch_size=validation_batch_size) test_loader_patches = torch.utils.data.DataLoader( patches_test_set, batch_size=validation_batch_size) training_loader = torch.utils.data.DataLoader(training_set, batch_size=2) validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=1) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1) if patches: return training_loader_patches, validation_loader_patches, test_loader_patches else: return training_loader, validation_loader, test_loader
def ImagesFromDataFrame(dataframe, psize, headers, q_max_length=10, q_samples_per_volume=1, q_num_workers=2, q_verbose=False, sampler='label', train=True, augmentations=None, preprocessing=None, in_memory=False): # Finding the dimension of the dataframe for computational purposes later num_row, num_col = dataframe.shape # num_channels = num_col - 1 # for non-segmentation tasks, this might be different # changing the column indices to make it easier dataframe.columns = range(0, num_col) dataframe.index = range(0, num_row) # This list will later contain the list of subjects subjects_list = [] channelHeaders = headers['channelHeaders'] labelHeader = headers['labelHeader'] predictionHeaders = headers['predictionHeaders'] subjectIDHeader = headers['subjectIDHeader'] sampler = sampler.lower() # for easier parsing # define the control points and swap axes for augmentation augmentation_patchAxesPoints = copy.deepcopy(psize) for i in range(len(augmentation_patchAxesPoints)): augmentation_patchAxesPoints[i] = max( round(augmentation_patchAxesPoints[i] / 10), 1) # always at least have 1 # iterating through the dataframe resizeCheck = False for patient in range(num_row): # We need this dict for storing the meta data for each subject # such as different image modalities, labels, any other data subject_dict = {} subject_dict['subject_id'] = dataframe[subjectIDHeader][patient] # iterating through the channels/modalities/timepoints of the subject for channel in channelHeaders: # assigning the dict key to the channel if not in_memory: subject_dict[str(channel)] = Image(str( dataframe[channel][patient]), type=torchio.INTENSITY) else: img = sitk.ReadImage(str(dataframe[channel][patient])) array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0) subject_dict[str(channel)] = Image( tensor=array, type=torchio.INTENSITY, path=dataframe[channel][patient]) # if resize has been defined but resample is not (or is none) if not resizeCheck: if not (preprocessing is None) and ('resize' in preprocessing): if (preprocessing['resize'] is not None): resizeCheck = True if not ('resample' in preprocessing): preprocessing['resample'] = {} if not ('resolution' in preprocessing['resample']): preprocessing['resample'][ 'resolution'] = resize_image_resolution( subject_dict[str(channel)].as_sitk(), preprocessing['resize']) else: print( 'WARNING: \'resize\' is ignored as \'resample\' is defined under \'data_processing\', this will be skipped', file=sys.stderr) else: resizeCheck = True # # for regression # if predictionHeaders: # # get the mask # if (subject_dict['label'] is None) and (class_list is not None): # sys.exit('The \'class_list\' parameter has been defined but a label file is not present for patient: ', patient) if labelHeader is not None: if not in_memory: subject_dict['label'] = Image(str( dataframe[labelHeader][patient]), type=torchio.LABEL) else: img = sitk.ReadImage(str(dataframe[labelHeader][patient])) array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0) subject_dict['label'] = Image( tensor=array, type=torchio.LABEL, path=dataframe[labelHeader][patient]) subject_dict['path_to_metadata'] = str( dataframe[labelHeader][patient]) else: subject_dict['label'] = "NA" subject_dict['path_to_metadata'] = str(dataframe[channel][patient]) # iterating through the values to predict of the subject valueCounter = 0 for values in predictionHeaders: # assigning the dict key to the channel subject_dict['value_' + str(valueCounter)] = np.array( dataframe[values][patient]) valueCounter = valueCounter + 1 # Initializing the subject object using the dict subject = Subject(subject_dict) # padding image, but only for label sampler, because we don't want to pad for uniform if 'label' in sampler or 'weight' in sampler: psize_pad = list( np.asarray(np.round(np.divide(psize, 2)), dtype=int)) padder = Pad( psize_pad, padding_mode='symmetric' ) # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html subject = padder(subject) # Appending this subject to the list of subjects subjects_list.append(subject) augmentation_list = [] # first, we want to do thresholding, followed by clipping, if it is present - required for inference as well if not (preprocessing is None): if train: # we want the crop to only happen during training if 'crop_external_zero_planes' in preprocessing: augmentation_list.append( global_preprocessing_dict['crop_external_zero_planes']( psize)) for key in ['threshold', 'clip']: if key in preprocessing: augmentation_list.append(global_preprocessing_dict[key]( min=preprocessing[key]['min'], max=preprocessing[key]['max'])) # first, we want to do the resampling, if it is present - required for inference as well if 'resample' in preprocessing: if 'resolution' in preprocessing['resample']: # resample_split = str(aug).split(':') resample_values = tuple( np.array(preprocessing['resample']['resolution']).astype( np.float)) if len(resample_values) == 2: resample_values = tuple(np.append(resample_values, 1)) augmentation_list.append(Resample(resample_values)) # next, we want to do the intensity normalize - required for inference as well if 'normalize' in preprocessing: augmentation_list.append(global_preprocessing_dict['normalize']) elif 'normalize_nonZero' in preprocessing: augmentation_list.append( global_preprocessing_dict['normalize_nonZero']) elif 'normalize_nonZero_masked' in preprocessing: augmentation_list.append( global_preprocessing_dict['normalize_nonZero_masked']) # other augmentations should only happen for training - and also setting the probabilities # for the augmentations if train and not (augmentations == None): for aug in augmentations: if aug != 'default_probability': actual_function = None if aug == 'flip': if ('axes_to_flip' in augmentations[aug]): print( 'WARNING: \'flip\' augmentation needs the key \'axis\' instead of \'axes_to_flip\'', file=sys.stderr) augmentations[aug]['axis'] = augmentations[aug][ 'axes_to_flip'] actual_function = global_augs_dict[aug]( axes=augmentations[aug]['axis'], p=augmentations[aug]['probability']) elif aug in ['rotate_90', 'rotate_180']: for axis in augmentations[aug]['axis']: augmentation_list.append(global_augs_dict[aug]( axis=axis, p=augmentations[aug]['probability'])) elif aug in ['swap', 'elastic']: actual_function = global_augs_dict[aug]( patch_size=augmentation_patchAxesPoints, p=augmentations[aug]['probability']) elif aug == 'blur': actual_function = global_augs_dict[aug]( std=augmentations[aug]['std'], p=augmentations[aug]['probability']) elif aug == 'noise': actual_function = global_augs_dict[aug]( mean=augmentations[aug]['mean'], std=augmentations[aug]['std'], p=augmentations[aug]['probability']) elif aug == 'anisotropic': actual_function = global_augs_dict[aug]( axes=augmentations[aug]['axis'], downsampling=augmentations[aug]['downsampling'], p=augmentations[aug]['probability']) else: actual_function = global_augs_dict[aug]( p=augmentations[aug]['probability']) if actual_function is not None: augmentation_list.append(actual_function) if augmentation_list: transform = Compose(augmentation_list) else: transform = None subjects_dataset = torchio.SubjectsDataset(subjects_list, transform=transform) if not train: return subjects_dataset if sampler in ('weighted', 'weightedsampler', 'weightedsample'): sampler = global_sampler_dict[sampler](psize, probability_map='label') else: sampler = global_sampler_dict[sampler](psize) # all of these need to be read from model.yaml patches_queue = torchio.Queue(subjects_dataset, max_length=q_max_length, samples_per_volume=q_samples_per_volume, sampler=sampler, num_workers=q_num_workers, shuffle_subjects=True, shuffle_patches=True, verbose=q_verbose) return patches_queue
training_transform = Compose([RescaleIntensity((0, 1)), RandomNoise(p=0.05)]) validation_transform = Compose([RescaleIntensity((0, 1))]) test_transform = Compose([RescaleIntensity((0, 1))]) training_dataset = tio.SubjectsDataset(training_subjects, transform=training_transform) validation_dataset = tio.SubjectsDataset(validation_subjects, transform=validation_transform) test_dataset = tio.SubjectsDataset(test_subjects, transform=test_transform) '''Patching''' patches_training_set = tio.Queue( subjects_dataset=training_dataset, max_length=max_queue_length, samples_per_volume=samples_per_volume, sampler=tio.sampler.UniformSampler(patch_size), # shuffle_subjects=True, # shuffle_patches=True, ) patches_validation_set = tio.Queue( subjects_dataset=validation_dataset, max_length=max_queue_length, samples_per_volume=samples_per_volume * 2, sampler=tio.sampler.UniformSampler(patch_size), # shuffle_subjects=False, # shuffle_patches=False, ) training_loader = torch.utils.data.DataLoader(patches_training_set, batch_size=training_batch_size,
def get_loaders(data, cv_split, training_transform = False, validation_transform = False, patch_size = 64, patches = False, samples_per_volume = 6, max_queue_length = 180, training_batch_size = 1, validation_batch_size = 1, mask = False, input_type = 'T1'): """ Function creates dataloaders Arguments: * data (data_processor.DataMriSegmentation): torchio dataset * cv_split (list): list of two arrays, one with train indexes, other with test indexes * training_transform (bool/torchio.transforms): whether or not to use transform for training images * validation_transform (bool/torchio.transforms): whether or not to use transform for validation images * patch_size (int): size of patches * patches (bool): if True, than patch-based training will be applied https://niftynet.readthedocs.io/en/dev/window_sizes.html - about patch based training * samples_per_volume (int): number of patches to extract from each volume * max_queue_length (int): maximum number of patches that can be stored in the queue * training_batch_size (int): size of batches for training * validation_batch_size (int): size of batches for validation * mask (bool): if True, than masked images will be used Output: * training_loader (torch.utils.data.DataLoader): loader for train * validation_loader (torch.utils.data.DataLoader): loader for test """ training_idx, validation_idx = cv_split mask = data.mask print('Training set:', len(training_idx), 'subjects') print('Validation set:', len(validation_idx), 'subjects') print(f'Input type is {input_type}') if input_type == 'T1': training_set = get_torchio_dataset( list(data.img_files[training_idx].values), list(data.img_seg[training_idx].values), training_transform) validation_set = get_torchio_dataset( list(data.img_files[validation_idx].values), list(data.img_seg[validation_idx].values), validation_transform) if mask in ['bb', 'combined']: print(f'Mask type is {mask}') # if using masked data for training training_set = get_torchio_dataset( list(data.img_files[training_idx].values), list(data.img_mask[training_idx].values), training_transform) validation_set = get_torchio_dataset( list(data.img_files[validation_idx].values), list(data.img_mask[validation_idx].values), validation_transform) training_loader = torch.utils.data.DataLoader( training_set, batch_size = training_batch_size) validation_loader = torch.utils.data.DataLoader( validation_set, batch_size = validation_batch_size) if input_type == 'seg': if mask in ['bb', 'combined']: print(f'Mask type is {mask}') # if using masked data for training training_set = get_torchio_dataset( list(data.img_seg[training_idx].values), list(data.img_mask[training_idx].values), training_transform) validation_set = get_torchio_dataset( list(data.img_seg[validation_idx].values), list(data.img_mask[validation_idx].values), validation_transform) training_loader = torch.utils.data.DataLoader( training_set, batch_size = training_batch_size) validation_loader = torch.utils.data.DataLoader( validation_set, batch_size = validation_batch_size) if patches: # https://niftynet.readthedocs.io/en/dev/window_sizes.html - about patch based training # https://torchio.readthedocs.io/data/patch_training.html - about Queue patches_training_set = torchio.Queue( subjects_dataset = training_set, max_length = max_queue_length, samples_per_volume = samples_per_volume, patch_size = patch_size, sampler_class = torchio.sampler.ImageSampler, num_workers = multiprocessing.cpu_count(), shuffle_subjects = True, shuffle_patches = True, ) patches_validation_set = torchio.Queue( subjects_dataset = validation_set, max_length = max_queue_length, samples_per_volume = samples_per_volume, patch_size = patch_size, sampler_class = torchio.sampler.ImageSampler, num_workers = multiprocessing.cpu_count(), shuffle_subjects = False, shuffle_patches = False, ) training_loader = torch.utils.data.DataLoader( patches_training_set, batch_size = training_batch_size) validation_loader = torch.utils.data.DataLoader( patches_validation_set, batch_size = validation_batch_size) print('Patches mode') print('Training loader length:', len(training_loader)) print('Validation loader length:', len(validation_loader)) return training_loader, validation_loader
import torchio as tio from torch.utils.data import DataLoader import resource import time n_subjects = 16 max_length = 40 samples_per_volume = 5 num_workers = 8 patch_size = 128 batch_size = 2 sampler = tio.data.UniformSampler(patch_size) subject = tio.datasets.Colin27() dataset = tio.SubjectsDataset(n_subjects * [subject]) queue = tio.Queue(dataset, max_length, samples_per_volume, sampler, num_workers) class DummyDataModule(pl.LightningDataModule): def train_dataloader(self): return DataLoader(queue, batch_size=batch_size) class DummyModule(pl.LightningModule): def configure_optimizers(self): pass def training_step(self, *args, **kwargs): #pdb.set_trace() # Use inspect_mem() here. time.sleep(0.1) main_memory = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1000
RandomElasticDeformation(): 0.2, }, p=0.5), # Changed from p=0.75 24/6/20 ]) # Create the datasets training_dataset = torchio.ImagesDataset( [train_subject], transform=training_transform) validation_dataset = torchio.ImagesDataset( [valid_subject]) # Define the queue of sampled patches for training and validation sampler = torchio.data.UniformSampler(PATCH_SIZE) patches_training_set = torchio.Queue( subjects_dataset=training_dataset, max_length=MAX_QUEUE_LENGTH, samples_per_volume=TRAIN_PATCHES, sampler=sampler, num_workers=NUM_WORKERS, shuffle_subjects=False, shuffle_patches=True, ) patches_validation_set = torchio.Queue( subjects_dataset=validation_dataset, max_length=MAX_QUEUE_LENGTH, samples_per_volume=VALID_PATCHES, sampler=sampler, num_workers=NUM_WORKERS, shuffle_subjects=False, shuffle_patches=False, )
def ImagesFromDataFrame( dataframe, parameters, train, apply_zero_crop=False, loader_type="" ): """ Reads the pandas dataframe and gives the dataloader to use for training/validation/testing Parameters ---------- dataframe : pandas.DataFrame The main input dataframe which is calculated after splitting the data CSV parameters : dict The parameters dictionary train : bool If the dataloader is for training or not. For training, the patching infrastructure and data augmentation is applied. apply_zero_crop : bool If enabled, the crop_external_zero_plane is applied. loader_type : str Type of loader for printing. Returns ------- subjects_dataset: torchio.SubjectsDataset This is the output for validation/testing, where patching and data augmentation is disregarded patches_queue: torchio.Queue This is the output for training, which is the subjects_dataset queue after patching and data augmentation is taken into account """ # store in previous variable names patch_size = parameters["patch_size"] headers = parameters["headers"] q_max_length = parameters["q_max_length"] q_samples_per_volume = parameters["q_samples_per_volume"] q_num_workers = parameters["q_num_workers"] q_verbose = parameters["q_verbose"] sampler = parameters["patch_sampler"] augmentations = parameters["data_augmentation"] preprocessing = parameters["data_preprocessing"] in_memory = parameters["in_memory"] enable_padding = parameters["enable_padding"] # Finding the dimension of the dataframe for computational purposes later num_row, num_col = dataframe.shape # changing the column indices to make it easier dataframe.columns = range(0, num_col) dataframe.index = range(0, num_row) # This list will later contain the list of subjects subjects_list = [] subjects_with_error = [] channelHeaders = headers["channelHeaders"] labelHeader = headers["labelHeader"] predictionHeaders = headers["predictionHeaders"] subjectIDHeader = headers["subjectIDHeader"] # this basically means that label sampler is selected with padding if isinstance(sampler, dict): sampler_padding = sampler["label"]["padding_type"] sampler = "label" else: sampler = sampler.lower() # for easier parsing sampler_padding = "symmetric" resize_images_flag = False # if resize has been defined but resample is not (or is none) if not (preprocessing is None): for key in preprocessing.keys(): # check for different resizing keys if key in ["resize", "resize_image", "resize_images"]: if not (preprocessing[key] is None): resize_images_flag = True preprocessing["resize_image"] = preprocessing[key] break # iterating through the dataframe for patient in tqdm( range(num_row), desc="Constructing queue for " + loader_type + " data" ): # We need this dict for storing the meta data for each subject # such as different image modalities, labels, any other data subject_dict = {} subject_dict["subject_id"] = str(dataframe[subjectIDHeader][patient]) skip_subject = False # iterating through the channels/modalities/timepoints of the subject for channel in channelHeaders: # sanity check for malformed csv if not os.path.isfile(str(dataframe[channel][patient])): skip_subject = True subject_dict[str(channel)] = torchio.ScalarImage( dataframe[channel][patient] ) # store image spacing information if not present if "spacing" not in subject_dict: file_reader = sitk.ImageFileReader() file_reader.SetFileName(dataframe[channel][patient]) file_reader.ReadImageInformation() subject_dict["spacing"] = torch.Tensor(file_reader.GetSpacing()) # if resize_image is requested, the perform per-image resize with appropriate interpolator if resize_images_flag: img_resized = resize_image( subject_dict[str(channel)].as_sitk(), preprocessing["resize_image"] ) # always ensure resized image spacing is used subject_dict["spacing"] = torch.Tensor(img_resized.GetSpacing()) subject_dict[str(channel)] = torchio.ScalarImage.from_sitk(img_resized) # # for regression -- this logic needs to be thought through # if predictionHeaders: # # get the mask # if (subject_dict['label'] is None) and (class_list is not None): # sys.exit('The \'class_list\' parameter has been defined but a label file is not present for patient: ', patient) if labelHeader is not None: if not os.path.isfile(str(dataframe[labelHeader][patient])): skip_subject = True subject_dict["label"] = torchio.LabelMap(dataframe[labelHeader][patient]) subject_dict["path_to_metadata"] = str(dataframe[labelHeader][patient]) # if resize is requested, the perform per-image resize with appropriate interpolator if resize_images_flag: img_resized = resize_image( subject_dict["label"].as_sitk(), preprocessing["resize_image"], sitk.sitkNearestNeighbor, ) subject_dict["label"] = torchio.LabelMap.from_sitk(img_resized) else: subject_dict["label"] = "NA" subject_dict["path_to_metadata"] = str(dataframe[channel][patient]) # iterating through the values to predict of the subject valueCounter = 0 for values in predictionHeaders: # assigning the dict key to the channel subject_dict["value_" + str(valueCounter)] = np.array( dataframe[values][patient] ) valueCounter += 1 # skip subject the condition was tripped if not skip_subject: # Initializing the subject object using the dict subject = torchio.Subject(subject_dict) # https://github.com/fepegar/torchio/discussions/587#discussioncomment-928834 # this is causing memory usage to explode, see https://github.com/CBICA/GaNDLF/issues/128 if parameters["verbose"]: print( "Checking consistency of images in subject '" + subject["subject_id"] + "'" ) try: perform_sanity_check_on_subject(subject, parameters) except Exception as e: subjects_with_error.append(subject["subject_id"]) # # padding image, but only for label sampler, because we don't want to pad for uniform if "label" in sampler or "weight" in sampler: if enable_padding: psize_pad = list( np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int) ) # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html padder = Pad(psize_pad, padding_mode=sampler_padding) subject = padder(subject) # load subject into memory: https://github.com/fepegar/torchio/discussions/568#discussioncomment-859027 if in_memory: subject.load() # Appending this subject to the list of subjects subjects_list.append(subject) if subjects_with_error: raise ValueError( "The following subjects could not be loaded, please recheck or remove and retry:", subjects_with_error, ) transformations_list = [] # augmentations are applied to the training set only if train and not (augmentations == None): for aug in augmentations: aug_lower = aug.lower() if aug_lower in global_augs_dict: transformations_list.append( global_augs_dict[aug_lower](augmentations[aug]) ) transform = get_transforms_for_preprocessing( parameters, transformations_list, train, apply_zero_crop ) subjects_dataset = torchio.SubjectsDataset(subjects_list, transform=transform) if not train: return subjects_dataset if sampler in ("weighted", "weightedsampler", "weightedsample"): sampler = global_sampler_dict[sampler](patch_size, probability_map="label") else: sampler = global_sampler_dict[sampler](patch_size) # all of these need to be read from model.yaml patches_queue = torchio.Queue( subjects_dataset, max_length=q_max_length, samples_per_volume=q_samples_per_volume, sampler=sampler, num_workers=q_num_workers, shuffle_subjects=True, shuffle_patches=True, verbose=q_verbose, ) return patches_queue
def get_loaders(data, cv_split, training_transform = False, validation_transform = False, patch_size = 64, patches = False, samples_per_volume = 6, max_queue_length = 180, training_batch_size = 1, validation_batch_size = 1, mask = False): """ The function creates dataloaders weights_stem (str): ['full_size', 'patches'] #sizes of training objects transform (bool): False # data augmentation batch_size (int): 1 # batch sizes for training """ training_idx, validation_idx = cv_split print('Training set:', len(training_idx), 'subjects') print('Validation set:', len(validation_idx), 'subjects') training_set = get_torchio_dataset( list(data.img_files[training_idx].values), list(data.img_seg[training_idx].values), training_transform) validation_set = get_torchio_dataset( list(data.img_files[validation_idx].values), list(data.img_seg[validation_idx].values), validation_transform) if mask: # if using masked data for training training_set = get_torchio_dataset( list(data.img_files[training_idx].values), list(data.img_mask[training_idx].values), training_transform) validation_set = get_torchio_dataset( list(data.img_files[validation_idx].values), list(data.img_mask[validation_idx].values), validation_transform) training_loader = torch.utils.data.DataLoader( training_set, batch_size=training_batch_size) validation_loader = torch.utils.data.DataLoader( validation_set, batch_size=validation_batch_size) if patches: patches_training_set = torchio.Queue( subjects_dataset=training_set, max_length=max_queue_length, samples_per_volume=samples_per_volume, patch_size=patch_size, sampler_class=torchio.sampler.ImageSampler, num_workers=multiprocessing.cpu_count(), shuffle_subjects=True, shuffle_patches=True, ) patches_validation_set = torchio.Queue( subjects_dataset=validation_set, max_length=max_queue_length, samples_per_volume=samples_per_volume, patch_size=patch_size, sampler_class=torchio.sampler.ImageSampler, num_workers=multiprocessing.cpu_count(), shuffle_subjects=False, shuffle_patches=False, ) training_loader = torch.utils.data.DataLoader( patches_training_set, batch_size=training_batch_size) validation_loader = torch.utils.data.DataLoader( patches_validation_set, batch_size=validation_batch_size) print('Training loader length:', len(training_loader)) print('Validation loader length:', len(validation_loader)) return training_loader, validation_loader