Пример #1
0
    def create_input_data(self):
        total = len(self.list_IDsT1)
        print('Dataset samples: ', total)
        for i in range(total):
            print(i)
            if self.modalities == 2:
                img_t1_tensor = img_loader.load_medical_image(self.list_IDsT1[i], type="T1", resample=self.voxels_space,
                                                              to_canonical=self.to_canonical)
                img_t2_tensor = img_loader.load_medical_image(self.list_IDsT1[i], type="T2", resample=self.voxels_space,
                                                              to_canonical=self.to_canonical)
            else:
                img_t1_tensor = img_loader.load_medical_image(self.list_IDsT1[i], type="T1", resample=self.voxels_space,
                                                              to_canonical=self.to_canonical)

        if self.save:
            filename = self.sub_vol_path + 'id_' + str(i) + '_s_' + str(i) + '_'
            if self.modalities == 2:
                f_t1 = filename + 'T1.npy'
                f_t2 = filename + 'T2.npy'
                np.save(f_t1, img_t1_tensor)
                np.save(f_t2, img_t2_tensor)
                self.list.append(tuple((f_t1, f_t2)))
            else:
                f_t1 = filename + 'T1.npy'
                np.save(f_t1, img_t1_tensor)
                self.list.append(f_t1)
        else:
            if self.modalities == 2:
                self.list.append(tuple((img_t1_tensor, img_t2_tensor)))
            else:
                self.list.append(img_t1_tensor)
Пример #2
0
    def get_viz_set(self, test_subject=0):
        """
        Returns total 3d input volumes (t1 and t2 or more) and segmentation maps
        3d total vol shape : torch.Size([1, 144, 192, 256])
        """
        TEST_SUBJECT = test_subject
        path_T1 = self.list_IDsT1[TEST_SUBJECT]
        path_T1ce = self.list_IDsT1ce[TEST_SUBJECT]
        path_T2 = self.list_IDsT2[TEST_SUBJECT]
        path_flair = self.list_IDsFlair[TEST_SUBJECT]
        label_path = self.labels[TEST_SUBJECT]

        segmentation_map = img_loader.load_medical_image(label_path,
                                                         viz3d=True)
        img_t1_tensor = img_loader.load_medical_image(path_T1,
                                                      type="T1",
                                                      viz3d=True)
        img_t1ce_tensor = img_loader.load_medical_image(path_T1ce,
                                                        type="T1ce",
                                                        viz3d=True)
        img_t2_tensor = img_loader.load_medical_image(path_T2,
                                                      type="T2",
                                                      viz3d=True)
        img_flair_tensor = img_loader.load_medical_image(path_flair,
                                                         type="FLAIR",
                                                         viz3d=True)

        ### TO DO SAVE FULL VOLUME AS numpy
        if self.save:
            self.full_volume = []

            segmentation_map = segmentation_map
            img_t1_tensor = self.find_reshaped_vol(img_t1_tensor)
            img_t1ce_tensor = self.find_reshaped_vol(img_t1ce_tensor)
            img_t2_tensor = self.find_reshaped_vol(img_t2_tensor)
            img_flair_tensor = self.find_reshaped_vol(img_flair_tensor)

            self.sub_vol_path = self.root + '/MICCAI_BraTS_2018_Data_Training/generated/visualize/'
            utils.make_dirs(self.sub_vol_path)

            for i in range(len(img_t1_tensor)):
                filename = self.sub_vol_path + 'id_' + str(
                    TEST_SUBJECT) + '_VIZ_' + str(i) + '_'
                f_t1 = filename + 'T1.npy'
                f_t1ce = filename + 'T1CE.npy'
                f_t2 = filename + 'T2.npy'
                f_flair = filename + 'FLAIR.npy'
                f_seg = filename + 'seg.npy'

                np.save(f_t1, img_t1_tensor[i])
                np.save(f_t1ce, img_t1ce_tensor[i])
                np.save(f_t2, img_t2_tensor[i])
                np.save(f_flair, img_flair_tensor[i])
                np.save(f_seg, segmentation_map[i])

                self.full_volume.append(tuple((f_t1, f_t2, f_seg)))
            print("Full validation volume has been generated")
        else:
            self.full_volume = tuple(
                (img_t1_tensor, img_t2_tensor, segmentation_map))
