Esempio n. 1
0
    def gsd_pCT_valid_transform(self, seed=None):
        valid_transform = ts.Compose([
            ts.ToTensor(),
            ts.Pad(size=self.scale_size),
            ts.TypeCast(['float', 'float']),
            StandardizeImage(norm_flag=[True, True, True, False]),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float'])
        ])

        return valid_transform
Esempio n. 2
0
    def gsd_pCT_valid_transform(self, seed=None):
        valid_transform = ts.Compose([
            ts.ToTensor(),
            ts.Pad(size=self.scale_size),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float']),
            # ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            ts.NormalizeMedic(norm_flag=(True, False)),
            # ts.ChannelsLast(),
            # ts.SpecialCrop(size=self.patch_size, crop_type=0),
            ts.TypeCast(['float', 'long'])
        ])

        return valid_transform
Esempio n. 3
0
    def cmr_3d_sax_transform(self):

        train_transform = ts.Compose([
            ts.PadNumpy(size=self.scale_size),
            ts.ToTensor(),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float']),
            ts.RandomFlip(h=True, v=True, p=self.random_flip_prob),
            ts.RandomAffine(rotation_range=self.rotate_val,
                            translation_range=self.shift_val,
                            zoom_range=self.scale_val,
                            interp=('bilinear', 'nearest')),
            #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            ts.NormalizeMedic(norm_flag=(True, False)),
            ts.ChannelsLast(),
            ts.AddChannel(axis=0),
            ts.RandomCrop(size=self.patch_size),
            ts.TypeCast(['float', 'long'])
        ])

        valid_transform = ts.Compose([
            ts.PadNumpy(size=self.scale_size),
            ts.ToTensor(),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float']),
            #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            ts.NormalizeMedic(norm_flag=(True, False)),
            ts.ChannelsLast(),
            ts.AddChannel(axis=0),
            ts.SpecialCrop(size=self.patch_size, crop_type=0),
            ts.TypeCast(['float', 'long'])
        ])
        image_transform = ts.Compose([
            ts.PadNumpy(size=self.scale_size),
            ts.ToTensor(),
            ts.ChannelsFirst(),
            ts.ChannelsLast(),
            ts.AddChannel(axis=0),
            ts.SpecialCrop(size=self.patch_size, crop_type=0)
        ])

        return {
            'train': train_transform,
            'valid': valid_transform,
            'image': image_transform
        }
Esempio n. 4
0
 def isles2018_valid_transform(self, seed=None):
     valid_transform = ts.Compose([
         ts.ToTensor(),
         ts.Pad(size=self.scale_size),
         ts.ChannelsFirst(),
         ts.TypeCast(['float', 'long'])
     ])
     return valid_transform
Esempio n. 5
0
    def test_3d_sax_transform(self):
        test_transform = ts.Compose([
            ts.PadFactorNumpy(factor=self.division_factor),
            ts.ToTensor(),
            ts.ChannelsFirst(),
            ts.TypeCast(['float']),
            #ts.NormalizeMedicPercentile(norm_flag=True),
            ts.NormalizeMedic(norm_flag=True),
            ts.ChannelsLast(),
            ts.AddChannel(axis=0),
        ])

        return {'test': test_transform}
    def ultrasound_transform(self):

        train_transform = ts.Compose([ts.ToTensor(),
                                      ts.TypeCast(['float']),
                                      ts.AddChannel(axis=0),
                                      ts.SpecialCrop(self.patch_size,0),
                                      ts.RandomFlip(h=True, v=False, p=self.random_flip_prob),
                                      ts.RandomAffine(rotation_range=self.rotate_val,
                                                      translation_range=self.shift_val,
                                                      zoom_range=self.scale_val,
                                                      interp=('bilinear')),
                                      ts.StdNormalize(),
                                ])

        valid_transform = ts.Compose([ts.ToTensor(),
                                      ts.TypeCast(['float']),
                                      ts.AddChannel(axis=0),
                                      ts.SpecialCrop(self.patch_size,0),
                                      ts.StdNormalize(),
                                ])

        return {'train': train_transform, 'valid': valid_transform}
