def get_viz_set(*ls,
                dataset_name,
                test_subject=0,
                save=False,
                sub_vol_path=None):
    """
    Returns total 3d input volumes (t1 and t2 or more) and segmentation maps
    3d total vol shape : torch.Size([1, 144, 192, 256])
    """
    modalities = len(ls)
    total_volumes = []

    for i in range(modalities):
        path_img = ls[i][test_subject]

        img_tensor = img_loader.load_medical_image(path_img, viz3d=True)
        if i == modalities - 1:
            img_tensor = fix_seg_map(img_tensor, dataset=dataset_name)

        total_volumes.append(img_tensor)

    if save:
        total_subvolumes = total_volumes[0].shape[0]
        for i in range(total_subvolumes):
            filename = (sub_vol_path + "id_" + str(test_subject) + "_VIZ_" +
                        str(i) + "_modality_")
            for j in range(modalities):
                filename = filename + str(j) + ".npy"
                np.save(filename, total_volumes[j][i])
    else:
        return torch.stack(total_volumes, dim=0)
    def create_input_data(self):
        total = len(self.list_IDsT1)
        print('Dataset samples: ', total)
        for i in range(total):
            print(i)
            if self.modalities == 2:
                img_t1_tensor = img_loader.load_medical_image(
                    self.list_IDsT1[i],
                    type="T1",
                    resample=self.voxels_space,
                    to_canonical=self.to_canonical)
                img_t2_tensor = img_loader.load_medical_image(
                    self.list_IDsT1[i],
                    type="T2",
                    resample=self.voxels_space,
                    to_canonical=self.to_canonical)
            else:
                img_t1_tensor = img_loader.load_medical_image(
                    self.list_IDsT1[i],
                    type="T1",
                    resample=self.voxels_space,
                    to_canonical=self.to_canonical)

        if self.save:
            filename = self.sub_vol_path + 'id_' + str(i) + '_s_' + str(
                i) + '_'
            if self.modalities == 2:
                f_t1 = filename + 'T1.npy'
                f_t2 = filename + 'T2.npy'
                np.save(f_t1, img_t1_tensor)
                np.save(f_t2, img_t2_tensor)
                self.list.append(tuple((f_t1, f_t2)))
            else:
                f_t1 = filename + 'T1.npy'
                np.save(f_t1, img_t1_tensor)
                self.list.append(f_t1)
        else:
            if self.modalities == 2:
                self.list.append(tuple((img_t1_tensor, img_t2_tensor)))
            else:
                self.list.append(img_t1_tensor)
from medzoo.lib.medloaders.medical_loader_utils import generate_padded_subvolumes
import torch
import matplotlib.pyplot as plt
import medzoo.lib.augment3D as augment

size = 32
from medzoo.lib.medloaders.medical_image_process import load_medical_image

# t1 = torch.randn(size,size,size).numpy()
# t2  = torch.randn(size,size,size).numpy()
b = torch.randn(size, size, size).numpy()

t1 = load_medical_image(
    '.././datasets/iseg_2017/iSeg-2017-Training/subject-1-T1.hdr').squeeze(
    ).numpy()
label = load_medical_image(
    '.././datasets/iseg_2017/iSeg-2017-Training/subject-1-label.img').squeeze(
    ).numpy()
f, axarr = plt.subplots(4, 1)

axarr[0].imshow(t1[70, :, :])
axarr[1].imshow(label[70, :, :])

c = augment.RandomChoice(transforms=[augment.GaussianNoise(mean=0, std=0.1)])
[t1], label = c([t1], label)

axarr[2].imshow(t1[70, :, :])
axarr[3].imshow(label[70, :, :])