Пример #3
0
    def create_sub_volumes(self):
        total = len(self.list_IDsT1)
        TH = 160  # threshold for non empty volumes
        print('Mode: ' + self.mode + ' Subvolume samples to generate: ',
              self.samples, ' Volumes: ', total)

        for i in range(self.samples):
            random_index = np.random.randint(total)
            path_T1 = self.list_IDsT1[random_index]
            path_T2 = self.list_IDsT2[random_index]

            while True:
                slices = np.random.randint(self.full_vol_dim[0] -
                                           self.crop_size[0])
                w_crop = np.random.randint(self.full_vol_dim[1] -
                                           self.crop_size[1])
                h_crop = np.random.randint(self.full_vol_dim[2] -
                                           self.crop_size[2])

                if self.labels is not None:
                    label_path = self.labels[random_index]
                    segmentation_map = img_loader.load_medical_image(
                        label_path,
                        crop_size=self.crop_size,
                        crop=(slices, w_crop, h_crop),
                        type='label')
                    if segmentation_map.sum() > TH:
                        img_t1_tensor = img_loader.load_medical_image(
                            path_T1,
                            crop_size=self.crop_size,
                            crop=(slices, w_crop, h_crop),
                            type="T1")
                        img_t2_tensor = img_loader.load_medical_image(
                            path_T2,
                            crop_size=self.crop_size,
                            crop=(slices, w_crop, h_crop),
                            type="T2")
                        segmentation_map = self.fix_seg_map(segmentation_map)
                        break
                    else:
                        continue
                else:
                    segmentation_map = None
                    break

            if self.save:

                filename = self.sub_vol_path + 'id_' + str(
                    random_index) + '_s_' + str(i) + '_'
                f_t1 = filename + 'T1.npy'
                f_t2 = filename + 'T2.npy'
                f_seg = filename + 'seg.npy'
                np.save(f_t1, img_t1_tensor)
                np.save(f_t2, img_t2_tensor)
                np.save(f_seg, segmentation_map)
                self.list.append(tuple((f_t1, f_t2, f_seg)))
            else:
                self.list.append(
                    tuple((img_t1_tensor, img_t2_tensor, segmentation_map)))
Пример #4
0
def get_all_sub_volumes(*ls,
                        dataset_name,
                        mode,
                        samples,
                        full_vol_dim,
                        crop_size,
                        sub_vol_path,
                        normalization='max_min'):
    # TODO
    # 1.) gia ola tas subject fortwnwn image kai target
    # 2.) call generate_non_overlapping_volumes gia na kanw to image kai target sub_volumnes patches
    # 3.) apothikeuw tensors
    total = len(ls[0])
    assert total != 0, "Problem reading data. Check the data paths."
    modalities = len(ls)
    list = []

    for vol_id in range(total):

        tensor_images = []
        for modality_id in range(modalities - 1):
            img_tensor = img_loader.medical_image_transform(
                img_loader.load_medical_image(ls[modality_id][vol_id],
                                              type="T1"),
                normalization=normalization)

            img_tensor = generate_padded_subvolumes(img_tensor,
                                                    kernel_dim=crop_size)

            tensor_images.append(img_tensor)
        segmentation_map = img_loader.medical_image_transform(
            img_loader.load_medical_image(ls[modalities - 1][vol_id],
                                          viz3d=True,
                                          type='label'))
        segmentation_map = generate_padded_subvolumes(segmentation_map,
                                                      kernel_dim=crop_size)

        filename = sub_vol_path + 'id_' + str(vol_id) + '_s_' + str(
            modality_id) + '_modality_'

        list_saved_paths = []
        # print(len(tensor_images[0]))
        for k in range(len(tensor_images[0])):
            for j in range(modalities - 1):
                f_t1 = filename + str(j) + '_sample_{}'.format(
                    str(k).zfill(8)) + '.npy'
                list_saved_paths.append(f_t1)
                # print(f_t1,tensor_images[j][k].shape)
                np.save(f_t1, tensor_images[j])

            f_seg = filename + 'seg_sample_{}'.format(str(k).zfill(8)) + '.npy'
            # print(f_seg)
            np.save(f_seg, segmentation_map)
            list_saved_paths.append(f_seg)
            list.append(tuple(list_saved_paths))

    # print(list)
    return list