Esempio n. 7
0
    print("random seed: {}".format(c))
    comment = "20/200 with parameter per slice and all slices"
    print(comment)

    patient_id_G_test = random.sample(patient_id_G, 20)
    patient_id_H_test = random.sample(patient_id_H, 200)

    patient_id_G_train = list(patient_id_G.difference(patient_id_G_test))
    patient_id_H_train = list(patient_id_H.difference(patient_id_H_test))

    transform_pipeline_train = tr.Compose([
        AddGaussian(),
        AddGaussian(ismulti=False),
        tr.ToTensor(),
        tr.AddChannel(axis=0),
        tr.TypeCast('float'),
        # Attenuation((-.001, .1)),
        # tr.RangeNormalize(0,1),
        tr.RandomBrightness(-.2, .2),
        tr.RandomGamma(.9, 1.1),
        tr.RandomFlip(),
        tr.RandomAffine(rotation_range=5,
                        translation_range=0.2,
                        zoom_range=(0.9, 1.1))
    ])

    transform_pipeline_test = tr.Compose([
        tr.ToTensor(),
        tr.AddChannel(axis=0),
        tr.TypeCast('float')
        # tr.RangeNormalize(0, 1)
Esempio n. 8
0
    def __init__(self,
                 root_dir,
                 split,
                 transform=None,
                 preload_data=True,
                 modalities=['7T_T2']):
        super(CMR3DDataset_MultiClass_MultiProj_infer, self).__init__()

        # Type of modalities
        self.TypeOfModal = modalities

        # For now we assume all projections are axial - no coronal projections
        image_dir = join(root_dir, split, 'image')

        self.image_filenames = []
        for mod in self.TypeOfModal:
            if mod == '7T_T2_cor':
                mod = '7T_T2'
            tmp_filenames = [
                join(image_dir, x) for x in listdir(image_dir)
                if (is_image_file(x) and x.find(mod) != -1)
            ]
            self.image_filenames.append(sorted(tmp_filenames))

            # # - DEBUG -
            # tmp_filenames = [1, 2]
            # self.image_filenames = [['/home/udall-raid2/DBS_collaborators/DBS_for_orens/DiseaseClassification/inference_db_stn_net/test/image/_7T_T2_Coronal_P0.nii.gz', # Good
            #                         '/home/udall-raid2/DBS_collaborators/DBS_for_orens/DiseaseClassification/inference_db_stn_net/test/image/_7T_T2_Coronal_P222.nii.gz', # Bad
            #                         '/home/udall-raid2/DBS_collaborators/DBS_for_orens/DiseaseClassification/inference_db_stn_net/test/image/_7T_T2_Coronal_P234.nii.gz']] # Bad
            # # - DEBUG -

            # TODO: if mod == '7T_T2'
            if mod == '7T_T2' or '3T':  # This is the reference scan (all patients must have it) - use it to determine how many patients we have
                self.patient_len = len(tmp_filenames)

        # Assume we always start from 7T_T2 Axial scans
        tmp_data = load_nifti_img(self.image_filenames[0][0], dtype=np.int16)
        self.image_dims = tmp_data[0].shape

        # report the number of images in the dataset
        print('Number of {0} images: {1} Patients'.format(
            split, self.__len__()))

        # data augmentation
        # NOTE: in this case, disable the add dimension transform!
        #self.transform = transform
        self.transform = ts.Compose(
            [ts.ToTensor(), ts.TypeCast(['float', 'long'])])

        # data load into the ram memory
        self.t2_headers = []
        self.preload_data = preload_data
        if self.preload_data:
            print('Preloading the {0} dataset ...'.format(split))
            #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list

            # Concatenate the raw data along the channels dimension
            self.raw_images = []
            for jj in range(len(self.image_filenames[0])
                            ):  # Per each patient, go over all modalities
                internal_cntr = 0
                for ii in range(len(self.image_filenames)
                                ):  # Go over all patients, left and right
                    #print('File: {}'.format(self.image_filenames[ii][jj])) # Only for DEBUG
                    if internal_cntr == 0:  # First time - should always be T2 axial
                        q_dat, tmp_header, _ = load_nifti_img(
                            self.image_filenames[ii][jj], dtype=np.float32
                        )  # normalize values to [0,1] range

                        # if self.TypeOfModal[0] == '7T_T2_cor':
                        #     # Only for coronal slices
                        #     q_dat = np.transpose(q_dat, (0, 2, 1))

                        tmp_data = np.expand_dims(q_dat /
                                                  np.max(q_dat.reshape(-1)),
                                                  axis=0)
                        tmp_name = self.image_filenames[ii][
                            jj]  # For the header file - identification in the multi GPU case
                    else:  # Concatenate additional channels
                        q_dat, _, _ = load_nifti_img(
                            self.image_filenames[ii][jj], dtype=np.float32
                        )  # normalize values to [0,1] range

                        # if self.TypeOfModal[0] == '7T_T2_cor':
                        #     # Only for coronal slices
                        #     q_dat = np.transpose(q_dat, (0, 2, 1))

                        concat_data = np.expand_dims(q_dat /
                                                     np.max(q_dat.reshape(-1)),
                                                     axis=0)
                        tmp_data = np.concatenate((tmp_data, concat_data),
                                                  axis=0)
                    internal_cntr += 1

                # Add the concatenated multichannel data to the list
                self.raw_images.append(tmp_data)
                tmp_header['db_name'] = re.search(
                    '_P(.*).nii.gz', tmp_name).group(1)  # Data identifier
                self.t2_headers.append(tmp_header)

            print('Loading is done\n')
Esempio n. 9
0
    def __init__(self,
                 root_dir,
                 split,
                 transform=None,
                 preload_data=False,
                 rank=0,
                 world_size=1):
        super(CMR3DDataset_MultiClass_MultiProj_V2, self).__init__()

        # TODO: make this an external parameter?
        internal_hist_augmentation_flag = 0

        # TODO: make this an external parameter?
        #self.TypeOfModal = ['7T_T2', '7T_T1', '7T_DTI_FA']
        self.TypeOfModal = ['7T_T2']
        self.TypeOfProj = ['Axial', 'Coronal']

        # For now we assume all projections are axial - no coronal projections
        image_dir = join(root_dir, split, 'image')
        target_dir = join(root_dir, split, 'label')

        self.image_filenames = []
        for mod in self.TypeOfModal:
            tmp_list = []
            for prj in self.TypeOfProj:
                tmp_str = [
                    join(image_dir, x) for x in listdir(image_dir)
                    if (is_image_file(x) and x.find(mod) != -1
                        and x.find(prj) != -1)
                ]
                tmp_list.append(sorted(tmp_str))

                # if mod == '7T_T2' and prj == 'Axial':
                #     self.patient_len = len(tmp_str)

            self.image_filenames.append(tmp_list)

        # Assume we always start from 7T_T2 Axial scans
        tmp_data = load_nifti_img(self.image_filenames[0][0][0],
                                  dtype=np.int16)
        self.image_dims = tmp_data[0].shape

        self.target_filenames = sorted([
            join(target_dir, x) for x in listdir(target_dir)
            if is_image_file(x)
        ])
        #assert len(self.image_filenames) == len(self.target_filenames)

        # Divide data to each rank
        grp_size = math.ceil(len(self.target_filenames) / world_size)
        self.image_filenames[0][0] = self.image_filenames[0][0][rank *
                                                                grp_size:
                                                                (rank + 1) *
                                                                grp_size]
        self.image_filenames[0][1] = self.image_filenames[0][1][rank *
                                                                grp_size:
                                                                (rank + 1) *
                                                                grp_size]
        self.target_filenames = self.target_filenames[rank *
                                                      grp_size:(rank + 1) *
                                                      grp_size]

        self.patient_len = len(self.target_filenames)

        # print("len(self.target_filenames (rank {}) = {})".format(rank, len(self.target_filenames )))
        # print(self.image_filenames[0][1])

        # report the number of images in the dataset
        print('Number of {0} images: {1} Patients'.format(
            split, self.__len__()))

        # data augmentation
        # NOTE: in this case, disable the add dimension transform!
        #self.transform = transform
        #self.transform = ts.TypeCast(['float', 'long'])
        self.transform = ts.Compose(
            [ts.ToTensor(), ts.TypeCast(['float', 'long'])])

        # data load into the ram memory
        self.preload_data = preload_data
        if self.preload_data:
            print('Preloading the {0} dataset ...'.format(split))
            #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list

            # Concatenate the raw data along the channels dimension
            self.raw_images = []  # This will be a list of lists
            # Per each patient, go over all modalities and projections
            for jj in range(len(self.image_filenames[0][0])):
                #for jj in range(len([1])):
                tmp_data_list = []
                # Go over all projections
                for kk in range(len(self.image_filenames[0])):
                    internal_cntr = 0
                    # Go over all modalities (T2, T1, ...) of similar projection
                    for ii in range(len(self.image_filenames)):
                        # NOTE: Coronal data is already permuted in the correct directions
                        if self.image_filenames[ii][
                                kk] != []:  # Check if the data exists
                            #print(self.image_filenames[ii][kk][jj]) # For DEBUG
                            if internal_cntr == 0:  # First time
                                q_dat = load_nifti_img(
                                    self.image_filenames[ii][kk][jj],
                                    dtype=np.float32)[
                                        0]  # normalize values to [0,1] range
                                q_dat = q_dat / np.max(
                                    q_dat.ravel())  # Normalize

                                # Do image histogram augmentation
                                if internal_hist_augmentation_flag == 1:
                                    q_dat = adaptive_hist_aug(q_dat)

                                ### TEST
                                #q_dat = q_dat[96-32:96+32, 96-32:96+32, 96-32:96+32]
                                ##########

                                tmp_data = np.expand_dims(q_dat, axis=0)
                            else:  # Concatenate additional channels
                                q_dat = load_nifti_img(
                                    self.image_filenames[ii][kk][jj],
                                    dtype=np.float32)[
                                        0]  # normalize values to [0,1] range
                                q_dat / np.max(q_dat.ravel())

                                # Do image histogram augmentation
                                if internal_hist_augmentation_flag == 1:
                                    q_dat = adaptive_hist_aug(q_dat)

                                ### TEST
                                #q_dat = q_dat[96-32:96+32, 96-32:96+32, 96-32:96+32]
                                ##########

                                concat_data = np.expand_dims(q_dat, axis=0)
                                tmp_data = np.concatenate(
                                    (tmp_data, concat_data), axis=0)
                            internal_cntr += 1

                    # Append for all modalities per same projection
                    tmp_data_list.append(tmp_data)

                # Add the concatenated multichannel data to the list
                # [0] - Axial, [1] - Coronal
                self.raw_images.append(tmp_data_list)

            self.raw_labels = [
                load_nifti_img(ii, dtype=np.uint8)[0]
                for ii in self.target_filenames
            ]
            # ### TEST
            # self.raw_labels = []
            # for ii in self.target_filenames:
            #     tmp_tmp = load_nifti_img(ii, dtype=np.uint8)[0]
            #     self.raw_labels.append(tmp_tmp[96-32:96+32, 96-32:96+32, 96-32:96+32])
            # ##########

            print('Loading is done\n')
Esempio n. 10
0
    def __init__(self, root_dir, split, transform=None, preload_data=True):
        super(CMR3DDataset_MultiClass_MultiProj_unreg, self).__init__()

        # TODO: make this a parameter
        self.TypeOfModal = ['7T_T2', '7T_T1', '7T_DTI_FA']
        #self.TypeOfModal = ['7T_T2']

        # For now we assume all projections are axial - no coronal projections
        image_dir = join(root_dir, split, 'image')
        target_dir = join(root_dir, split, 'label')

        self.image_filenames = []
        for mod in self.TypeOfModal:
            tmp_filenames = [
                join(image_dir, x) for x in listdir(image_dir)
                if (is_image_file(x) and x.find(mod) != -1)
            ]
            self.image_filenames.append(sorted(tmp_filenames))
            # TODO: if mod == '7T_T2'
            if mod == '7T_T2':  # This is the reference scan (all patients must have it) - use it to determine how many patients we have
                self.patient_len = len(tmp_filenames)

        self.target_filenames = sorted([
            join(target_dir, x) for x in listdir(target_dir)
            if is_image_file(x)
        ])
        #assert len(self.image_filenames) == len(self.target_filenames)

        # Assume we always start from 7T_T2 Axial scans
        tmp_data, meta = load_nifti_img(self.image_filenames[0][0],
                                        dtype=np.int16)
        #self.image_dims = tmp_data[0].shape
        self.image_dims = tmp_data.shape

        # report the number of images in the dataset
        print('Number of {0} images: {1} Patients'.format(
            split, self.__len__()))

        # data augmentation
        # NOTE: in this case, disable the add dimension transform!
        #self.transform = transform
        self.transform = ts.Compose(
            [ts.ToTensor(), ts.TypeCast(['float', 'long'])])

        # data load into the ram memory
        self.preload_data = preload_data
        if self.preload_data:
            print('Preloading the {0} dataset ...'.format(split))
            #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list

            # Concatenate the raw data along the channels dimension
            self.raw_images = []
            for jj in range(len(self.image_filenames[0])
                            ):  # Per each patient, go over all modalities
                tmp_data = []
                for ii in range(len(
                        self.image_filenames)):  # Go over all patients
                    #print('File: {}'.format(self.image_filenames[ii][jj])) # Only for DEBUG
                    q_dat = load_nifti_img(
                        self.image_filenames[ii][jj],
                        dtype=np.float32)[0]  # normalize values to [0,1] range
                    tmp_data.append(
                        np.expand_dims(q_dat / np.max(q_dat.reshape(-1)),
                                       axis=0))

                # Add the concatenated multichannel data to the list
                self.raw_images.append(tmp_data)

            self.raw_labels = [
                load_nifti_img(ii, dtype=np.uint8)[0]
                for ii in self.target_filenames
            ]
            print('Loading is done\n')
Esempio n. 11
0
    def __init__(self, root_dir, split, transform=None, preload_data=False):
        super(CMR3DDataset_t2_reg, self).__init__()

        # TODO: make this a parameter
        self.TypeOfModal = ['7T_T2']
        self.TypeOfProj = ['Axial', 'Coronal']

        # For now we assume all projections are axial - no coronal projections
        image_dir = join(root_dir, split, 'image')
        target_dir = join(root_dir, split, 'label')

        self.image_filenames = []
        for mod in self.TypeOfModal:
            tmp_list = []
            for prj in self.TypeOfProj:
                tmp_str = [
                    join(image_dir, x) for x in listdir(image_dir)
                    if (is_image_file(x) and x.find(mod) != -1
                        and x.find(prj) != -1)
                ]
                tmp_list.append(sorted(tmp_str))

                if mod == '7T_T2' and prj == 'Axial':
                    self.patient_len = len(tmp_str)

            self.image_filenames.append(tmp_list)

        # Assume we always start from 7T_T2 Axial scans
        tmp_data = load_nifti_img(self.image_filenames[0][0][0],
                                  dtype=np.int16)
        self.image_dims = tmp_data[0].shape

        self.target_filenames = []  # No labels for this project

        # report the number of images in the dataset
        print('Number of {0} images: {1} Patients'.format(
            split, self.__len__()))

        # data augmentation
        # NOTE: in this case, disable the add dimension transform!
        #self.transform = transform
        #self.transform = ts.TypeCast(['float', 'long'])
        self.transform = ts.Compose(
            [ts.ToTensor(), ts.TypeCast(['float', 'long'])])

        # data load into the ram memory
        self.preload_data = preload_data
        if self.preload_data:
            print('Preloading the {0} dataset ...'.format(split))
            #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list

            # Concatenate the raw data along the channels dimension
            self.raw_images = []  # This will be a list of lists
            # Per each patient, go over all modalities and projections
            #for jj in range(len(self.image_filenames[0][0])): # REAL
            for jj in range(len([0, 1])):  # DEBUG
                tmp_data_list = []
                # Go over all projections
                for kk in range(len(self.image_filenames[0])):
                    internal_cntr = 0
                    # Go over all modalities (T2, T1, ...) of similar projection
                    for ii in range(len(self.image_filenames)):
                        # NOTE: Coronal data is already permuted in the correct directions
                        if self.image_filenames[ii][
                                kk] != []:  # Check if the data exists
                            #print(self.image_filenames[ii][kk][jj]) # For DEBUG
                            if internal_cntr == 0:  # First time
                                q_dat = load_nifti_img(
                                    self.image_filenames[ii][kk][jj],
                                    dtype=np.float32)[
                                        0]  # normalize values to [0,1] range
                                q_dat = q_dat / np.max(
                                    q_dat.reshape(-1))  # Normalize
                                q_dat = self.zero_pad(q_dat)  # zero pad
                                tmp_data = np.expand_dims(q_dat, axis=0)
                            else:  # Concatenate additional channels
                                q_dat = load_nifti_img(
                                    self.image_filenames[ii][kk][jj],
                                    dtype=np.float32)[
                                        0]  # normalize values to [0,1] range
                                q_dat = q_dat / np.max(q_dat.reshape(-1))
                                q_dat = self.zero_pad(q_dat)
                                concat_data = np.expand_dims(q_dat, axis=0)
                                tmp_data = np.concatenate(
                                    (tmp_data, concat_data), axis=0)
                            internal_cntr += 1

                    # Append for all modalities per same projection
                    tmp_data_list.append(tmp_data)

                # Add the concatenated multichannel data to the list
                # [0] - Axial, [1] - Coronal
                self.raw_images.append(tmp_data_list)

            self.raw_labels = [
                f for f in range(len(self.image_filenames) + 1)
            ]  # Dummy: no labels for this project
            print('Loading is done\n')
Esempio n. 12
0
        for epoch in range(1, nb_epoch + 1):
            seg_train(epoch, unet_model, seg_train_loader, criterion, optimizer)
            # print("Test AUC:", auc_cal(model, testloader))
            # test(model, testloader)
        torch.save(unet_model, os.path.join(model_repo_dir, 'unet.pt'))

    #Training UNET END


    #Training Diag network START
    transform_pipeline_train = tr.Compose(
        [
         # AddGaussian(),
         # AddGaussian(ismulti=False),
         tr.ToTensor(), tr.AddChannel(axis=0), tr.TypeCast('float'),
         # Attenuation((-.001, .1)),
         # tr.RangeNormalize(0,1),
         tr.RandomBrightness(-.2, .2),
         tr.RandomGamma(.9, 1.1),
         tr.RandomFlip(),
         tr.RandomAffine(rotation_range=5, translation_range=0.2
                         # zoom_range=(0.9, 1.1)
                         )])

    transform_pipeline_test = tr.Compose([tr.ToTensor(), tr.AddChannel(axis=0), tr.TypeCast('float')
                                          # tr.RangeNormalize(0, 1)
                                          ])

    transformed_images = Beijing_diag_dataset(root_dir, patient_id_G_train, patient_id_H_train, resize_dim= (96,288),
                                         transform=transform_pipeline_train)
Esempio n. 13
0
    def gsd_pCT_transform(self):
        '''
        Data augmentation transformations for the Geneva Stroke dataset (pCT maps)
        :return:
        '''

        train_transform = ts.Compose([
            ts.ToTensor(),
            ts.Pad(size=self.scale_size),
            ts.TypeCast(['float', 'float']),
            ts.RandomFlip(h=True, v=True, p=self.random_flip_prob),
            # Todo Random Affine doesn't support channels --> try newer version of torchsample or torchvision
            # ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val,
            #                 zoom_range=self.scale_val, interp=('bilinear', 'nearest')),
            ts.ChannelsFirst(),
            #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            # Todo apply channel wise normalisation
            ts.NormalizeMedic(norm_flag=(True, False)),
            # Todo fork torchsample and fix the Random Crop bug
            # ts.ChannelsLast(), # seems to be needed for crop
            # ts.RandomCrop(size=self.patch_size),
            ts.TypeCast(['float', 'long'])
        ])

        valid_transform = ts.Compose([
            ts.ToTensor(),
            ts.Pad(size=self.scale_size),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float']),
            #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            ts.NormalizeMedic(norm_flag=(True, False)),
            # ts.ChannelsLast(),
            # ts.SpecialCrop(size=self.patch_size, crop_type=0),
            ts.TypeCast(['float', 'long'])
        ])

        # train_transform = ts.Compose([
        #     ts.ToTensor(),
        #     ts.Pad(size=self.scale_size),
        #                               ts.ChannelsFirst(),
        #                               ts.TypeCast(['float', 'long'])
        # ])
        # valid_transform = ts.Compose([
        #                               ts.ToTensor(),
        #     ts.Pad(size=self.scale_size),
        #     ts.ChannelsFirst(),
        #     ts.TypeCast(['float', 'long'])
        #
        # ])

        # train_transform = tf.Compose([
        #     tf.Pad(1),
        #     tf.Lambda(lambda a: a.permute(3, 0, 1, 2)),
        #     tf.Lambda(lambda a: a.float()),
        # ])
        # valid_transform = tf.Compose([
        #     tf.Pad(1),
        #     tf.Lambda(lambda a: a.permute(3, 0, 1, 2)),
        #     tf.Lambda(lambda a: a.float()),
        #
        # ])

        return {'train': train_transform, 'valid': valid_transform}
Esempio n. 14
0
    def __init__(self,
                 root_dir,
                 split,
                 transform=None,
                 preload_data=True,
                 modalities=['7T_T2'],
                 rank=0):
        super(CMR3DDataset_MultiClass_MultiProj, self).__init__()

        # TODO: make this a parameter
        #self.TypeOfModal = ['7T_DTI_B0', '7T_DTI_FA'] # If we use B0 as well for the Thalamus seg

        ### ---
        #self.TypeOfModal = ['7T_T2']  # For T2 axial
        #self.TypeOfModal = ['7T_T2_cor'] # For T2 coronal
        #self.TypeOfModal = ['7T_SWI']
        self.TypeOfModal = modalities
        if rank == 0:
            print("Modalities: {}".format(self.TypeOfModal))

        # For now we assume all projections are axial - no coronal projections
        image_dir = join(root_dir, split, 'image')
        target_dir = join(root_dir, split, 'label')

        self.image_filenames = []
        for mod in self.TypeOfModal:
            if mod == '7T_T2_cor':
                mod = '7T_T2'
            tmp_filenames = [
                join(image_dir, x) for x in listdir(image_dir)
                if (is_image_file(x) and x.find(mod) != -1)
            ]
            self.image_filenames.append(sorted(tmp_filenames))
            # TODO: if mod == '7T_T2'
            if mod == '7T_T2' or mod == '7T_SWI':  # This is the reference scan (all patients must have it) - use it to determine how many patients we have
                self.patient_len = len(tmp_filenames)
            elif mod == '7T_T1':
                # Secondary priority
                self.patient_len = len(tmp_filenames)
            elif mod == '7T_DTI_B0':
                # Tertiary priority
                self.patient_len = len(tmp_filenames)
            elif mod == '3T_T2':
                # Fourth priority
                self.patient_len = len(tmp_filenames)

        self.target_filenames = sorted([
            join(target_dir, x) for x in listdir(target_dir)
            if is_image_file(x)
        ])
        #assert len(self.image_filenames) == len(self.target_filenames)

        if rank == 0:
            print("\n".join(self.target_filenames))

        # Assume we always start from 7T_T2 Axial scans
        tmp_data = load_nifti_img(self.image_filenames[0][0], dtype=np.int16)
        self.image_dims = tmp_data[0].shape

        # report the number of images in the dataset
        if rank == 0:
            print('Number of {0} images: {1} Patients'.format(
                split, self.__len__()))

        # data augmentation
        # NOTE: in this case, disable the add dimension transform!
        #self.transform = transform
        self.transform = ts.Compose(
            [ts.ToTensor(), ts.TypeCast(['float', 'long'])])

        # data load into the ram memory
        self.t2_headers = []
        self.preload_data = preload_data
        if self.preload_data:
            if rank == 0:
                print('Preloading the {0} dataset ...'.format(split))
            #self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] # Output is a list

            # Concatenate the raw data along the channels dimension
            self.raw_images = []
            for jj in range(len(self.image_filenames[0])
                            ):  # Per each patient, go over all modalities
                internal_cntr = 0
                for ii in range(len(self.image_filenames)
                                ):  # Go over all patients, left and right
                    #print('File: {}'.format(self.image_filenames[ii][jj])) # Only for DEBUG
                    if internal_cntr == 0:  # First time
                        q_dat, tmp_header, _ = load_nifti_img(
                            self.image_filenames[ii][jj], dtype=np.float32
                        )  # normalize values to [0,1] range

                        # if self.TypeOfModal[0] == '7T_T2_cor':
                        #     # Only for coronal slices
                        #     q_dat = np.transpose(q_dat, (0, 2, 1))

                        tmp_data = np.expand_dims(q_dat /
                                                  np.max(q_dat.reshape(-1)),
                                                  axis=0)
                        tmp_name = self.image_filenames[ii][
                            jj]  # For the header file - identification in the multi GPU case
                    else:  # Concatenate additional channels
                        q_dat = load_nifti_img(
                            self.image_filenames[ii][jj], dtype=np.float32)[
                                0]  # normalize values to [0,1] range

                        # if self.TypeOfModal[0] == '7T_T2_cor':
                        #     # Only for coronal slices
                        #     q_dat = np.transpose(q_dat, (0, 2, 1))

                        concat_data = np.expand_dims(q_dat /
                                                     np.max(q_dat.reshape(-1)),
                                                     axis=0)
                        tmp_data = np.concatenate((tmp_data, concat_data),
                                                  axis=0)
                    internal_cntr += 1

                # Add the concatenated multichannel data to the list
                self.raw_images.append(tmp_data)
                tmp_header['db_name'] = re.search(
                    '_P(.*).nii.gz', tmp_name).group(1)  # Data identifier
                self.t2_headers.append(tmp_header)

            # Load labels
            #self.raw_labels = [load_nifti_img(ii, dtype=np.uint8)[0] for ii in self.target_filenames]
            self.raw_labels = []
            for ii in self.target_filenames:
                label_tmp = load_nifti_img(ii, dtype=np.uint8)[0]

                # if self.TypeOfModal[0] == '7T_T2_cor':
                #     # Only for coronal slices
                #     label_tmp = np.transpose(label_tmp, (0, 2, 1))

                self.raw_labels.append(label_tmp)

            if rank == 0:
                print('Loading is done\n')