コード例 #1
0
ファイル: utils.py プロジェクト: jizongFox/DeepNormalize
def save_rebuilt_images(current_epoch, save_folder, datasets, img_input, img_norm, img_seg, img_gt):
    if not os.path.exists(os.path.join(save_folder, "reconstructed_images")):
        os.makedirs(os.path.join(save_folder, "reconstructed_images"))

    for dataset in datasets:
        transform_img_norm = Compose(
            [ToNifti1Image(), NiftiToDisk(os.path.join(save_folder, "reconstructed_images",
                                                       "Reconstructed_Normalized_{}_Image_{}.nii.gz".format(
                                                           dataset, str(current_epoch))))])
        transform_img_seg = Compose(
            [ToNifti1Image(), NiftiToDisk(os.path.join(save_folder, "reconstructed_images",
                                                       "Reconstructed_Segmented_{}_Image_{}.nii.gz".format(
                                                           dataset, str(current_epoch))))])
        transform_img_gt = Compose(
            [ToNifti1Image(), NiftiToDisk(os.path.join(save_folder, "reconstructed_images",
                                                       "Reconstructed_Ground_Truth_{}_Image_{}.nii.gz".format(
                                                           dataset, str(current_epoch))))])
        transform_img_input = Compose(
            [ToNifti1Image(), NiftiToDisk(os.path.join(save_folder, "reconstructed_images",
                                                       "Reconstructed_Input_{}_Image.nii.gz".format(
                                                           dataset, str(current_epoch))))])

        transform_img_norm(img_norm[dataset])
        transform_img_seg(img_seg[dataset])
        transform_img_gt(img_gt[dataset])
        transform_img_input(img_input[dataset])
コード例 #2
0
ファイル: utils.py プロジェクト: sami-ets/DeepNormalize
def save_augmented_rebuilt_images(current_epoch, save_folder, datasets,
                                  img_augmented_input,
                                  img_augmented_normalized):
    if not os.path.exists(os.path.join(save_folder, "reconstructed_images")):
        os.makedirs(os.path.join(save_folder, "reconstructed_images"))

    for dataset in datasets:
        transform_img_input_augmented = Compose([
            ToNifti1Image(),
            NiftiToDisk(
                os.path.join(
                    save_folder, "reconstructed_images",
                    "Reconstructed_Augmented_Input_{}_Image_{}.nii.gz".format(
                        dataset, str(current_epoch))))
        ])
        transform_img_normalized_augmented = Compose([
            ToNifti1Image(),
            NiftiToDisk(
                os.path.join(
                    save_folder, "reconstructed_images",
                    "Reconstructed_Augmented_Normalized_{}_Image_{}.nii.gz".
                    format(dataset, str(current_epoch))))
        ])

        transform_img_input_augmented(img_augmented_input[dataset])
        transform_img_normalized_augmented(img_augmented_normalized[dataset])
コード例 #3
0
ファイル: utils.py プロジェクト: sami-ets/DeepNormalize
def save_rebuilt_image(current_epoch, save_folder, datasets, image,
                       image_type):
    if not os.path.exists(os.path.join(save_folder, "reconstructed_images")):
        os.makedirs(os.path.join(save_folder, "reconstructed_images"))

    for dataset in datasets:
        if image[dataset].shape[0] == 2:
            transform_img = Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    os.path.join(
                        save_folder, "reconstructed_images",
                        "Reconstructed_{}_T1_{}_Image_{}.nii.gz".format(
                            image_type, dataset, str(current_epoch))))
            ])
            transform_img(image[dataset][0])
            transform_img = Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    os.path.join(
                        save_folder, "reconstructed_images",
                        "Reconstructed_{}_T2_{}_Image_{}.nii.gz".format(
                            image_type, dataset, str(current_epoch))))
            ])
            transform_img(image[dataset][1])
        else:
            transform_img = Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    os.path.join(
                        save_folder, "reconstructed_images",
                        "Reconstructed_{}_{}_Image_{}.nii.gz".format(
                            image_type, dataset, str(current_epoch))))
            ])
            transform_img(image[dataset])