def create_sub_volumes(*ls,
                       dataset_name,
                       mode,
                       samples,
                       full_vol_dim,
                       crop_size,
                       sub_vol_path,
                       threshold=10):
    total = len(ls[0])
    assert total != 0, "Problem reading data. Check the data paths."
    modalities = len(ls)
    list = []
    print('Mode: ' + mode + ' Subvolume samples to generate: ', samples,
          ' Volumes: ', total)
    for i in range(samples):
        random_index = np.random.randint(total)
        sample_paths = []
        tensor_images = []
        for j in range(modalities):
            sample_paths.append(ls[j][random_index])

        while True:
            crop = find_random_crop_dim(full_vol_dim, crop_size)

            label_path = sample_paths[-1]
            segmentation_map = img_loader.load_medical_image(
                label_path, crop_size=crop_size, crop=crop, type='label')

            segmentation_map = fix_seg_map(segmentation_map, dataset_name)
            if segmentation_map.sum() > threshold:
                for j in range(modalities - 1):
                    img_tensor = img_loader.load_medical_image(
                        sample_paths[j],
                        crop_size=crop_size,
                        crop=crop,
                        type="T1")
                    tensor_images.append(img_tensor)

                break

        filename = sub_vol_path + 'id_' + str(random_index) + '_s_' + str(
            i) + '_modality_'
        list_saved_paths = []
        for j in range(modalities - 1):
            f_t1 = filename + str(j) + '.npy'
            list_saved_paths.append(f_t1)
            np.save(f_t1, tensor_images[j])

        f_seg = filename + 'seg.npy'
        np.save(f_seg, segmentation_map)
        list_saved_paths.append(f_seg)
        list.append(tuple(list_saved_paths))
    return list
Пример #6
0
    def get_viz_set(self, test_subject=0):
        """
        Returns total 3d input volumes (t1 and t2 or more) and segmentation maps
        3d total vol shape : torch.Size([1, 144, 192, 256])
        """
        TEST_SUBJECT = test_subject
        path_T1 = self.list_IDsT1[TEST_SUBJECT]
        path_T2 = self.list_IDsT2[TEST_SUBJECT]
        label_path = self.labels[TEST_SUBJECT]

        segmentation_map = img_loader.load_medical_image(label_path,
                                                         viz3d=True)

        img_t1_tensor = img_loader.load_medical_image(path_T1,
                                                      type="T1",
                                                      viz3d=True)
        img_t2_tensor = img_loader.load_medical_image(path_T2,
                                                      type="T2",
                                                      viz3d=True)
        segmentation_map = self.fix_seg_map(segmentation_map)

        ### TO DO SAVE FULL VOLUME AS numpy
        if self.save:
            self.full_volume = []

            segmentation_map = segmentation_map.reshape(
                -1, self.crop_size[0], self.crop_size[1], self.crop_size[2])
            img_t1_tensor = img_t1_tensor.reshape(-1, self.crop_size[0],
                                                  self.crop_size[1],
                                                  self.crop_size[2])
            img_t2_tensor = img_t1_tensor.reshape(-1, self.crop_size[0],
                                                  self.crop_size[1],
                                                  self.crop_size[2])
            self.sub_vol_path = self.root + '/iseg_2017/generated/visualize/'
            utils.make_dirs(self.sub_vol_path)

            for i in range(len(img_t1_tensor)):
                filename = self.sub_vol_path + 'id_' + str(
                    TEST_SUBJECT) + '_VIZ_' + str(i) + '_'
                f_t1 = filename + 'T1.npy'
                f_t2 = filename + 'T2.npy'
                f_seg = filename + 'seg.npy'

                np.save(f_t1, img_t1_tensor[i])
                np.save(f_t2, img_t2_tensor[i])

                np.save(f_seg, segmentation_map[i])
                self.full_volume.append(tuple((f_t1, f_t2, f_seg)))
            print("Full validation volume has been generated")
        else:
            self.full_volume = tuple(
                (img_t1_tensor, img_t2_tensor, segmentation_map))