plt.show()
def create_sub_volumes(
    *ls,
    dataset_name,
    mode,
    samples,
    full_vol_dim,
    crop_size,
    sub_vol_path,
    normalization="max_min",
    th_percent=0.1,
):
    """

    :param ls: list of modality paths, where the last path is the segmentation map
    :param dataset_name: which dataset is used
    :param mode: train/val
    :param samples: train/val samples to generate
    :param full_vol_dim: full image size
    :param crop_size: train volume size
    :param sub_vol_path: path for the particular patient
    :param th_percent: the % of the croped dim that corresponds to non-zero labels
    :param crop_type:
    :return:
    """
    total = len(ls[0])
    assert total != 0, "Problem reading data. Check the data paths."
    modalities = len(ls)
    volume_list = []
    # print(modalities)
    # print(ls[2])

    print(
        "Mode: " + mode + " Subvolume samples to generate: ",
        samples,
        " Volumes: ",
        total,
    )
    for i in range(samples):
        # print(i)
        random_index = np.random.randint(total)
        sample_paths = []
        tensor_images = []
        for j in range(modalities):
            sample_paths.append(ls[j][random_index])
        # print(sample_paths)
        while True:
            label_path = sample_paths[-1]
            crop = find_random_crop_dim(full_vol_dim, crop_size)
            full_segmentation_map = img_loader.load_medical_image(
                label_path,
                viz3d=True,
                type="label",
                crop_size=crop_size,
                crop=crop)

            full_segmentation_map = fix_seg_map(full_segmentation_map,
                                                dataset_name)
            # print(full_segmentation_map.shape)
            if find_non_zero_labels_mask(full_segmentation_map, th_percent,
                                         crop_size, crop):
                segmentation_map = img_loader.load_medical_image(
                    label_path, type="label", crop_size=crop_size, crop=crop)
                segmentation_map = fix_seg_map(segmentation_map, dataset_name)
                for j in range(modalities - 1):
                    img_tensor = img_loader.load_medical_image(
                        sample_paths[j],
                        type="T1",
                        normalization=normalization,
                        crop_size=crop_size,
                        crop=crop,
                    )

                    tensor_images.append(img_tensor)

                break

        filename = (sub_vol_path + "id_" + str(random_index) + "_s_" + str(i) +
                    "_modality_")
        list_saved_paths = []
        for j in range(modalities - 1):
            f_t1 = filename + str(j) + ".npy"
            list_saved_paths.append(f_t1)

            np.save(f_t1, tensor_images[j])

        f_seg = filename + "seg.npy"

        np.save(f_seg, segmentation_map)
        list_saved_paths.append(f_seg)
        volume_list.append(tuple(list_saved_paths))

    return volume_list
def get_all_sub_volumes(
    *ls,
    dataset_name,
    mode,
    samples,
    full_vol_dim,
    crop_size,
    sub_vol_path,
    normalization="max_min",
):
    # TODO
    # 1.) gia ola tas subject fortwnwn image kai target
    # 2.) call generate_non_overlapping_volumes gia na kanw to image kai target sub_volumnes patches
    # 3.) apothikeuw tensors
    total = len(ls[0])
    assert total != 0, "Problem reading data. Check the data paths."
    modalities = len(ls)
    list = []

    for vol_id in range(total):

        tensor_images = []
        for modality_id in range(modalities - 1):
            img_tensor = img_loader.medical_image_transform(
                img_loader.load_medical_image(ls[modality_id][vol_id],
                                              type="T1"),
                normalization=normalization,
            )

            img_tensor = generate_padded_subvolumes(img_tensor,
                                                    kernel_dim=crop_size)

            tensor_images.append(img_tensor)

        segmentation_map = img_loader.medical_image_transform(
            img_loader.load_medical_image(ls[modalities - 1][vol_id],
                                          viz3d=True,
                                          type="label"))
        segmentation_map = generate_padded_subvolumes(segmentation_map,
                                                      kernel_dim=crop_size)

        filename = (sub_vol_path + "id_" + str(vol_id) + "_s_" +
                    str(modality_id) + "_modality_")

        list_saved_paths = []
        # print(len(tensor_images[0]))
        for k in range(len(tensor_images[0])):
            for j in range(modalities - 1):
                f_t1 = filename + str(j) + "_sample_{}".format(
                    str(k).zfill(8)) + ".npy"
                list_saved_paths.append(f_t1)
                # print(f_t1,tensor_images[j][k].shape)
                np.save(f_t1, tensor_images[j])

            f_seg = filename + "seg_sample_{}".format(str(k).zfill(8)) + ".npy"
            # print(f_seg)
            np.save(f_seg, segmentation_map)
            list_saved_paths.append(f_seg)
            list.append(tuple(list_saved_paths))

    # print(list)
    return list