コード例 #4
0
ファイル: scale.py プロジェクト: sami-ets/DeepNormalize
    def run(self, prefix=""):
        images_np = list()
        headers = list()
        file_names = list()
        root_dirs = list()

        for root, dirs, files in os.walk(os.path.join(self._root_dir)):
            root_dir_number = os.path.basename(os.path.normpath(root))
            for file in files:
                if not os.path.exists(
                        os.path.join(self._output_dir, root_dir_number)):
                    os.makedirs(os.path.join(self._output_dir,
                                             root_dir_number))

                try:
                    self.LOGGER.info("Processing: {}".format(file))
                    file_names.append(file)
                    root_dirs.append(root_dir_number)
                    images_np.append(self._transforms(os.path.join(root,
                                                                   file)))
                    headers.append(
                        self._get_image_header(os.path.join(root, file)))

                except Exception as e:
                    self.LOGGER.warning(e)

        for i, header in enumerate(headers):
            transforms_ = transforms.Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    os.path.join(os.path.join(self._output_dir, root_dirs[i]),
                                 prefix + file_names[i]))
            ])

            transforms_(images_np[i])
コード例 #5
0
ファイル: pipelinev2.py プロジェクト: jizongFox/DeepNormalize
    def _extract_labels(self, input_dir, output_dir):
        self._mri_binarize(os.path.join(input_dir, "aparc+aseg.mgz"),
                           os.path.join(output_dir, "wm_mask.mgz"), "wm")

        self._mri_binarize(os.path.join(input_dir, "aparc+aseg.mgz"),
                           os.path.join(output_dir, "gm_mask.mgz"), "gm")
        self._mri_binarize(os.path.join(input_dir, "aparc+aseg.mgz"),
                           os.path.join(output_dir, "csf_mask.mgz"), "csf")

        csf_labels = ToNumpyArray()(os.path.join(output_dir, "csf_mask.mgz"))
        gm_labels = self._remap_labels(os.path.join(output_dir, "gm_mask.mgz"),
                                       1, 2)
        wm_labels = self._remap_labels(os.path.join(output_dir, "wm_mask.mgz"),
                                       1, 3)

        # merged = self._merge_volumes(gm_labels, wm_labels, csf_labels)
        #
        # brainmask = ToNumpyArray()(os.path.join(input_dir, "brainmask.mgz"))
        # T1 = ApplyMask(merged)(brainmask)
        #
        # csf = brainmask - T1
        # csf[csf != 0] = 1
        #
        # csf_labels = csf_labels + csf

        merged = self._merge_volumes(gm_labels, wm_labels, csf_labels)

        transform_ = transforms.Compose([
            ToNifti1Image(),
            NiftiToDisk(os.path.join(output_dir, "labels.nii.gz"))
        ])
        transform_(merged)
コード例 #6
0
ファイル: scale.py プロジェクト: sami-ets/DeepNormalize
    def run(self, prefix="scaled_"):
        images_np = list()
        headers = list()
        file_names = list()
        root_dirs = list()

        for root, dirs, files in os.walk(os.path.join(self._root_dir)):
            root_dir_number = os.path.basename(os.path.normpath(root))
            images = list(filter(re.compile(r"^T.*\.nii").search, files))
            for file in images:
                if not os.path.exists(
                        os.path.join(self._output_dir, root_dir_number)):
                    os.makedirs(os.path.join(self._output_dir,
                                             root_dir_number))

                try:
                    self.LOGGER.info("Processing: {}".format(file))
                    file_names.append(file)
                    root_dirs.append(root_dir_number)
                    images_np.append(self._transforms(os.path.join(root,
                                                                   file)))
                    headers.append(
                        self._get_image_header(os.path.join(root, file)))

                except Exception as e:
                    self.LOGGER.warning(e)

        images = np.array(images_np).astype(np.float32)
        images_shape = images.shape
        images = images.reshape(
            images_shape[0], images_shape[1] * images_shape[2] *
            images_shape[3] * images_shape[4])
        transform_ = transforms.Compose([self._scaler])

        transformed_images = transform_(images).reshape(images_shape)

        for i, (image,
                header) in enumerate(zip(range(images.shape[0]), headers)):
            transforms_ = transforms.Compose([
                ToNifti1Image(header),
                NiftiToDisk(
                    os.path.join(os.path.join(self._output_dir, root_dirs[i]),
                                 prefix + file_names[i]))
            ])

            transforms_(transformed_images[i])

        for root, dirs, files in os.walk(self._root_dir):
            root_dir_end = os.path.basename(os.path.normpath(root))

            images = list(
                filter(re.compile(r"^LabelsFor.*\.nii").search, files))

            for file in images:
                if not os.path.exists(
                        os.path.join(self._output_dir, root_dir_end)):
                    os.makedirs(os.path.join(self._output_dir, root_dir_end))
                shutil.copy(os.path.join(root, file),
                            os.path.join(self._output_dir, root_dir_end))