Пример #7
0
def get_viz_set(*ls,
                dataset_name,
                test_subject=0,
                save=False,
                sub_vol_path=None):
    """
    Returns total 3d input volumes (t1 and t2 or more) and segmentation maps
    3d total vol shape : torch.Size([1, 144, 192, 256])
    """
    modalities = len(ls)
    total_volumes = []

    for i in range(modalities):
        path_img = ls[i][test_subject]

        img_tensor = img_loader.load_medical_image(path_img, viz3d=True)
        if i == modalities - 1:
            img_tensor = fix_seg_map(img_tensor, dataset=dataset_name)

        total_volumes.append(img_tensor)

    if save:
        total_subvolumes = total_volumes[0].shape[0]
        for i in range(total_subvolumes):
            filename = sub_vol_path + 'id_' + str(
                test_subject) + '_VIZ_' + str(i) + '_modality_'
            for j in range(modalities):
                filename = filename + str(j) + '.npy'
                np.save(filename, total_volumes[j][i])
    else:
        return torch.stack(total_volumes, dim=0)
Пример #8
0
    def get_viz_set(self):
        """
        Returns total 3d input volumes(t1 and t2) and segmentation maps
        3d total vol shape : torch.Size([1, 144, 192, 256])
        """
        segmentation_map = img_loader.load_medical_image(self.labels[0],
                                                         type="label",
                                                         viz3d=True)
        img_t1_tensor = img_loader.load_medical_image(self.list_reg_t1[0],
                                                      type="T1",
                                                      viz3d=True)
        img_ir_tensor = img_loader.load_medical_image(self.list_reg_ir[0],
                                                      type="T2",
                                                      viz3d=True)
        img_flair_tensor = img_loader.load_medical_image(self.list_flair[0],
                                                         type="FLAIR",
                                                         viz3d=True)

        if self.classes == 4:
            segmentation_map = self.fix_seg_map(segmentation_map)

        self.full_volume = tuple(
            (img_t1_tensor, img_ir_tensor, img_flair_tensor, segmentation_map))
        print("Full validation volume has been generated")
Пример #9
0
def get_viz_set(*ls,
                dataset_name,
                test_subject=0,
                save=False,
                sub_vol_path=None):
    """
    Returns total 3d input volumes (t1 and t2 or more) and segmentation maps
    3d total vol shape : torch.Size([1, 144, 192, 256])
    """
    modalities = len(ls)
    test_num = len(ls[0])
    total_volumes = []
    # import ipdb;ipdb.set_trace()
    # for test_id in range(test_num):
    for i in range(modalities):
        path_img = ls[i][test_subject]
        img_tensor = img_loader.load_medical_image(path_img, viz3d=True)
        if i == modalities - 1:
            img_tensor = fix_seg_map(img_tensor, dataset=dataset_name)

            # import ipdb;ipdb.set_trace()
        total_volumes.append(img_tensor)
    return torch.stack(total_volumes, dim=0)
Пример #10
0
from lib.medloaders.medical_loader_utils import generate_padded_subvolumes
import torch
import matplotlib.pyplot as plt
import lib.augment3D as augment

size = 32
from lib.medloaders.medical_image_process import load_medical_image

# t1 = torch.randn(size,size,size).numpy()
# t2  = torch.randn(size,size,size).numpy()
b = torch.randn(size, size, size).numpy()

t1 = load_medical_image(
    '.././datasets/iseg_2017/iSeg-2017-Training/subject-1-T1.hdr').squeeze(
    ).numpy()
label = load_medical_image(
    '.././datasets/iseg_2017/iSeg-2017-Training/subject-1-label.img').squeeze(
    ).numpy()
f, axarr = plt.subplots(4, 1)

axarr[0].imshow(t1[70, :, :])
axarr[1].imshow(label[70, :, :])

c = augment.RandomChoice(transforms=[augment.GaussianNoise(mean=0, std=0.1)])
[t1], label = c([t1], label)