def create_non_overlapping_sub_volumes(
    *ls,
    dataset_name,
    mode,
    samples,
    full_vol_dim,
    crop_size,
    sub_vol_path,
    normalization="max_min",
    th_percent=0.1,
):
    """returns list of non overlapping subvolumes. Also this codebase is hella shitty -_-

    :param ls: list of modality paths, where the last path is the segmentation map
    :param dataset_name: which dataset is used
    :param mode: train/val
    :param samples: train/val samples to generate
    :param full_vol_dim: full image size
    :param crop_size: train volume size
    :param sub_vol_path: path for the particular patient
    :param th_percent: the % of the croped dim that corresponds to non-zero labels
    :param crop_type:
    :return:
    """
    total = len(ls[0])  # total is 0 because only one subject
    assert total != 0, "Problem reading data. Check the data paths."
    modalities = len(ls)
    subvolume_list = []

    # make range be a multiple of crop_size[i]
    x_range = ceil(full_vol_dim[0] / crop_size[0]) * crop_size[0]
    y_range = ceil(full_vol_dim[1] / crop_size[1]) * crop_size[1]
    z_range = ceil(full_vol_dim[2] / crop_size[2]) * crop_size[2]

    print(
        "Mode: " + mode + " Subvolume samples to generate: ",
        samples,
        " Volumes: ",
        total,
    )
    # x, y, z are min coordinates of crop volume
    for x in range(0, x_range, crop_size[0]):
        for y in range(0, y_range, crop_size[1]):
            for z in range(0, z_range, crop_size[2]):

                crop_x = min(full_vol_dim[0] - crop_size[0], x)
                crop_y = min(full_vol_dim[1] - crop_size[1], y)
                crop_z = min(full_vol_dim[2] - crop_size[2], z)

                crop = (crop_x, crop_y, crop_z)  # (int, int, int)

                # generate subvolume
                # print(i)
                tensor_images = []
                sample_paths = [ls[i][0] for i in range(modalities)
                                ]  # list of paths for one subject

                # print(sample_paths)

                label_path = sample_paths[-1]

                # crop = find_random_crop_dim(full_vol_dim, crop_size)

                full_segmentation_map = img_loader.load_medical_image(
                    label_path,
                    viz3d=True,
                    type="label",
                    crop_size=crop_size,
                    crop=crop)

                full_segmentation_map = fix_seg_map(full_segmentation_map,
                                                    dataset_name)
                # print(full_segmentation_map.shape)

                if find_non_zero_labels_mask(full_segmentation_map, th_percent,
                                             crop_size, crop):
                    segmentation_map = img_loader.load_medical_image(
                        label_path,
                        type="label",
                        crop_size=crop_size,
                        crop=crop)
                    segmentation_map = fix_seg_map(segmentation_map,
                                                   dataset_name)
                    for j in range(modalities - 1):
                        img_tensor = img_loader.load_medical_image(
                            sample_paths[j],
                            type="T1",
                            normalization=normalization,
                            crop_size=crop_size,
                            crop=crop,
                        )

                        tensor_images.append(img_tensor)

                # save subvolume
                filename = f"{sub_vol_path}id_9_s_{x}_{y}_{z}_modality_"
                list_saved_paths = []
                for j in range(modalities - 1):
                    f_t1 = f"{filename}{j}.npy"
                    np.save(f_t1, tensor_images[j])
                    list_saved_paths.append(f_t1)

                f_seg = f"{filename}seg.npy"
                np.save(f_seg, segmentation_map)
                list_saved_paths.append(f_seg)

                subvolume_list.append(tuple(list_saved_paths))

    return subvolume_list