コード例 #7
0
ファイル: pipelinev2.py プロジェクト: sami-ets/DeepNormalize
    def _write_image(self, image, subject, modality):
        if not os.path.exists(os.path.join(self._output_dir, subject,
                                           modality)):
            os.makedirs(os.path.join(self._output_dir, subject, modality))

        transform_ = transforms.Compose([
            ToNifti1Image(),
            NiftiToDisk(
                os.path.join(self._output_dir, subject, modality,
                             modality + ".nii.gz"))
        ])

        transform_(image)
コード例 #8
0
ファイル: pipelinev2.py プロジェクト: sami-ets/DeepNormalize
    def _extract_patches(self, image, subject, modality, patch_size, step):
        transforms_ = transforms.Compose(
            [PadToPatchShape(patch_size=patch_size, step=step)])
        transformed_image = transforms_(image)

        patches = ABIDEPreprocessingPipeline.get_patches(
            transformed_image, patch_size, step)

        for i, patch in enumerate(patches):
            x = transformed_image.x[tuple(patch.slice)]
            transform_ = transforms.Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    os.path.join(
                        os.path.join(self._output_dir, subject, "mri",
                                     "patches", modality),
                        str(i) + ".nii.gz"))
            ])
            transform_(x)
コード例 #9
0
ファイル: pipelinev2.py プロジェクト: sami-ets/DeepNormalize
    def _extract_patches(self, image, subject, modality, patch_size, step):
        transforms_ = transforms.Compose(
            [PadToPatchShape(patch_size=patch_size, step=step)])
        transformed_image = transforms_(image)

        patches = iSEGPipeline.get_patches(transformed_image, patch_size, step)

        if not os.path.exists(os.path.join(self._output_dir, subject,
                                           modality)):
            os.makedirs(os.path.join(self._output_dir, subject, modality))

        for i, patch in enumerate(patches):
            x = transformed_image[tuple(patch.slice)]
            transforms_ = transforms.Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    os.path.join(self._output_dir, subject, modality,
                                 str(i) + ".nii.gz"))
            ])
            transforms_(x)
コード例 #10
0
ファイル: pipelinev2.py プロジェクト: sami-ets/DeepNormalize
    def _extract_labels(self, input_dir, output_dir):
        self._mri_binarize(os.path.join(input_dir, "aparc+aseg.mgz"),
                           os.path.join(output_dir, "wm_mask.mgz"), "wm")

        self._mri_binarize(os.path.join(input_dir, "aparc+aseg.mgz"),
                           os.path.join(output_dir, "gm_mask.mgz"), "gm")
        self._mri_binarize(os.path.join(input_dir, "aparc+aseg.mgz"),
                           os.path.join(output_dir, "csf_mask.mgz"), "csf")

        csf_labels = ToNumpyArray()(os.path.join(output_dir, "csf_mask.mgz"))
        gm_labels = self._remap_labels(os.path.join(output_dir, "gm_mask.mgz"),
                                       1, 2)
        wm_labels = self._remap_labels(os.path.join(output_dir, "wm_mask.mgz"),
                                       1, 3)

        merged = self._merge_volumes(gm_labels, wm_labels, csf_labels)

        transform_ = transforms.Compose([
            ToNifti1Image(),
            NiftiToDisk(os.path.join(output_dir, "labels.nii.gz"))
        ])
        transform_(merged)