axarr[2].imshow(t1[70, :, :])
axarr[3].imshow(label[70, :, :])

plt.show()
Пример #11
0
def create_sub_volumes(*ls,
                       dataset_name,
                       mode,
                       samples,
                       full_vol_dim,
                       crop_size,
                       sub_vol_path,
                       normalization='max_min',
                       th_percent=0.1):
    """

    :param ls: list of modality paths, where the last path is the segmentation map
    :param dataset_name: which dataset is used
    :param mode: train/val
    :param samples: train/val samples to generate
    :param full_vol_dim: full image size
    :param crop_size: train volume size
    :param sub_vol_path: path for the particular patient
    :param th_percent: the % of the croped dim that corresponds to non-zero labels
    :param crop_type:
    :return:
    """
    total = len(ls[0])
    # print("mode:", mode)
    # print('total:', total)
    assert total != 0, "Problem reading data. Check the data paths."
    modalities = len(ls)
    list = []
    # print(modalities)
    # print(ls[2])

    print('Mode: ' + mode + ' Subvolume samples to generate: ', samples,
          ' Volumes: ', total)
    for i in range(samples):
        # print(i)
        # random_index = np.random.randint(total)
        random_index = i
        sample_paths = []
        tensor_images = []
        for j in range(modalities):
            sample_paths.append(ls[j][random_index])
        # print(sample_paths)
        # while True:
        label_path = sample_paths[-1]
        # crop = find_random_crop_dim(full_vol_dim, crop_size)
        crop = (0, 0, 0)
        full_segmentation_map = img_loader.load_medical_image(
            label_path,
            viz3d=True,
            type='label',
            crop_size=crop_size,
            crop=crop)

        full_segmentation_map = fix_seg_map(full_segmentation_map,
                                            dataset_name)
        # print(full_segmentation_map.shape)
        # if find_non_zero_labels_mask(full_segmentation_map, th_percent, crop_size, crop):
        #     segmentation_map = img_loader.load_medical_image(label_path, type='label', crop_size=crop_size,
        #                                                      crop=crop)
        #     segmentation_map = fix_seg_map(segmentation_map, dataset_name)
        #     for j in range(modalities - 1):
        #         img_tensor = img_loader.load_medical_image(sample_paths[j], type="T1", normalization=normalization,
        #                                                    crop_size=crop_size, crop=crop)
        #
        #         tensor_images.append(img_tensor)
        #
        #     break

        segmentation_map = img_loader.load_medical_image(label_path,
                                                         type='label',
                                                         crop_size=crop_size,
                                                         crop=crop)
        segmentation_map = fix_seg_map(segmentation_map, dataset_name)
        for j in range(modalities - 1):
            img_tensor = img_loader.load_medical_image(
                sample_paths[j],
                type="T1",
                normalization=normalization,
                crop_size=crop_size,
                crop=crop)

            tensor_images.append(img_tensor)

            # break
        filename = sub_vol_path + 'id_' + str(random_index) + '_s_' + str(
            i) + '_modality_'
        list_saved_paths = []
        for j in range(modalities - 1):
            f_t1 = filename + str(j) + '.npy'
            list_saved_paths.append(f_t1)

            np.save(f_t1, tensor_images[j])

        f_seg = filename + 'seg.npy'

        np.save(f_seg, segmentation_map)
        list_saved_paths.append(f_seg)
        list.append(tuple(list_saved_paths))

    return list
