def gsd_pCT_valid_transform(self, seed=None): valid_transform = ts.Compose([ ts.ToTensor(), ts.Pad(size=self.scale_size), ts.TypeCast(['float', 'float']), StandardizeImage(norm_flag=[True, True, True, False]), ts.ChannelsFirst(), ts.TypeCast(['float', 'float']) ]) return valid_transform
def gsd_pCT_valid_transform(self, seed=None): valid_transform = ts.Compose([ ts.ToTensor(), ts.Pad(size=self.scale_size), ts.ChannelsFirst(), ts.TypeCast(['float', 'float']), # ts.NormalizeMedicPercentile(norm_flag=(True, False)), ts.NormalizeMedic(norm_flag=(True, False)), # ts.ChannelsLast(), # ts.SpecialCrop(size=self.patch_size, crop_type=0), ts.TypeCast(['float', 'long']) ]) return valid_transform
def cmr_3d_sax_transform(self): train_transform = ts.Compose([ ts.PadNumpy(size=self.scale_size), ts.ToTensor(), ts.ChannelsFirst(), ts.TypeCast(['float', 'float']), ts.RandomFlip(h=True, v=True, p=self.random_flip_prob), ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val, zoom_range=self.scale_val, interp=('bilinear', 'nearest')), #ts.NormalizeMedicPercentile(norm_flag=(True, False)), ts.NormalizeMedic(norm_flag=(True, False)), ts.ChannelsLast(), ts.AddChannel(axis=0), ts.RandomCrop(size=self.patch_size), ts.TypeCast(['float', 'long']) ]) valid_transform = ts.Compose([ ts.PadNumpy(size=self.scale_size), ts.ToTensor(), ts.ChannelsFirst(), ts.TypeCast(['float', 'float']), #ts.NormalizeMedicPercentile(norm_flag=(True, False)), ts.NormalizeMedic(norm_flag=(True, False)), ts.ChannelsLast(), ts.AddChannel(axis=0), ts.SpecialCrop(size=self.patch_size, crop_type=0), ts.TypeCast(['float', 'long']) ]) image_transform = ts.Compose([ ts.PadNumpy(size=self.scale_size), ts.ToTensor(), ts.ChannelsFirst(), ts.ChannelsLast(), ts.AddChannel(axis=0), ts.SpecialCrop(size=self.patch_size, crop_type=0) ]) return { 'train': train_transform, 'valid': valid_transform, 'image': image_transform }
def isles2018_valid_transform(self, seed=None): valid_transform = ts.Compose([ ts.ToTensor(), ts.Pad(size=self.scale_size), ts.ChannelsFirst(), ts.TypeCast(['float', 'long']) ]) return valid_transform
def test_3d_sax_transform(self): test_transform = ts.Compose([ ts.PadFactorNumpy(factor=self.division_factor), ts.ToTensor(), ts.ChannelsFirst(), ts.TypeCast(['float']), #ts.NormalizeMedicPercentile(norm_flag=True), ts.NormalizeMedic(norm_flag=True), ts.ChannelsLast(), ts.AddChannel(axis=0), ]) return {'test': test_transform}
def ultrasound_transform(self): train_transform = ts.Compose([ts.ToTensor(), ts.TypeCast(['float']), ts.AddChannel(axis=0), ts.SpecialCrop(self.patch_size,0), ts.RandomFlip(h=True, v=False, p=self.random_flip_prob), ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val, zoom_range=self.scale_val, interp=('bilinear')), ts.StdNormalize(), ]) valid_transform = ts.Compose([ts.ToTensor(), ts.TypeCast(['float']), ts.AddChannel(axis=0), ts.SpecialCrop(self.patch_size,0), ts.StdNormalize(), ]) return {'train': train_transform, 'valid': valid_transform}
print("random seed: {}".format(c)) comment = "20/200 with parameter per slice and all slices" print(comment) patient_id_G_test = random.sample(patient_id_G, 20) patient_id_H_test = random.sample(patient_id_H, 200) patient_id_G_train = list(patient_id_G.difference(patient_id_G_test)) patient_id_H_train = list(patient_id_H.difference(patient_id_H_test)) transform_pipeline_train = tr.Compose([ AddGaussian(), AddGaussian(ismulti=False), tr.ToTensor(), tr.AddChannel(axis=0), tr.TypeCast('float'), # Attenuation((-.001, .1)), # tr.RangeNormalize(0,1), tr.RandomBrightness(-.2, .2), tr.RandomGamma(.9, 1.1), tr.RandomFlip(), tr.RandomAffine(rotation_range=5, translation_range=0.2, zoom_range=(0.9, 1.1)) ]) transform_pipeline_test = tr.Compose([ tr.ToTensor(), tr.AddChannel(axis=0), tr.TypeCast('float') # tr.RangeNormalize(0, 1)
def __init__(self, root_dir, split, transform=None, preload_data=True, modalities=['7T_T2']): super(CMR3DDataset_MultiClass_MultiProj_infer, self).__init__() # Type of modalities self.TypeOfModal = modalities # For now we assume all projections are axial - no coronal projections image_dir = join(root_dir, split, 'image') self.image_filenames = [] for mod in self.TypeOfModal: if mod == '7T_T2_cor': mod = '7T_T2' tmp_filenames = [ join(image_dir, x) for x in listdir(image_dir) if (is_image_file(x) and x.find(mod) != -1) ] self.image_filenames.append(sorted(tmp_filenames)) # # - DEBUG - # tmp_filenames = [1, 2] # self.image_filenames = [['/home/udall-raid2/DBS_collaborators/DBS_for_orens/DiseaseClassification/inference_db_stn_net/test/image/_7T_T2_Coronal_P0.nii.gz', # Good # '/home/udall-raid2/DBS_collaborators/DBS_for_orens/DiseaseClassification/inference_db_stn_net/test/image/_7T_T2_Coronal_P222.nii.gz', # Bad # '/home/udall-raid2/DBS_collaborators/DBS_for_orens/DiseaseClassification/inference_db_stn_net/test/image/_7T_T2_Coronal_P234.nii.gz']] # Bad # # - DEBUG - # TODO: if mod == '7T_T2' if mod == '7T_T2' or '3T': # This is the reference scan (all patients must have it) - use it to determine how many patients we have self.patient_len = len(tmp_filenames) # Assume we always start from 7T_T2 Axial scans tmp_data = load_nifti_img(self.image_filenames[0][0], dtype=np.int16) self.image_dims = tmp_data[0].shape # report the number of images in the dataset print('Number of {0} images: {1} Patients'.format( split, self.__len__())) # data augmentation # NOTE: in this case, disable the add dimension transform! #self.transform = transform self.transform = ts.Compose( [ts.ToTensor(), ts.TypeCast(['float', 'long'])]) # data load into the ram memory self.t2_headers = [] self.preload_data = preload_data if self.preload_data: print('Preloading the {0} dataset ...'.format(split)) #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list # Concatenate the raw data along the channels dimension self.raw_images = [] for jj in range(len(self.image_filenames[0]) ): # Per each patient, go over all modalities internal_cntr = 0 for ii in range(len(self.image_filenames) ): # Go over all patients, left and right #print('File: {}'.format(self.image_filenames[ii][jj])) # Only for DEBUG if internal_cntr == 0: # First time - should always be T2 axial q_dat, tmp_header, _ = load_nifti_img( self.image_filenames[ii][jj], dtype=np.float32 ) # normalize values to [0,1] range # if self.TypeOfModal[0] == '7T_T2_cor': # # Only for coronal slices # q_dat = np.transpose(q_dat, (0, 2, 1)) tmp_data = np.expand_dims(q_dat / np.max(q_dat.reshape(-1)), axis=0) tmp_name = self.image_filenames[ii][ jj] # For the header file - identification in the multi GPU case else: # Concatenate additional channels q_dat, _, _ = load_nifti_img( self.image_filenames[ii][jj], dtype=np.float32 ) # normalize values to [0,1] range # if self.TypeOfModal[0] == '7T_T2_cor': # # Only for coronal slices # q_dat = np.transpose(q_dat, (0, 2, 1)) concat_data = np.expand_dims(q_dat / np.max(q_dat.reshape(-1)), axis=0) tmp_data = np.concatenate((tmp_data, concat_data), axis=0) internal_cntr += 1 # Add the concatenated multichannel data to the list self.raw_images.append(tmp_data) tmp_header['db_name'] = re.search( '_P(.*).nii.gz', tmp_name).group(1) # Data identifier self.t2_headers.append(tmp_header) print('Loading is done\n')
def __init__(self, root_dir, split, transform=None, preload_data=False, rank=0, world_size=1): super(CMR3DDataset_MultiClass_MultiProj_V2, self).__init__() # TODO: make this an external parameter? internal_hist_augmentation_flag = 0 # TODO: make this an external parameter? #self.TypeOfModal = ['7T_T2', '7T_T1', '7T_DTI_FA'] self.TypeOfModal = ['7T_T2'] self.TypeOfProj = ['Axial', 'Coronal'] # For now we assume all projections are axial - no coronal projections image_dir = join(root_dir, split, 'image') target_dir = join(root_dir, split, 'label') self.image_filenames = [] for mod in self.TypeOfModal: tmp_list = [] for prj in self.TypeOfProj: tmp_str = [ join(image_dir, x) for x in listdir(image_dir) if (is_image_file(x) and x.find(mod) != -1 and x.find(prj) != -1) ] tmp_list.append(sorted(tmp_str)) # if mod == '7T_T2' and prj == 'Axial': # self.patient_len = len(tmp_str) self.image_filenames.append(tmp_list) # Assume we always start from 7T_T2 Axial scans tmp_data = load_nifti_img(self.image_filenames[0][0][0], dtype=np.int16) self.image_dims = tmp_data[0].shape self.target_filenames = sorted([ join(target_dir, x) for x in listdir(target_dir) if is_image_file(x) ]) #assert len(self.image_filenames) == len(self.target_filenames) # Divide data to each rank grp_size = math.ceil(len(self.target_filenames) / world_size) self.image_filenames[0][0] = self.image_filenames[0][0][rank * grp_size: (rank + 1) * grp_size] self.image_filenames[0][1] = self.image_filenames[0][1][rank * grp_size: (rank + 1) * grp_size] self.target_filenames = self.target_filenames[rank * grp_size:(rank + 1) * grp_size] self.patient_len = len(self.target_filenames) # print("len(self.target_filenames (rank {}) = {})".format(rank, len(self.target_filenames ))) # print(self.image_filenames[0][1]) # report the number of images in the dataset print('Number of {0} images: {1} Patients'.format( split, self.__len__())) # data augmentation # NOTE: in this case, disable the add dimension transform! #self.transform = transform #self.transform = ts.TypeCast(['float', 'long']) self.transform = ts.Compose( [ts.ToTensor(), ts.TypeCast(['float', 'long'])]) # data load into the ram memory self.preload_data = preload_data if self.preload_data: print('Preloading the {0} dataset ...'.format(split)) #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list # Concatenate the raw data along the channels dimension self.raw_images = [] # This will be a list of lists # Per each patient, go over all modalities and projections for jj in range(len(self.image_filenames[0][0])): #for jj in range(len([1])): tmp_data_list = [] # Go over all projections for kk in range(len(self.image_filenames[0])): internal_cntr = 0 # Go over all modalities (T2, T1, ...) of similar projection for ii in range(len(self.image_filenames)): # NOTE: Coronal data is already permuted in the correct directions if self.image_filenames[ii][ kk] != []: # Check if the data exists #print(self.image_filenames[ii][kk][jj]) # For DEBUG if internal_cntr == 0: # First time q_dat = load_nifti_img( self.image_filenames[ii][kk][jj], dtype=np.float32)[ 0] # normalize values to [0,1] range q_dat = q_dat / np.max( q_dat.ravel()) # Normalize # Do image histogram augmentation if internal_hist_augmentation_flag == 1: q_dat = adaptive_hist_aug(q_dat) ### TEST #q_dat = q_dat[96-32:96+32, 96-32:96+32, 96-32:96+32] ########## tmp_data = np.expand_dims(q_dat, axis=0) else: # Concatenate additional channels q_dat = load_nifti_img( self.image_filenames[ii][kk][jj], dtype=np.float32)[ 0] # normalize values to [0,1] range q_dat / np.max(q_dat.ravel()) # Do image histogram augmentation if internal_hist_augmentation_flag == 1: q_dat = adaptive_hist_aug(q_dat) ### TEST #q_dat = q_dat[96-32:96+32, 96-32:96+32, 96-32:96+32] ########## concat_data = np.expand_dims(q_dat, axis=0) tmp_data = np.concatenate( (tmp_data, concat_data), axis=0) internal_cntr += 1 # Append for all modalities per same projection tmp_data_list.append(tmp_data) # Add the concatenated multichannel data to the list # [0] - Axial, [1] - Coronal self.raw_images.append(tmp_data_list) self.raw_labels = [ load_nifti_img(ii, dtype=np.uint8)[0] for ii in self.target_filenames ] # ### TEST # self.raw_labels = [] # for ii in self.target_filenames: # tmp_tmp = load_nifti_img(ii, dtype=np.uint8)[0] # self.raw_labels.append(tmp_tmp[96-32:96+32, 96-32:96+32, 96-32:96+32]) # ########## print('Loading is done\n')
def __init__(self, root_dir, split, transform=None, preload_data=True): super(CMR3DDataset_MultiClass_MultiProj_unreg, self).__init__() # TODO: make this a parameter self.TypeOfModal = ['7T_T2', '7T_T1', '7T_DTI_FA'] #self.TypeOfModal = ['7T_T2'] # For now we assume all projections are axial - no coronal projections image_dir = join(root_dir, split, 'image') target_dir = join(root_dir, split, 'label') self.image_filenames = [] for mod in self.TypeOfModal: tmp_filenames = [ join(image_dir, x) for x in listdir(image_dir) if (is_image_file(x) and x.find(mod) != -1) ] self.image_filenames.append(sorted(tmp_filenames)) # TODO: if mod == '7T_T2' if mod == '7T_T2': # This is the reference scan (all patients must have it) - use it to determine how many patients we have self.patient_len = len(tmp_filenames) self.target_filenames = sorted([ join(target_dir, x) for x in listdir(target_dir) if is_image_file(x) ]) #assert len(self.image_filenames) == len(self.target_filenames) # Assume we always start from 7T_T2 Axial scans tmp_data, meta = load_nifti_img(self.image_filenames[0][0], dtype=np.int16) #self.image_dims = tmp_data[0].shape self.image_dims = tmp_data.shape # report the number of images in the dataset print('Number of {0} images: {1} Patients'.format( split, self.__len__())) # data augmentation # NOTE: in this case, disable the add dimension transform! #self.transform = transform self.transform = ts.Compose( [ts.ToTensor(), ts.TypeCast(['float', 'long'])]) # data load into the ram memory self.preload_data = preload_data if self.preload_data: print('Preloading the {0} dataset ...'.format(split)) #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list # Concatenate the raw data along the channels dimension self.raw_images = [] for jj in range(len(self.image_filenames[0]) ): # Per each patient, go over all modalities tmp_data = [] for ii in range(len( self.image_filenames)): # Go over all patients #print('File: {}'.format(self.image_filenames[ii][jj])) # Only for DEBUG q_dat = load_nifti_img( self.image_filenames[ii][jj], dtype=np.float32)[0] # normalize values to [0,1] range tmp_data.append( np.expand_dims(q_dat / np.max(q_dat.reshape(-1)), axis=0)) # Add the concatenated multichannel data to the list self.raw_images.append(tmp_data) self.raw_labels = [ load_nifti_img(ii, dtype=np.uint8)[0] for ii in self.target_filenames ] print('Loading is done\n')
def __init__(self, root_dir, split, transform=None, preload_data=False): super(CMR3DDataset_t2_reg, self).__init__() # TODO: make this a parameter self.TypeOfModal = ['7T_T2'] self.TypeOfProj = ['Axial', 'Coronal'] # For now we assume all projections are axial - no coronal projections image_dir = join(root_dir, split, 'image') target_dir = join(root_dir, split, 'label') self.image_filenames = [] for mod in self.TypeOfModal: tmp_list = [] for prj in self.TypeOfProj: tmp_str = [ join(image_dir, x) for x in listdir(image_dir) if (is_image_file(x) and x.find(mod) != -1 and x.find(prj) != -1) ] tmp_list.append(sorted(tmp_str)) if mod == '7T_T2' and prj == 'Axial': self.patient_len = len(tmp_str) self.image_filenames.append(tmp_list) # Assume we always start from 7T_T2 Axial scans tmp_data = load_nifti_img(self.image_filenames[0][0][0], dtype=np.int16) self.image_dims = tmp_data[0].shape self.target_filenames = [] # No labels for this project # report the number of images in the dataset print('Number of {0} images: {1} Patients'.format( split, self.__len__())) # data augmentation # NOTE: in this case, disable the add dimension transform! #self.transform = transform #self.transform = ts.TypeCast(['float', 'long']) self.transform = ts.Compose( [ts.ToTensor(), ts.TypeCast(['float', 'long'])]) # data load into the ram memory self.preload_data = preload_data if self.preload_data: print('Preloading the {0} dataset ...'.format(split)) #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list # Concatenate the raw data along the channels dimension self.raw_images = [] # This will be a list of lists # Per each patient, go over all modalities and projections #for jj in range(len(self.image_filenames[0][0])): # REAL for jj in range(len([0, 1])): # DEBUG tmp_data_list = [] # Go over all projections for kk in range(len(self.image_filenames[0])): internal_cntr = 0 # Go over all modalities (T2, T1, ...) of similar projection for ii in range(len(self.image_filenames)): # NOTE: Coronal data is already permuted in the correct directions if self.image_filenames[ii][ kk] != []: # Check if the data exists #print(self.image_filenames[ii][kk][jj]) # For DEBUG if internal_cntr == 0: # First time q_dat = load_nifti_img( self.image_filenames[ii][kk][jj], dtype=np.float32)[ 0] # normalize values to [0,1] range q_dat = q_dat / np.max( q_dat.reshape(-1)) # Normalize q_dat = self.zero_pad(q_dat) # zero pad tmp_data = np.expand_dims(q_dat, axis=0) else: # Concatenate additional channels q_dat = load_nifti_img( self.image_filenames[ii][kk][jj], dtype=np.float32)[ 0] # normalize values to [0,1] range q_dat = q_dat / np.max(q_dat.reshape(-1)) q_dat = self.zero_pad(q_dat) concat_data = np.expand_dims(q_dat, axis=0) tmp_data = np.concatenate( (tmp_data, concat_data), axis=0) internal_cntr += 1 # Append for all modalities per same projection tmp_data_list.append(tmp_data) # Add the concatenated multichannel data to the list # [0] - Axial, [1] - Coronal self.raw_images.append(tmp_data_list) self.raw_labels = [ f for f in range(len(self.image_filenames) + 1) ] # Dummy: no labels for this project print('Loading is done\n')
for epoch in range(1, nb_epoch + 1): seg_train(epoch, unet_model, seg_train_loader, criterion, optimizer) # print("Test AUC:", auc_cal(model, testloader)) # test(model, testloader) torch.save(unet_model, os.path.join(model_repo_dir, 'unet.pt')) #Training UNET END #Training Diag network START transform_pipeline_train = tr.Compose( [ # AddGaussian(), # AddGaussian(ismulti=False), tr.ToTensor(), tr.AddChannel(axis=0), tr.TypeCast('float'), # Attenuation((-.001, .1)), # tr.RangeNormalize(0,1), tr.RandomBrightness(-.2, .2), tr.RandomGamma(.9, 1.1), tr.RandomFlip(), tr.RandomAffine(rotation_range=5, translation_range=0.2 # zoom_range=(0.9, 1.1) )]) transform_pipeline_test = tr.Compose([tr.ToTensor(), tr.AddChannel(axis=0), tr.TypeCast('float') # tr.RangeNormalize(0, 1) ]) transformed_images = Beijing_diag_dataset(root_dir, patient_id_G_train, patient_id_H_train, resize_dim= (96,288), transform=transform_pipeline_train)
def gsd_pCT_transform(self): ''' Data augmentation transformations for the Geneva Stroke dataset (pCT maps) :return: ''' train_transform = ts.Compose([ ts.ToTensor(), ts.Pad(size=self.scale_size), ts.TypeCast(['float', 'float']), ts.RandomFlip(h=True, v=True, p=self.random_flip_prob), # Todo Random Affine doesn't support channels --> try newer version of torchsample or torchvision # ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val, # zoom_range=self.scale_val, interp=('bilinear', 'nearest')), ts.ChannelsFirst(), #ts.NormalizeMedicPercentile(norm_flag=(True, False)), # Todo apply channel wise normalisation ts.NormalizeMedic(norm_flag=(True, False)), # Todo fork torchsample and fix the Random Crop bug # ts.ChannelsLast(), # seems to be needed for crop # ts.RandomCrop(size=self.patch_size), ts.TypeCast(['float', 'long']) ]) valid_transform = ts.Compose([ ts.ToTensor(), ts.Pad(size=self.scale_size), ts.ChannelsFirst(), ts.TypeCast(['float', 'float']), #ts.NormalizeMedicPercentile(norm_flag=(True, False)), ts.NormalizeMedic(norm_flag=(True, False)), # ts.ChannelsLast(), # ts.SpecialCrop(size=self.patch_size, crop_type=0), ts.TypeCast(['float', 'long']) ]) # train_transform = ts.Compose([ # ts.ToTensor(), # ts.Pad(size=self.scale_size), # ts.ChannelsFirst(), # ts.TypeCast(['float', 'long']) # ]) # valid_transform = ts.Compose([ # ts.ToTensor(), # ts.Pad(size=self.scale_size), # ts.ChannelsFirst(), # ts.TypeCast(['float', 'long']) # # ]) # train_transform = tf.Compose([ # tf.Pad(1), # tf.Lambda(lambda a: a.permute(3, 0, 1, 2)), # tf.Lambda(lambda a: a.float()), # ]) # valid_transform = tf.Compose([ # tf.Pad(1), # tf.Lambda(lambda a: a.permute(3, 0, 1, 2)), # tf.Lambda(lambda a: a.float()), # # ]) return {'train': train_transform, 'valid': valid_transform}
def __init__(self, root_dir, split, transform=None, preload_data=True, modalities=['7T_T2'], rank=0): super(CMR3DDataset_MultiClass_MultiProj, self).__init__() # TODO: make this a parameter #self.TypeOfModal = ['7T_DTI_B0', '7T_DTI_FA'] # If we use B0 as well for the Thalamus seg ### --- #self.TypeOfModal = ['7T_T2'] # For T2 axial #self.TypeOfModal = ['7T_T2_cor'] # For T2 coronal #self.TypeOfModal = ['7T_SWI'] self.TypeOfModal = modalities if rank == 0: print("Modalities: {}".format(self.TypeOfModal)) # For now we assume all projections are axial - no coronal projections image_dir = join(root_dir, split, 'image') target_dir = join(root_dir, split, 'label') self.image_filenames = [] for mod in self.TypeOfModal: if mod == '7T_T2_cor': mod = '7T_T2' tmp_filenames = [ join(image_dir, x) for x in listdir(image_dir) if (is_image_file(x) and x.find(mod) != -1) ] self.image_filenames.append(sorted(tmp_filenames)) # TODO: if mod == '7T_T2' if mod == '7T_T2' or mod == '7T_SWI': # This is the reference scan (all patients must have it) - use it to determine how many patients we have self.patient_len = len(tmp_filenames) elif mod == '7T_T1': # Secondary priority self.patient_len = len(tmp_filenames) elif mod == '7T_DTI_B0': # Tertiary priority self.patient_len = len(tmp_filenames) elif mod == '3T_T2': # Fourth priority self.patient_len = len(tmp_filenames) self.target_filenames = sorted([ join(target_dir, x) for x in listdir(target_dir) if is_image_file(x) ]) #assert len(self.image_filenames) == len(self.target_filenames) if rank == 0: print("\n".join(self.target_filenames)) # Assume we always start from 7T_T2 Axial scans tmp_data = load_nifti_img(self.image_filenames[0][0], dtype=np.int16) self.image_dims = tmp_data[0].shape # report the number of images in the dataset if rank == 0: print('Number of {0} images: {1} Patients'.format( split, self.__len__())) # data augmentation # NOTE: in this case, disable the add dimension transform! #self.transform = transform self.transform = ts.Compose( [ts.ToTensor(), ts.TypeCast(['float', 'long'])]) # data load into the ram memory self.t2_headers = [] self.preload_data = preload_data if self.preload_data: if rank == 0: print('Preloading the {0} dataset ...'.format(split)) #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list # Concatenate the raw data along the channels dimension self.raw_images = [] for jj in range(len(self.image_filenames[0]) ): # Per each patient, go over all modalities internal_cntr = 0 for ii in range(len(self.image_filenames) ): # Go over all patients, left and right #print('File: {}'.format(self.image_filenames[ii][jj])) # Only for DEBUG if internal_cntr == 0: # First time q_dat, tmp_header, _ = load_nifti_img( self.image_filenames[ii][jj], dtype=np.float32 ) # normalize values to [0,1] range # if self.TypeOfModal[0] == '7T_T2_cor': # # Only for coronal slices # q_dat = np.transpose(q_dat, (0, 2, 1)) tmp_data = np.expand_dims(q_dat / np.max(q_dat.reshape(-1)), axis=0) tmp_name = self.image_filenames[ii][ jj] # For the header file - identification in the multi GPU case else: # Concatenate additional channels q_dat = load_nifti_img( self.image_filenames[ii][jj], dtype=np.float32)[ 0] # normalize values to [0,1] range # if self.TypeOfModal[0] == '7T_T2_cor': # # Only for coronal slices # q_dat = np.transpose(q_dat, (0, 2, 1)) concat_data = np.expand_dims(q_dat / np.max(q_dat.reshape(-1)), axis=0) tmp_data = np.concatenate((tmp_data, concat_data), axis=0) internal_cntr += 1 # Add the concatenated multichannel data to the list self.raw_images.append(tmp_data) tmp_header['db_name'] = re.search( '_P(.*).nii.gz', tmp_name).group(1) # Data identifier self.t2_headers.append(tmp_header) # Load labels #self.raw_labels = [load_nifti_img(ii, dtype=np.uint8)[0] for ii in self.target_filenames] self.raw_labels = [] for ii in self.target_filenames: label_tmp = load_nifti_img(ii, dtype=np.uint8)[0] # if self.TypeOfModal[0] == '7T_T2_cor': # # Only for coronal slices # label_tmp = np.transpose(label_tmp, (0, 2, 1)) self.raw_labels.append(label_tmp) if rank == 0: print('Loading is done\n')