コード例 #11
0
ファイル: histogram.py プロジェクト: sami-ets/DeepNormalize
    def run(self, prefix="scaled_"):
        density_histograms = list()
        histograms = list()
        bins = list()
        headers = list()
        file_names = list()
        root_dirs = list()
        EXCLUDED = ["ROI", "label", "Normalized"]

        for root, dirs, files in os.walk(os.path.join(self._root_dir)):
            if os.path.basename(os.path.normpath(root)) in EXCLUDED:
                continue

            root_dir_number = os.path.basename(os.path.normpath(root))
            images = list(filter(re.compile(r".*T.*\.nii").search, files))
            for file in images:
                if not os.path.exists(os.path.join(self._output_dir, root_dir_number)):
                    os.makedirs(os.path.join(self._output_dir, root_dir_number))

                try:
                    self.LOGGER.info("Processing: {}".format(file))
                    file_names.append(file)
                    root_dirs.append(root_dir_number)
                    image = self._transforms(os.path.join(root, file)).flatten()
                    density_histogram, bin = np.histogram(image[np.nonzero(image)], bins=1024, density=True)
                    histogram, bin = np.histogram(image[np.nonzero(image)], bins=1024)
                    density_histograms.append(density_histogram)
                    histograms.append(histogram)
                    bins.append(bin)
                    headers.append(self._get_image_header(os.path.join(root, file)))

                except Exception as e:
                    self.LOGGER.warning(e)

        density_histograms, histograms, bins = np.array(density_histograms), np.array(histograms), np.array(bins)

        for histogram, density_histogram, bin in zip(histograms, density_histograms, bins):
            median_x = np.searchsorted(density_histogram.cumsum(), 0.50)
            mean_x = bin[np.searchsorted(histogram, np.mean(histogram))]

        images_shape = images.shape
        images = images.reshape(images_shape[0], images_shape[1] * images_shape[2] * images_shape[3] * images_shape[4])
        histogram, bins = np.histogram(images.flatten(), 256)
        cdf = histogram.cumsum()

        transform_ = transforms.Compose([self._scaler(self._params)])

        transformed_images = transform_(images).reshape(images_shape)

        for i, (image, header) in enumerate(zip(range(images.shape[0]), headers)):
            transforms_ = transforms.Compose([ToNifti1Image(header),
                                              NiftiToDisk(
                                                  os.path.join(
                                                      os.path.join(self._output_dir, root_dirs[i]),
                                                      prefix + file_names[i]))])

            transforms_(transformed_images[i])

        for root, dirs, files in os.walk(self._root_dir):
            root_dir_end = os.path.basename(os.path.normpath(root))
            if "ROI" in root_dir_end or "label" in root_dir_end:
                for file in files:
                    if not os.path.exists(os.path.join(self._output_dir, root_dir_end)):
                        os.makedirs(os.path.join(self._output_dir, root_dir_end))
                    shutil.copy(os.path.join(root, file), os.path.join(self._output_dir, root_dir_end))