Пример #12
0
    def get_samples(self):
        # threshhold for rejecting empty air sub volumes that make training instable
        TH = 10
        total = len(self.labels)
        print('Mode: ' + self.mode + ' Subvolume samples to generate: ',
              self.samples, ' Volumes: ', total)
        for i in range(self.samples):
            # if self.mode=='train':
            random_index = np.random.randint(total)
            path_flair = self.list_flair[random_index]
            path_reg_ir = self.list_reg_ir[random_index]
            path_reg_t1 = self.list_reg_t1[random_index]

            while True:
                w_crop = np.random.randint(self.full_vol_size[0] -
                                           self.crop_dim[0])
                h_crop = np.random.randint(self.full_vol_size[1] -
                                           self.crop_dim[1])
                slices = np.random.randint(self.full_vol_size[2] -
                                           self.crop_dim[2])
                crop = (w_crop, h_crop, slices)

                if self.labels is not None:
                    label_path = self.labels[random_index]
                    segmentation_map = img_loader.load_medical_image(
                        label_path,
                        crop_size=self.crop_dim,
                        crop=crop,
                        type='label')

                    if self.classes == 4:
                        segmentation_map = self.fix_seg_map(segmentation_map)

                    if segmentation_map.sum() > TH:
                        img_t1_tensor = img_loader.load_medical_image(
                            path_reg_t1,
                            crop_size=self.crop_dim,
                            crop=crop,
                            type="T1")
                        img_ir_tensor = img_loader.load_medical_image(
                            path_reg_ir,
                            crop_size=self.crop_dim,
                            crop=crop,
                            type="reg-IR")
                        img_flair_tensor = img_loader.load_medical_image(
                            path_flair,
                            crop_size=self.crop_dim,
                            crop=crop,
                            type="FLAIR")
                        break
                    else:
                        continue
                else:
                    segmentation_map = None
                    break
            if self.save:
                filename = self.sub_vol_path + 'id_' + str(
                    random_index) + '_s_' + str(i) + '_'
                f_t1 = filename + 'T1.npy'
                f_ir = filename + 'IR.npy'
                f_flair = filename + 'FLAIR.npy'
                f_seg = filename + 'seg.npy'

                np.save(f_t1, img_t1_tensor)
                np.save(f_ir, img_ir_tensor)
                np.save(f_flair, img_flair_tensor)
                np.save(f_seg, segmentation_map)

                self.list.append(tuple((f_t1, f_ir, f_flair, f_seg)))
            else:
                self.list.append(
                    tuple((img_t1_tensor, img_ir_tensor, img_flair_tensor,
                           segmentation_map)))
Пример #13
0
def create_sub_volumes(*ls,
                       dataset_name,
                       mode,
                       samples,
                       full_vol_dim,
                       crop_size,
                       sub_vol_path,
                       normalization='max_min',
                       th_percent=0.1):
    """

    :param ls: list of modality paths, where the last path is the segmentation map
    :param dataset_name: which dataset is used
    :param mode: train/val
    :param samples: train/val samples to generate
    :param full_vol_dim: full image size
    :param crop_size: train volume size
    :param sub_vol_path: path for the particular patient
    :param th_percent: the % of the croped dim that corresponds to non-zero labels
    :param crop_type:
    :return:
    """
    total = len(ls[0])
    assert total != 0, "Problem reading data. Check the data paths."
    modalities = len(ls)
    list = []
    # print(modalities)
    # print(ls[2])
    # import ipdb;ipdb.set_trace()
    # print('Mode: ' + mode + ' Subvolume samples to generate: ', samples, ' Volumes: ', total)

    # if mode == 'train' or mode == 'val':
    if os.path.exists(sub_vol_path):
        img_list = glob.glob(sub_vol_path + '*_0.npy')
        label_list = glob.glob(sub_vol_path + '*seg.npy')
        if not img_list:
            pass
        else:
            for img_np in img_list:
                list_saved_paths = []
                label_np = img_np.split('_0.npy')[0] + ('_seg.npy')
                assert label_np in label_list, '%s absent' % label_np
                list_saved_paths.append(img_np)
                list_saved_paths.append(label_np)
                list.append(tuple(list_saved_paths))
                print('Mode: ' + mode + ' Subvolume samples to generate: ',
                      len(list), ' Volumes: ', total)
            return list
    else:
        print('Mode: ' + mode + ' Subvolume samples to generate: ', samples,
              ' Volumes: ', total)
    fg_cnt = 0
    bg_cnt = 0
    for i in range(samples):
        print('%d/%d' % (i, samples))
        random_index = np.random.randint(total)
        sample_paths = []
        tensor_images = []
        for j in range(modalities):
            sample_paths.append(ls[j][random_index])
        # print(sample_paths)

        label_path = sample_paths[-1]
        img_nii = nib.load(label_path)
        full_vol_dim = img_nii.shape

        if random.uniform(0, 1) < 1:
            cnt = 0
            while True:
                crop = find_random_crop_dim(full_vol_dim, crop_size)
                full_segmentation_map = img_loader.load_medical_image(
                    label_path,
                    viz3d=True,
                    type='label',
                    crop_size=crop_size,
                    crop=crop)
                full_segmentation_map = fix_seg_map(full_segmentation_map,
                                                    dataset_name)
                # print(full_segmentation_map.shape)
                if find_non_zero_labels_mask(full_segmentation_map, th_percent,
                                             crop_size, crop):
                    segmentation_map = img_loader.load_medical_image(
                        label_path,
                        type='label',
                        crop_size=crop_size,
                        crop=crop)
                    segmentation_map = fix_seg_map(segmentation_map,
                                                   dataset_name)
                    for j in range(modalities - 1):
                        img_tensor = img_loader.load_medical_image(
                            sample_paths[j],
                            type="T1",
                            normalization=normalization,
                            crop_size=crop_size,
                            crop=crop)

                        tensor_images.append(img_tensor.unsqueeze(0))
                    print('one patch foreground ratio over %s get' %
                          th_percent)
                    fg_cnt += 1
                    break

                if cnt > 100:
                    print('one background get')
                    segmentation_map = img_loader.load_medical_image(
                        label_path,
                        type='label',
                        crop_size=crop_size,
                        crop=crop)
                    segmentation_map = fix_seg_map(segmentation_map,
                                                   dataset_name)
                    for j in range(modalities - 1):
                        img_tensor = img_loader.load_medical_image(
                            sample_paths[j],
                            type="T1",
                            normalization=normalization,
                            crop_size=crop_size,
                            crop=crop)

                        tensor_images.append(img_tensor.unsqueeze(0))
                    bg_cnt += 1
                    break
                cnt += 1
        else:
            crop = find_random_crop_dim(full_vol_dim, crop_size)
            full_segmentation_map = img_loader.load_medical_image(
                label_path,
                viz3d=True,
                type='label',
                crop_size=crop_size,
                crop=crop)
            full_segmentation_map = fix_seg_map(full_segmentation_map,
                                                dataset_name)
            print('one background get')
            segmentation_map = img_loader.load_medical_image(
                label_path, type='label', crop_size=crop_size, crop=crop)
            segmentation_map = fix_seg_map(segmentation_map, dataset_name)
            for j in range(modalities - 1):
                img_tensor = img_loader.load_medical_image(
                    sample_paths[j],
                    type="T1",
                    normalization=normalization,
                    crop_size=crop_size,
                    crop=crop)
                tensor_images.append(img_tensor.unsqueeze(0))
            bg_cnt += 1

        # import ipdb;ipdb.set_trace()
        # if not os.path.exists(sub_vol_path):
        #     os.makedirs(sub_vol_path)
        filename = sub_vol_path + 'id_' + str(random_index) + '_s_' + str(
            i) + '_modality_'
        list_saved_paths = []
        # import ipdb;ipdb.set_trace()
        for j in range(modalities - 1):
            f_t1 = filename + str(j) + '.npy'
            list_saved_paths.append(f_t1)

            np.save(f_t1, tensor_images[j])
        f_seg = filename + 'seg.npy'

        h, w, d = segmentation_map.shape
        segmentation_map_new = torch.zeros(2, h, w, d)
        segmentation_map_new[0][segmentation_map == 0] = 1
        segmentation_map_new[1][segmentation_map != 0] = 1
        # segmentation_map_new[2][segmentation_map==2] = 1
        # segmentation_map_new[3][segmentation_map==3] = 1
        # segmentation_map_new[4][segmentation_map==4] = 1
        # segmentation_map_new[5][segmentation_map==5] = 1  #for 6 cls
        np.save(f_seg, segmentation_map_new)
        # import ipdb;ipdb.set_trace()
        list_saved_paths.append(f_seg)
        list.append(tuple(list_saved_paths))
    print('fg vs bg : %d vs %d' % (fg_cnt, bg_cnt))
    return list