コード例 #12
0
ファイル: scale.py プロジェクト: sami-ets/DeepNormalize
    def run(self, prefix="standardize_"):
        images_np = list()
        headers = list()
        file_names = list()
        root_dirs = list()
        root_dirs_number = list()
        EXCLUDED = ["ROI", "label", "Normalized"]

        for root, dirs, files in os.walk(os.path.join(
                self._root_dir_mrbrains)):
            root_dir_number = os.path.basename(os.path.normpath(root))
            images = list(filter(re.compile(r"^T.*\.nii").search, files))
            for file in images:
                try:
                    self.LOGGER.info("Processing: {}".format(file))
                    file_names.append(file)
                    root_dirs.append(root)
                    root_dirs_number.append(root_dir_number)
                    images_np.append(self._transforms(os.path.join(root,
                                                                   file)))
                    headers.append(
                        self._get_image_header(os.path.join(root, file)))

                except Exception as e:
                    self.LOGGER.warning(e)

        for root, dirs, files in os.walk(os.path.join(self._root_dir_iseg)):
            if os.path.basename(os.path.normpath(root)) in EXCLUDED:
                continue

            root_dir_number = os.path.basename(os.path.normpath(root))
            images = list(filter(re.compile(r".*T.*\.nii").search, files))

            for file in images:
                try:
                    self.LOGGER.info("Processing: {}".format(file))
                    file_names.append(file)
                    root_dirs.append(root)
                    root_dirs_number.append(root_dir_number)
                    images_np.append(self._transforms(os.path.join(root,
                                                                   file)))
                    headers.append(
                        self._get_image_header(os.path.join(root, file)))

                except Exception as e:
                    self.LOGGER.warning(e)

        images = np.array(images_np).astype(np.float32)
        transformed_images = np.subtract(images,
                                         np.mean(images)) / np.std(images)

        for i in range(transformed_images.shape[0]):
            if "MRBrainS" in root_dirs[i]:
                root_dir_number = os.path.basename(
                    os.path.normpath(root_dirs[i]))
                if not os.path.exists(
                        os.path.join(
                            self._output_dir, "MRBrainS/Dual_Standardized/{}".
                            format(root_dir_number))):
                    os.makedirs(
                        os.path.join(
                            self._output_dir,
                            "MRBrainS/Dual_Standardized/{}".format(
                                root_dir_number)))
                transforms_ = transforms.Compose([
                    ToNifti1Image(),
                    NiftiToDisk(
                        os.path.join(
                            os.path.join(
                                self._output_dir,
                                os.path.join("MRBrainS/Dual_Standardized",
                                             root_dir_number)),
                            prefix + file_names[i]))
                ])
                transforms_(transformed_images[i])
            elif "iSEG" in root_dirs[i]:
                root_dir_number = os.path.basename(
                    os.path.normpath(root_dirs[i]))
                if not os.path.exists(
                        os.path.join(
                            self._output_dir, "iSEG/Dual_Standardized/{}".
                            format(root_dir_number))):
                    os.makedirs(
                        os.path.join(
                            self._output_dir,
                            "iSEG/Dual_Standardized/{}".format(
                                root_dir_number)))
                transforms_ = transforms.Compose([
                    ToNifti1Image(),
                    NiftiToDisk(
                        os.path.join(
                            os.path.join(
                                self._output_dir,
                                os.path.join("iSEG/Dual_Standardized",
                                             root_dir_number)),
                            prefix + file_names[i]))
                ])

                transforms_(transformed_images[i])

        for root, dirs, files in os.walk(self._root_dir_mrbrains):
            root_dir_end = os.path.basename(os.path.normpath(root))

            images = list(
                filter(re.compile(r"^LabelsFor.*\.nii").search, files))

            for file in images:
                if not os.path.exists(
                        os.path.join(
                            self._output_dir,
                            os.path.join("MRBrainS/Dual_Standardized",
                                         root_dir_end))):
                    os.makedirs(
                        os.path.join(
                            self._output_dir,
                            os.path.join("MRBrainS/Dual_Standardized",
                                         root_dir_end)))

                transforms_ = transforms.Compose([
                    ToNumpyArray(),
                    CropToContent(),
                    PadToShape(self._normalized_shape),
                    ToNifti1Image(),
                    NiftiToDisk(
                        os.path.join(
                            os.path.join(
                                self._output_dir,
                                os.path.join("MRBrainS/Dual_Standardized",
                                             root_dir_end)), file))
                ])

                transforms_(os.path.join(root, file))

        for root, dirs, files in os.walk(self._root_dir_iseg):
            root_dir_end = os.path.basename(os.path.normpath(root))
            if "ROI" in root_dir_end or "label" in root_dir_end:
                for file in files:
                    if not os.path.exists(
                            os.path.join(
                                self._output_dir,
                                os.path.join("iSEG/Dual_Standardized",
                                             root_dir_end))):
                        os.makedirs(
                            os.path.join(
                                self._output_dir,
                                os.path.join("iSEG/Dual_Standardized",
                                             root_dir_end)))
                    transforms_ = transforms.Compose([
                        ToNumpyArray(),
                        CropToContent(),
                        PadToShape(self._normalized_shape),
                        ToNifti1Image(),
                        NiftiToDisk(
                            os.path.join(
                                os.path.join(
                                    self._output_dir,
                                    os.path.join("iSEG/Dual_Standardized",
                                                 root_dir_end)), file))
                    ])

                    transforms_(os.path.join(root, file))
コード例 #13
0
    def reconstruct_from_patches_3d(self):
        datasets = list(
            map(
                lambda image: SingleNDArrayDataset(image,
                                                   patch_size=self._patch_size,
                                                   step=self._step,
                                                   prob_bias=self._prob_bias,
                                                   prob_noise=self._prob_noise,
                                                   alpha=self._alpha,
                                                   snr=self._snr),
                self._images))

        data_loaders = list(
            map(
                lambda dataset: DataLoader(dataset,
                                           batch_size=self._batch_size,
                                           num_workers=0,
                                           drop_last=False,
                                           shuffle=False,
                                           pin_memory=False,
                                           collate_fn=self.custom_collate),
                datasets))

        if len(datasets) == 2:
            reconstructed_image = [
                np.zeros(datasets[0].image_shape),
                np.zeros(datasets[1].image_shape)
            ]

            for idx, (iseg_inputs, mrbrains_inputs) in enumerate(
                    zip(data_loaders[ISEG_ID], data_loaders[MRBRAINS_ID])):
                inputs = torch.cat(
                    (iseg_inputs[PATCH], mrbrains_inputs[PATCH]))
                slices = [iseg_inputs[SLICE], mrbrains_inputs[SLICE]]

                if self._do_normalize:
                    patches = torch.nn.functional.sigmoid(
                        (self._models[GENERATOR](inputs.to(DEVICE))))

                elif self._do_normalize_and_segment:
                    normalized_patches = torch.nn.functional.sigmoid(
                        (self._models[GENERATOR](inputs.to(DEVICE))))
                    patches = torch.argmax(torch.nn.functional.softmax(
                        self._models[SEGMENTER](normalized_patches), dim=1),
                                           dim=1,
                                           keepdim=True)
                else:
                    patches = inputs

                for pred_patch, slice in zip(patches[0:self._batch_size],
                                             slices[ISEG_ID]):
                    reconstructed_image[ISEG_ID][slice] = reconstructed_image[ISEG_ID][slice] + \
                                                          pred_patch.data.cpu().numpy()

                for pred_patch, slice in zip(
                        patches[self._batch_size:self._batch_size * 2],
                        slices[MRBRAINS_ID]):
                    reconstructed_image[MRBRAINS_ID][slice] = reconstructed_image[MRBRAINS_ID][slice] + \
                                                              pred_patch.data.cpu().numpy()

            if self._do_normalize_and_segment or self._is_ground_truth:
                reconstructed_image[ISEG_ID] = np.clip(
                    np.round(reconstructed_image[ISEG_ID] *
                             self._overlap_maps[ISEG_ID]),
                    a_min=0,
                    a_max=3)
                reconstructed_image[MRBRAINS_ID] = np.clip(
                    np.round(reconstructed_image[MRBRAINS_ID] *
                             self._overlap_maps[MRBRAINS_ID]),
                    a_min=0,
                    a_max=3)
            else:
                reconstructed_image[ISEG_ID] = reconstructed_image[
                    ISEG_ID] * self._overlap_maps[ISEG_ID]
                reconstructed_image[MRBRAINS_ID] = reconstructed_image[
                    MRBRAINS_ID] * self._overlap_maps[MRBRAINS_ID]

            transforms_ = transforms.Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    "reconstructed_iseg_image_generated_noise_{}_alpha_{}.nii.gz"
                    .format(self._snr, self._alpha))
            ])
            transforms_(reconstructed_image[ISEG_ID])
            transforms_ = transforms.Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    "reconstructed_mrbrains_image_generated_noise_{}_alpha_{}.nii.gz"
                    .format(self._snr, self._alpha))
            ])
            transforms_(reconstructed_image[MRBRAINS_ID])

        if len(datasets) == 3:
            reconstructed_image = [
                np.zeros(datasets[0].image_shape),
                np.zeros(datasets[1].image_shape),
                np.zeros(datasets[2].image_shape)
            ]

            for idx, (iseg_inputs, mrbrains_inputs, abide_inputs) in enumerate(
                    zip(data_loaders[ISEG_ID], data_loaders[MRBRAINS_ID],
                        data_loaders[ABIDE_ID])):
                inputs = torch.cat((iseg_inputs[PATCH], mrbrains_inputs[PATCH],
                                    abide_inputs[PATCH]))
                slices = [
                    iseg_inputs[SLICE], mrbrains_inputs[SLICE],
                    abide_inputs[SLICE]
                ]

                if self._do_normalize:
                    patches = torch.nn.functional.sigmoid(
                        (self._models[GENERATOR](inputs)))

                elif self._do_normalize_and_segment:
                    normalized_patches = torch.nn.functional.sigmoid(
                        (self._models[GENERATOR](inputs)))
                    patches = torch.argmax(torch.nn.functional.softmax(
                        self._models[SEGMENTER](normalized_patches), dim=1),
                                           dim=1,
                                           keepdim=True)
                else:
                    patches = inputs

                for pred_patch, slice in zip(patches[0:self._batch_size],
                                             slices[ISEG_ID]):
                    reconstructed_image[ISEG_ID][slice] = reconstructed_image[ISEG_ID][slice] + \
                                                          pred_patch.data.cpu().numpy()

                for pred_patch, slice in zip(
                        patches[self._batch_size:self._batch_size * 2],
                        slices[MRBRAINS_ID]):
                    reconstructed_image[MRBRAINS_ID][slice] = reconstructed_image[MRBRAINS_ID][slice] + \
                                                              pred_patch.data.cpu().numpy()

                for pred_patch, slice in zip(
                        patches[self._batch_size * 2:self._batch_size * 3],
                        slices[ABIDE_ID]):
                    reconstructed_image[MRBRAINS_ID][slice] = reconstructed_image[ABIDE_ID][slice] + \
                                                              pred_patch.data.cpu().numpy()

            reconstructed_image[ISEG_ID] = reconstructed_image[
                ISEG_ID] * self._overlap_maps[ISEG_ID]
            reconstructed_image[MRBRAINS_ID] = reconstructed_image[
                MRBRAINS_ID] * self._overlap_maps[MRBRAINS_ID]
            reconstructed_image[ABIDE_ID] = reconstructed_image[
                ABIDE_ID] * self._overlap_maps[ABIDE_ID]

            if self._do_normalize_and_segment:
                reconstructed_image[ISEG_ID] = np.clip(np.round(
                    reconstructed_image[ISEG_ID]),
                                                       a_min=0,
                                                       a_max=3)
                reconstructed_image[MRBRAINS_ID] = np.clip(np.round(
                    reconstructed_image[MRBRAINS_ID]),
                                                           a_min=0,
                                                           a_max=3)
                reconstructed_image[ABIDE_ID] = np.clip(np.round(
                    reconstructed_image[ABIDE_ID]),
                                                        a_min=0,
                                                        a_max=3)

            transforms_ = transforms.Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    "reconstructed_iseg_image_generated_noise_{}_alpha_{}.nii.gz"
                    .format(self._snr, self._alpha))
            ])
            transforms_(reconstructed_image[ISEG_ID])
            transforms_ = transforms.Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    "reconstructed_mrbrains_image_generated_noise_{}_alpha_{}.nii.gz"
                    .format(self._snr, self._alpha))
            ])
            transforms_(reconstructed_image[MRBRAINS_ID])
            transforms_ = transforms.Compose([
                ToNifti1Image(),
                NiftiToDisk(
                    "reconstructed_abide_image_generated_noise_{}_alpha_{}.nii.gz"
                    .format(self._snr, self._alpha))
            ])
            transforms_(reconstructed_image[ABIDE_ID])