예제 #1
0
    def on_test_epoch_end(self):
        if self.epoch % 10 == 0:
            self._per_dataset_hausdorff_distance_gauge.reset()
            self._class_dice_gauge_on_reconstructed_iseg_images.reset()
            self._class_dice_gauge_on_reconstructed_mrbrains_images.reset()
            self._class_dice_gauge_on_reconstructed_abide_images.reset()
            self._hausdorff_distance_gauge_on_reconstructed_iseg_images.reset()
            self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.reset()
            self._hausdorff_distance_gauge_on_reconstructed_abide_images.reset()

            img_input = self._input_reconstructor.reconstruct_from_patches_3d()
            img_gt = self._gt_reconstructor.reconstruct_from_patches_3d()
            img_seg = self._segmentation_reconstructor.reconstruct_from_patches_3d()

            save_rebuilt_image(self._current_epoch, self._save_folder, self._dataset_configs.keys(), img_input,
                               "Input")
            save_rebuilt_image(self._current_epoch, self._save_folder, self._dataset_configs.keys(), img_gt,
                               "Ground_Truth")
            save_rebuilt_image(self._current_epoch, self._save_folder, self._dataset_configs.keys(), img_seg,
                               "Segmented")

            if self._training_config.build_augmented_images:
                img_augmented_input = self._augmented_input_reconstructor.reconstruct_from_patches_3d()
                img_augmented_normalized = self._augmented_normalized_reconstructor.reconstruct_from_patches_3d()
                save_augmented_rebuilt_images(self._current_epoch, self._save_folder, self._dataset_configs.keys(),
                                              img_augmented_input, img_augmented_normalized)

            mean_mhd = []
            for dataset in self._dataset_configs.keys():
                self.custom_variables[
                    "Reconstructed Segmented {} Image".format(dataset)] = self._seg_slicer.get_colored_slice(
                    SliceType.AXIAL, np.expand_dims(img_seg[dataset], 0), 160).squeeze(0)
                self.custom_variables[
                    "Reconstructed Ground Truth {} Image".format(dataset)] = self._seg_slicer.get_colored_slice(
                    SliceType.AXIAL, np.expand_dims(img_gt[dataset], 0), 160).squeeze(0)
                self.custom_variables[
                    "Reconstructed Input {} Image".format(dataset)] = self._slicer.get_slice(
                    SliceType.AXIAL, np.expand_dims(img_input[dataset], 0), 160)

                if self._training_config.build_augmented_images:
                    self.custom_variables[
                        "Reconstructed Augmented Input {} Image".format(dataset)] = self._slicer.get_slice(
                        SliceType.AXIAL, np.expand_dims(np.expand_dims(img_augmented_input[dataset], 0), 0), 160)
                    self.custom_variables[
                        "Reconstructed Augmented {} After Normalization".format(
                            dataset)] = self._seg_slicer.get_colored_slice(
                        SliceType.AXIAL,
                        np.expand_dims(np.expand_dims(img_augmented_normalized[dataset], 0), 0), 160).squeeze(0)
                else:
                    self.custom_variables["Reconstructed Augmented Input {} Image".format(
                        dataset)] = np.zeros((224, 192))
                    self.custom_variables[
                        "Reconstructed Initial Noise {} Image".format(
                            dataset)] = np.zeros((224, 192))
                    self.custom_variables[
                        "Reconstructed Augmented {} After Normalization".format(
                            dataset)] = np.zeros((224, 192))

                mean_mhd.append(mean_hausdorff_distance(
                    to_onehot(torch.tensor(img_gt[dataset], dtype=torch.long), num_classes=4),
                    to_onehot(torch.tensor(img_seg[dataset], dtype=torch.long), num_classes=4))[-3:].mean())

                metric = self._model_trainers[0].compute_metrics(
                    to_onehot(torch.tensor(img_seg[dataset]).unsqueeze(0).long(), num_classes=4),
                    torch.tensor(img_gt[dataset]).unsqueeze(0).long())

                self._class_dice_gauge_on_reconstructed_images.update(np.array(metric["Dice"]))

            self._per_dataset_hausdorff_distance_gauge.update(np.array(mean_mhd))

            if "iSEG" in img_seg:
                metric = self._model_trainers[0].compute_metrics(
                    to_onehot(torch.tensor(img_seg["iSEG"]).unsqueeze(0).long(), num_classes=4),
                    torch.tensor(img_gt["iSEG"]).unsqueeze(0).long())
                self._class_dice_gauge_on_reconstructed_iseg_images.update(np.array(metric["Dice"]))
                self._hausdorff_distance_gauge_on_reconstructed_iseg_images.update(mean_hausdorff_distance(
                    to_onehot(torch.tensor(img_gt["iSEG"], dtype=torch.long), num_classes=4),
                    to_onehot(torch.tensor(img_seg["iSEG"], dtype=torch.long), num_classes=4))[-3:])
            else:
                self._class_dice_gauge_on_reconstructed_iseg_images.update(np.array([0.0, 0.0, 0.0]))
                self._hausdorff_distance_gauge_on_reconstructed_iseg_images.update(np.array([0.0, 0.0, 0.0]))
            if "MRBrainS" in img_seg:
                metric = self._model_trainers[0].compute_metrics(
                    to_onehot(torch.tensor(img_seg["MRBrainS"]).unsqueeze(0).long(), num_classes=4),
                    torch.tensor(img_gt["MRBrainS"]).unsqueeze(0).long())
                self._class_dice_gauge_on_reconstructed_mrbrains_images.update(np.array(metric["Dice"]))
                self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.update(mean_hausdorff_distance(
                    to_onehot(torch.tensor(img_gt["MRBrainS"], dtype=torch.long), num_classes=4),
                    to_onehot(torch.tensor(img_seg["MRBrainS"], dtype=torch.long), num_classes=4))[-3:])
            else:
                self._class_dice_gauge_on_reconstructed_mrbrains_images.update(np.array([0.0, 0.0, 0.0]))
                self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.update(np.array([0.0, 0.0, 0.0]))
            if "ABIDE" in img_seg:
                metric = self._model_trainers[0].compute_metrics(
                    to_onehot(torch.tensor(img_seg["ABIDE"]).unsqueeze(0).long(), num_classes=4),
                    torch.tensor(img_gt["ABIDE"]).unsqueeze(0).long())
                self._class_dice_gauge_on_reconstructed_abide_images.update(np.array(metric["Dice"]))
                self._hausdorff_distance_gauge_on_reconstructed_abide_images.update(mean_hausdorff_distance(
                    to_onehot(torch.tensor(img_gt["ABIDE"], dtype=torch.long), num_classes=4),
                    to_onehot(torch.tensor(img_seg["ABIDE"], dtype=torch.long), num_classes=4))[-3:])
            else:
                self._class_dice_gauge_on_reconstructed_abide_images.update(np.array([0.0, 0.0, 0.0]))
                self._hausdorff_distance_gauge_on_reconstructed_abide_images.update(np.array([0.0, 0.0, 0.0]))

        if "ABIDE" not in self._dataset_configs.keys():
            self.custom_variables["Reconstructed Segmented ABIDE Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Ground Truth ABIDE Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Input ABIDE Image"] = np.zeros((224, 192))
        if "iSEG" not in self._dataset_configs.keys():
            self.custom_variables["Reconstructed Segmented iSEG Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Ground Truth iSEG Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Input iSEG Image"] = np.zeros((224, 192))
        if "MRBrainS" not in self._dataset_configs.keys():
            self.custom_variables["Reconstructed Segmented MRBrainS Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Ground Truth MRBrainS Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Input MRBrainS Image"] = np.zeros((224, 192))

        self.custom_variables["Runtime"] = to_html_time(timedelta(seconds=time.time() - self._start_time))

        if self._general_confusion_matrix_gauge._num_examples != 0:
            self.custom_variables["Confusion Matrix"] = np.array(
                np.fliplr(self._general_confusion_matrix_gauge.compute().cpu().detach().numpy()))
        else:
            self.custom_variables["Confusion Matrix"] = np.zeros((4, 4))

        if self._iSEG_confusion_matrix_gauge._num_examples != 0:
            self.custom_variables["iSEG Confusion Matrix"] = np.array(
                np.fliplr(self._iSEG_confusion_matrix_gauge.compute().cpu().detach().numpy()))
        else:
            self.custom_variables["iSEG Confusion Matrix"] = np.zeros((4, 4))

        if self._MRBrainS_confusion_matrix_gauge._num_examples != 0:
            self.custom_variables["MRBrainS Confusion Matrix"] = np.array(
                np.fliplr(self._MRBrainS_confusion_matrix_gauge.compute().cpu().detach().numpy()))
        else:
            self.custom_variables["MRBrainS Confusion Matrix"] = np.zeros((4, 4))

        if self._ABIDE_confusion_matrix_gauge._num_examples != 0:
            self.custom_variables["ABIDE Confusion Matrix"] = np.array(
                np.fliplr(self._ABIDE_confusion_matrix_gauge.compute().cpu().detach().numpy()))
        else:
            self.custom_variables["ABIDE Confusion Matrix"] = np.zeros((4, 4))

        self.custom_variables["Metric Table"] = to_html(["CSF", "Grey Matter", "White Matter"],
                                                        ["DSC", "HD"],
                                                        [
                                                            self._class_dice_gauge_on_patches.compute() if self._class_dice_gauge_on_patches.has_been_updated() else np.array(
                                                                [0.0, 0.0, 0.0]),
                                                            self._class_hausdorff_distance_gauge.compute() if self._class_hausdorff_distance_gauge.has_been_updated() else np.array(
                                                                [0.0, 0.0, 0.0])
                                                        ])

        self.custom_variables[
            "Dice score per class per epoch"] = self._class_dice_gauge_on_patches.compute() if self._class_dice_gauge_on_patches.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Dice score per class per epoch on reconstructed image"] = self._class_dice_gauge_on_reconstructed_images.compute() if self._class_dice_gauge_on_reconstructed_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Dice score per class per epoch on reconstructed iSEG image"] = self._class_dice_gauge_on_reconstructed_iseg_images.compute() if self._class_dice_gauge_on_reconstructed_iseg_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Dice score per class per epoch on reconstructed MRBrainS image"] = self._class_dice_gauge_on_reconstructed_mrbrains_images.compute() if self._class_dice_gauge_on_reconstructed_mrbrains_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Dice score per class per epoch on reconstructed ABIDE image"] = self._class_dice_gauge_on_reconstructed_abide_images.compute() if self._class_dice_gauge_on_reconstructed_abide_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Hausdorff Distance per class per epoch on reconstructed iSEG image"] = self._hausdorff_distance_gauge_on_reconstructed_iseg_images.compute() if self._hausdorff_distance_gauge_on_reconstructed_iseg_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Hausdorff Distance per class per epoch on reconstructed MRBrainS image"] = self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.compute() if self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Hausdorff Distance per class per epoch on reconstructed ABIDE image"] = self._hausdorff_distance_gauge_on_reconstructed_abide_images.compute() if self._hausdorff_distance_gauge_on_reconstructed_abide_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])

        if self._valid_dice_gauge.compute() > self._previous_mean_dice:
            new_table = to_html_per_dataset(
                ["CSF", "Grey Matter", "White Matter"],
                ["DSC", "HD"],
                [
                    [
                        self._iSEG_dice_gauge.compute() if self._iSEG_dice_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0]),
                        self._iSEG_hausdorff_gauge.compute() if self._iSEG_hausdorff_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0])],
                    [
                        self._MRBrainS_dice_gauge.compute() if self._MRBrainS_dice_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0]),
                        self._MRBrainS_hausdorff_gauge.compute() if self._MRBrainS_hausdorff_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0])],
                    [
                        self._ABIDE_dice_gauge.compute() if self._ABIDE_dice_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0]),
                        self._ABIDE_hausdorff_gauge.compute() if self._ABIDE_hausdorff_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0])]],
                ["iSEG", "MRBrainS", "ABIDE"])

            self.custom_variables["Per-Dataset Metric Table"] = new_table
            self._previous_mean_dice = self._valid_dice_gauge.compute()
            self._previous_per_dataset_table = new_table
        else:
            self.custom_variables["Per-Dataset Metric Table"] = self._previous_per_dataset_table
        self._valid_dice_gauge.reset()

        self.custom_variables["Mean Hausdorff Distance"] = [
            self._class_hausdorff_distance_gauge.compute().mean() if self._class_hausdorff_distance_gauge.has_been_updated() else np.array(
                [0.0])]

        self.custom_variables[
            "Per Dataset Mean Hausdorff Distance"] = self._per_dataset_hausdorff_distance_gauge.compute()
예제 #2
0
    def test_step(self, inputs, target):
        if self._should_activate_autoencoder():
            gen_pred = self._test_g(self._model_trainers[GENERATOR],
                                    inputs[NON_AUGMENTED_INPUTS])

            seg_pred, _ = self._test_s(self._model_trainers[SEGMENTER],
                                       inputs[NON_AUGMENTED_INPUTS],
                                       target[IMAGE_TARGET],
                                       self._class_dice_gauge_on_patches)

        if self._should_activate_segmentation():
            gen_pred = self._test_g(self._model_trainers[GENERATOR],
                                    inputs[AUGMENTED_INPUTS])

            seg_pred, loss_S = self._test_s(self._model_trainers[SEGMENTER],
                                            gen_pred, target[IMAGE_TARGET],
                                            self._class_dice_gauge_on_patches)

            if seg_pred[torch.where(
                    target[DATASET_ID] == ISEG_ID)].shape[0] != 0:
                self._iSEG_dice_gauge.update(
                    np.array(self._model_trainers[SEGMENTER].compute_metrics(
                        torch.nn.functional.softmax(seg_pred[torch.where(
                            target[DATASET_ID] == ISEG_ID)],
                                                    dim=1),
                        torch.squeeze(target[IMAGE_TARGET][torch.where(
                            target[DATASET_ID] == ISEG_ID)],
                                      dim=1).long())["Dice"].numpy()))

                self._iSEG_hausdorff_gauge.update(
                    mean_hausdorff_distance(
                        to_onehot(torch.argmax(torch.nn.functional.softmax(
                            seg_pred[torch.where(
                                target[DATASET_ID] == ISEG_ID)],
                            dim=1),
                                               dim=1),
                                  num_classes=4),
                        to_onehot(torch.squeeze(
                            target[IMAGE_TARGET][torch.where(
                                target[DATASET_ID] == ISEG_ID)],
                            dim=1).long(),
                                  num_classes=4))[-3:])

                self._iSEG_confusion_matrix_gauge.update(
                    (to_onehot(torch.argmax(torch.nn.functional.softmax(
                        seg_pred[torch.where(target[DATASET_ID] == ISEG_ID)],
                        dim=1),
                                            dim=1,
                                            keepdim=False),
                               num_classes=4),
                     torch.squeeze(target[IMAGE_TARGET][torch.where(
                         target[DATASET_ID] == ISEG_ID)].long(),
                                   dim=1)))

            else:
                self._iSEG_dice_gauge.update(np.zeros((3, )))
                self._iSEG_hausdorff_gauge.update(np.zeros((3, )))

            if seg_pred[torch.where(
                    target[DATASET_ID] == MRBRAINS_ID)].shape[0] != 0:
                self._MRBrainS_dice_gauge.update(
                    np.array(self._model_trainers[SEGMENTER].compute_metrics(
                        torch.nn.functional.softmax(seg_pred[torch.where(
                            target[DATASET_ID] == MRBRAINS_ID)],
                                                    dim=1),
                        torch.squeeze(target[IMAGE_TARGET][torch.where(
                            target[DATASET_ID] == MRBRAINS_ID)],
                                      dim=1).long())["Dice"].numpy()))

                self._MRBrainS_hausdorff_gauge.update(
                    mean_hausdorff_distance(
                        to_onehot(torch.argmax(torch.nn.functional.softmax(
                            seg_pred[torch.where(
                                target[DATASET_ID] == MRBRAINS_ID)],
                            dim=1),
                                               dim=1),
                                  num_classes=4),
                        to_onehot(torch.squeeze(
                            target[IMAGE_TARGET][torch.where(
                                target[DATASET_ID] == MRBRAINS_ID)],
                            dim=1).long(),
                                  num_classes=4))[-3:])

                self._MRBrainS_confusion_matrix_gauge.update(
                    (to_onehot(torch.argmax(torch.nn.functional.softmax(
                        seg_pred[torch.where(
                            target[DATASET_ID] == MRBRAINS_ID)],
                        dim=1),
                                            dim=1,
                                            keepdim=False),
                               num_classes=4),
                     torch.squeeze(target[IMAGE_TARGET][torch.where(
                         target[DATASET_ID] == MRBRAINS_ID)].long(),
                                   dim=1)))
            else:
                self._MRBrainS_dice_gauge.update(np.zeros((3, )))
                self._MRBrainS_hausdorff_gauge.update(np.zeros((3, )))

            if seg_pred[torch.where(
                    target[DATASET_ID] == ABIDE_ID)].shape[0] != 0:
                self._ABIDE_dice_gauge.update(
                    np.array(self._model_trainers[SEGMENTER].compute_metrics(
                        torch.nn.functional.softmax(seg_pred[torch.where(
                            target[DATASET_ID] == ABIDE_ID)],
                                                    dim=1),
                        torch.squeeze(target[IMAGE_TARGET][torch.where(
                            target[DATASET_ID] == ABIDE_ID)],
                                      dim=1).long())["Dice"].numpy()))

                self._ABIDE_hausdorff_gauge.update(
                    mean_hausdorff_distance(
                        to_onehot(torch.argmax(torch.nn.functional.softmax(
                            seg_pred[torch.where(
                                target[DATASET_ID] == ABIDE_ID)],
                            dim=1),
                                               dim=1),
                                  num_classes=4),
                        to_onehot(torch.squeeze(
                            target[IMAGE_TARGET][torch.where(
                                target[DATASET_ID] == ABIDE_ID)],
                            dim=1).long(),
                                  num_classes=4))[-3:])

                self._ABIDE_confusion_matrix_gauge.update(
                    (to_onehot(torch.argmax(torch.nn.functional.softmax(
                        seg_pred[torch.where(target[DATASET_ID] == ABIDE_ID)],
                        dim=1),
                                            dim=1,
                                            keepdim=False),
                               num_classes=4),
                     torch.squeeze(target[IMAGE_TARGET][torch.where(
                         target[DATASET_ID] == ABIDE_ID)].long(),
                                   dim=1)))

            self._class_hausdorff_distance_gauge.update(
                mean_hausdorff_distance(
                    to_onehot(torch.argmax(torch.nn.functional.softmax(
                        seg_pred, dim=1),
                                           dim=1),
                              num_classes=4),
                    to_onehot(torch.squeeze(target[IMAGE_TARGET],
                                            dim=1).long(),
                              num_classes=4))[-3:])

            self._general_confusion_matrix_gauge.update(
                (to_onehot(torch.argmax(torch.nn.functional.softmax(seg_pred,
                                                                    dim=1),
                                        dim=1,
                                        keepdim=False),
                           num_classes=4),
                 torch.squeeze(target[IMAGE_TARGET].long(), dim=1)))

            inputs_reshaped = inputs[AUGMENTED_INPUTS].reshape(
                inputs[AUGMENTED_INPUTS].shape[0],
                inputs[AUGMENTED_INPUTS].shape[1] *
                inputs[AUGMENTED_INPUTS].shape[2] *
                inputs[AUGMENTED_INPUTS].shape[3] *
                inputs[AUGMENTED_INPUTS].shape[4])

            c, d, h, w = inputs[AUGMENTED_INPUTS].shape[1], inputs[AUGMENTED_INPUTS].shape[2], \
                         inputs[AUGMENTED_INPUTS].shape[3], inputs[AUGMENTED_INPUTS].shape[4]

            hist_inputs = torch.cat([
                torch.histc(inputs[AUGMENTED_INPUTS][i].view(1, c * d * h * w),
                            bins=256,
                            min=0,
                            max=1).unsqueeze(0)
                for i in range(inputs[0].shape[0])
            ]).unsqueeze(0)
            hist_inputs = hist_inputs / (c * d * h * w)
            hist_inputs = torch.nn.Softmax(dim=2)(hist_inputs)

            hist_gen = torch.cat([
                torch.histc(gen_pred[i].view(1, c * d * h * w),
                            bins=256,
                            min=0,
                            max=1).unsqueeze(0)
                for i in range(gen_pred.shape[0])
            ]).unsqueeze(0)
            hist_gen = hist_gen / (c * d * h * w)
            hist_gen = torch.nn.Softmax(dim=2)(hist_gen)

            self._js_div_inputs_gauge.update(js_div(hist_inputs).item())
            self._js_div_gen_gauge.update(js_div(hist_gen).item())

        if self.current_test_step % 100 == 0:
            self._update_histograms(inputs[NON_AUGMENTED_INPUTS], target,
                                    gen_pred)
            self._update_image_plots(
                self.phase, inputs[NON_AUGMENTED_INPUTS].cpu().detach(),
                gen_pred.cpu().detach(),
                seg_pred.cpu().detach(), target[IMAGE_TARGET].cpu().detach(),
                target[DATASET_ID].cpu().detach())
예제 #3
0
    def test_step(self, inputs, target):
        inputs, target = self._sampler(inputs, target)
        target = target[AUGMENTED_TARGETS]

        seg_pred, _ = self._test_s(self._model_trainers[0], inputs[AUGMENTED_INPUTS], target[IMAGE_TARGET],
                                   self._class_dice_gauge_on_patches)

        if self.current_test_step % 100 == 0:
            self._update_histograms(inputs[AUGMENTED_INPUTS], target)
            self._update_image_plots(self.phase, inputs[AUGMENTED_INPUTS].cpu().detach(),
                                     seg_pred.cpu().detach(),
                                     target[IMAGE_TARGET].cpu().detach(),
                                     target[DATASET_ID].cpu().detach())

        if seg_pred[torch.where(target[DATASET_ID] == ISEG_ID)].shape[0] != 0:
            self._iSEG_dice_gauge.update(np.array(self._model_trainers[0].compute_metrics(
                torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == ISEG_ID)], dim=1),
                torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == ISEG_ID)],
                              dim=1).long())["Dice"].numpy()))

            self._iSEG_hausdorff_gauge.update(mean_hausdorff_distance(
                to_onehot(
                    torch.argmax(
                        torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == ISEG_ID)], dim=1),
                        dim=1), num_classes=4),
                to_onehot(
                    torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == ISEG_ID)], dim=1).long(),
                    num_classes=4))[-3:])

            self._iSEG_confusion_matrix_gauge.update((
                to_onehot(
                    torch.argmax(
                        torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == ISEG_ID)], dim=1),
                        dim=1, keepdim=False),
                    num_classes=4),
                torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == ISEG_ID)].long(), dim=1)))

        else:
            self._iSEG_dice_gauge.update(np.zeros((3,)))
            self._iSEG_hausdorff_gauge.update(np.zeros((3,)))

        if seg_pred[torch.where(target[DATASET_ID] == MRBRAINS_ID)].shape[0] != 0:
            self._MRBrainS_dice_gauge.update(np.array(self._model_trainers[0].compute_metrics(
                torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == MRBRAINS_ID)], dim=1),
                torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == MRBRAINS_ID)],
                              dim=1).long())["Dice"].numpy()))

            self._MRBrainS_hausdorff_gauge.update(mean_hausdorff_distance(
                to_onehot(
                    torch.argmax(
                        torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == MRBRAINS_ID)],
                                                    dim=1),
                        dim=1), num_classes=4),
                to_onehot(
                    torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == MRBRAINS_ID)],
                                  dim=1).long(),
                    num_classes=4))[-3:])

            self._MRBrainS_confusion_matrix_gauge.update((
                to_onehot(
                    torch.argmax(
                        torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == MRBRAINS_ID)],
                                                    dim=1),
                        dim=1, keepdim=False),
                    num_classes=4),
                torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == MRBRAINS_ID)].long(), dim=1)))

        else:
            self._MRBrainS_dice_gauge.update(np.zeros((3,)))
            self._MRBrainS_hausdorff_gauge.update(np.zeros((3,)))

        if seg_pred[torch.where(target[DATASET_ID] == ABIDE_ID)].shape[0] != 0:
            self._ABIDE_dice_gauge.update(np.array(self._model_trainers[0].compute_metrics(
                torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == ABIDE_ID)], dim=1),
                torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == ABIDE_ID)],
                              dim=1).long())["Dice"].numpy()))

            self._ABIDE_hausdorff_gauge.update(mean_hausdorff_distance(
                to_onehot(
                    torch.argmax(
                        torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == ABIDE_ID)], dim=1),
                        dim=1), num_classes=4),
                to_onehot(
                    torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == ABIDE_ID)], dim=1).long(),
                    num_classes=4))[-3:])

            self._ABIDE_confusion_matrix_gauge.update((
                to_onehot(
                    torch.argmax(
                        torch.nn.functional.softmax(seg_pred[torch.where(target[DATASET_ID] == ABIDE_ID)], dim=1),
                        dim=1, keepdim=False),
                    num_classes=4),
                torch.squeeze(target[IMAGE_TARGET][torch.where(target[DATASET_ID] == ABIDE_ID)].long(), dim=1)))

        self._class_hausdorff_distance_gauge.update(
            mean_hausdorff_distance(
                to_onehot(torch.argmax(torch.nn.functional.softmax(seg_pred, dim=1), dim=1), num_classes=4),
                to_onehot(torch.squeeze(target[IMAGE_TARGET], dim=1).long(), num_classes=4))[-3:])

        self._general_confusion_matrix_gauge.update((
            to_onehot(torch.argmax(torch.nn.functional.softmax(seg_pred, dim=1), dim=1, keepdim=False),
                      num_classes=4),
            torch.squeeze(target[IMAGE_TARGET].long(), dim=1)))
예제 #4
0
    def on_test_epoch_end(self):
        if self.epoch % 20 == 0:
            self._per_dataset_hausdorff_distance_gauge.reset()
            self._class_dice_gauge_on_reconstructed_iseg_images.reset()
            self._class_dice_gauge_on_reconstructed_mrbrains_images.reset()
            self._class_dice_gauge_on_reconstructed_abide_images.reset()
            self._hausdorff_distance_gauge_on_reconstructed_iseg_images.reset()
            self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.reset()
            self._hausdorff_distance_gauge_on_reconstructed_abide_images.reset()

            all_patches, ground_truth_patches = get_all_patches(self._reconstruction_datasets, self._is_sliced)

            img_input = rebuild_image(self._dataset_configs.keys(), all_patches, self._input_reconstructors)
            img_gt = rebuild_image(self._dataset_configs.keys(), ground_truth_patches, self._gt_reconstructors)
            img_norm = rebuild_image(self._dataset_configs.keys(), all_patches, self._normalize_reconstructors)
            img_seg = rebuild_image(self._dataset_configs.keys(), all_patches, self._segmentation_reconstructors)

            save_rebuilt_image(self._current_epoch, self._save_folder, self._dataset_configs.keys(), img_input, "Input")
            save_rebuilt_image(self._current_epoch, self._save_folder, self._dataset_configs.keys(), img_gt,
                               "Ground_Truth")
            save_rebuilt_image(self._current_epoch, self._save_folder, self._dataset_configs.keys(), img_norm,
                               "Normalized")
            save_rebuilt_image(self._current_epoch, self._save_folder, self._dataset_configs.keys(), img_seg,
                               "Segmented")

            if self._training_config.build_augmented_images:
                img_augmented = rebuild_image(self._dataset_configs.keys(), all_patches, self._augmented_reconstructors)
                augmented_minus_inputs, normalized_minus_inputs = rebuild_augmented_images(img_augmented, img_input,
                                                                                           img_gt, img_norm, img_seg)

                save_augmented_rebuilt_images(self._current_epoch, self._save_folder, self._dataset_configs.keys(),
                                              img_augmented, augmented_minus_inputs, normalized_minus_inputs)

            mean_mhd = []
            for dataset in self._dataset_configs.keys():
                self.custom_variables[
                    "Reconstructed Normalized {} Image".format(dataset)] = self._slicer.get_slice(
                    SliceType.AXIAL, np.expand_dims(np.expand_dims(img_norm[dataset], 0), 0), 160)
                self.custom_variables[
                    "Reconstructed Segmented {} Image".format(dataset)] = self._seg_slicer.get_colored_slice(
                    SliceType.AXIAL, np.expand_dims(np.expand_dims(img_seg[dataset], 0), 0), 160).squeeze(0)
                self.custom_variables[
                    "Reconstructed Ground Truth {} Image".format(dataset)] = self._seg_slicer.get_colored_slice(
                    SliceType.AXIAL, np.expand_dims(np.expand_dims(img_gt[dataset], 0), 0), 160).squeeze(0)
                self.custom_variables[
                    "Reconstructed Input {} Image".format(dataset)] = self._slicer.get_slice(
                    SliceType.AXIAL, np.expand_dims(np.expand_dims(img_input[dataset], 0), 0), 160)

                if self._training_config.build_augmented_images:
                    self.custom_variables[
                        "Reconstructed Augmented Input {} Image".format(dataset)] = self._slicer.get_slice(
                        SliceType.AXIAL, np.expand_dims(np.expand_dims(img_augmented[dataset], 0), 0), 160)
                    self.custom_variables[
                        "Reconstructed Initial Noise {} Image".format(
                            dataset)] = self._seg_slicer.get_colored_slice(
                        SliceType.AXIAL,
                        np.expand_dims(np.expand_dims(augmented_minus_inputs[dataset], 0), 0), 160).squeeze(0)
                    self.custom_variables[
                        "Reconstructed Noise {} After Normalization".format(
                            dataset)] = self._seg_slicer.get_colored_slice(
                        SliceType.AXIAL,
                        np.expand_dims(np.expand_dims(normalized_minus_inputs[dataset], 0), 0), 160).squeeze(0)
                else:
                    self.custom_variables["Reconstructed Augmented Input {} Image".format(
                        dataset)] = np.zeros((224, 192))
                    self.custom_variables[
                        "Reconstructed Initial Noise {} Image".format(
                            dataset)] = np.zeros((224, 192))
                    self.custom_variables[
                        "Reconstructed Noise {} After Normalization".format(
                            dataset)] = np.zeros((224, 192))

                mean_mhd.append(mean_hausdorff_distance(
                    to_onehot(torch.tensor(img_gt[dataset], dtype=torch.long), num_classes=4),
                    to_onehot(torch.tensor(img_seg[dataset], dtype=torch.long), num_classes=4))[-3:].mean())

                metric = self._model_trainers[SEGMENTER].compute_metrics(
                    to_onehot(torch.tensor(img_seg[dataset]).unsqueeze(0).long(), num_classes=4),
                    torch.tensor(img_gt[dataset]).unsqueeze(0).long())

                self._class_dice_gauge_on_reconstructed_images.update(np.array(metric["Dice"]))

            self._per_dataset_hausdorff_distance_gauge.update(np.array(mean_mhd))

            if "iSEG" in img_seg:
                metric = self._model_trainers[SEGMENTER].compute_metrics(
                    to_onehot(torch.tensor(img_seg["iSEG"]).unsqueeze(0).long(), num_classes=4),
                    torch.tensor(img_gt["iSEG"]).unsqueeze(0).long())
                self._class_dice_gauge_on_reconstructed_iseg_images.update(np.array(metric["Dice"]))
                self._hausdorff_distance_gauge_on_reconstructed_iseg_images.update(mean_hausdorff_distance(
                    to_onehot(torch.tensor(img_gt["iSEG"], dtype=torch.long), num_classes=4),
                    to_onehot(torch.tensor(img_seg["iSEG"], dtype=torch.long), num_classes=4))[-3:])
            else:
                self._class_dice_gauge_on_reconstructed_iseg_images.update(np.array([0.0, 0.0, 0.0]))
                self._hausdorff_distance_gauge_on_reconstructed_iseg_images.update(np.array([0.0, 0.0, 0.0]))
            if "MRBrainS" in img_seg:
                metric = self._model_trainers[SEGMENTER].compute_metrics(
                    to_onehot(torch.tensor(img_seg["MRBrainS"]).unsqueeze(0).long(), num_classes=4),
                    torch.tensor(img_gt["MRBrainS"]).unsqueeze(0).long())
                self._class_dice_gauge_on_reconstructed_mrbrains_images.update(np.array(metric["Dice"]))
                self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.update(mean_hausdorff_distance(
                    to_onehot(torch.tensor(img_gt["MRBrainS"], dtype=torch.long), num_classes=4),
                    to_onehot(torch.tensor(img_seg["MRBrainS"], dtype=torch.long), num_classes=4))[-3:])
            else:
                self._class_dice_gauge_on_reconstructed_mrbrains_images.update(np.array([0.0, 0.0, 0.0]))
                self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.update(np.array([0.0, 0.0, 0.0]))
            if "ABIDE" in img_seg:
                metric = self._model_trainers[SEGMENTER].compute_metrics(
                    to_onehot(torch.tensor(img_seg["ABIDE"]).unsqueeze(0).long(), num_classes=4),
                    torch.tensor(img_gt["ABIDE"]).unsqueeze(0).long())
                self._class_dice_gauge_on_reconstructed_abide_images.update(np.array(metric["Dice"]))
                self._hausdorff_distance_gauge_on_reconstructed_abide_images.update(mean_hausdorff_distance(
                    to_onehot(torch.tensor(img_gt["ABIDE"], dtype=torch.long), num_classes=4),
                    to_onehot(torch.tensor(img_seg["ABIDE"], dtype=torch.long), num_classes=4))[-3:])
            else:
                self._class_dice_gauge_on_reconstructed_abide_images.update(np.array([0.0, 0.0, 0.0]))
                self._hausdorff_distance_gauge_on_reconstructed_abide_images.update(np.array([0.0, 0.0, 0.0]))

            if len(img_input) == 3:
                self.custom_variables["Reconstructed Images Histograms"] = cv2.imread(
                    construct_triple_histrogram(img_norm["iSEG"],
                                                img_input["iSEG"],
                                                img_norm["MRBrainS"],
                                                img_input["MRBrainS"],
                                                img_norm["ABIDE"],
                                                img_input["ABIDE"])).transpose((2, 0, 1))
            elif len(img_input) == 2:
                self.custom_variables["Reconstructed Images Histograms"] = cv2.imread(
                    construct_double_histrogram(img_norm["iSEG"],
                                                img_input["iSEG"],
                                                img_norm["MRBrainS"],
                                                img_input["MRBrainS"])).transpose((2, 0, 1))
            elif len(img_input) == 1:
                self.custom_variables["Reconstructed Images Histograms"] = cv2.imread(
                    construct_single_histogram(img_norm[list(self._dataset_configs.keys())[0]],
                                               img_input[list(self._dataset_configs.keys())[0]],
                                               )).transpose((2, 0, 1))

        if "ABIDE" not in self._dataset_configs.keys():
            self.custom_variables["Reconstructed Normalized ABIDE Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Segmented ABIDE Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Ground Truth ABIDE Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Input ABIDE Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Initial Noise ABIDE Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Noise ABIDE After Normalization"] = np.zeros((224, 192))
        if "iSEG" not in self._dataset_configs.keys():
            self.custom_variables["Reconstructed Normalized iSEG Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Segmented iSEG Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Ground Truth iSEG Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Input iSEG Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Initial Noise iSEG Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Noise iSEG After Normalization"] = np.zeros((224, 192))
        if "MRBrainS" not in self._dataset_configs.keys():
            self.custom_variables["Reconstructed Normalized MRBrainS Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Segmented MRBrainS Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Ground Truth MRBrainS Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Input MRBrainS Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Initial Noise MRBrainS Image"] = np.zeros((224, 192))
            self.custom_variables["Reconstructed Noise MRBrainS After Normalization"] = np.zeros((224, 192))

        self.custom_variables["Runtime"] = to_html_time(timedelta(seconds=time.time() - self._start_time))

        if self._general_confusion_matrix_gauge._num_examples != 0:
            self.custom_variables["Confusion Matrix"] = np.array(
                np.fliplr(self._general_confusion_matrix_gauge.compute().cpu().detach().numpy()))
        else:
            self.custom_variables["Confusion Matrix"] = np.zeros((4, 4))

        if self._iSEG_confusion_matrix_gauge._num_examples != 0:
            self.custom_variables["iSEG Confusion Matrix"] = np.array(
                np.fliplr(self._iSEG_confusion_matrix_gauge.compute().cpu().detach().numpy()))
        else:
            self.custom_variables["iSEG Confusion Matrix"] = np.zeros((4, 4))

        if self._MRBrainS_confusion_matrix_gauge._num_examples != 0:
            self.custom_variables["MRBrainS Confusion Matrix"] = np.array(
                np.fliplr(self._MRBrainS_confusion_matrix_gauge.compute().cpu().detach().numpy()))
        else:
            self.custom_variables["MRBrainS Confusion Matrix"] = np.zeros((4, 4))

        if self._ABIDE_confusion_matrix_gauge._num_examples != 0:
            self.custom_variables["ABIDE Confusion Matrix"] = np.array(
                np.fliplr(self._ABIDE_confusion_matrix_gauge.compute().cpu().detach().numpy()))
        else:
            self.custom_variables["ABIDE Confusion Matrix"] = np.zeros((4, 4))

        self.custom_variables["Metric Table"] = to_html(["CSF", "Grey Matter", "White Matter"],
                                                        ["DSC", "HD"],
                                                        [
                                                            self._class_dice_gauge_on_patches.compute() if self._class_dice_gauge_on_patches.has_been_updated() else np.array(
                                                                [0.0, 0.0, 0.0]),
                                                            self._class_hausdorff_distance_gauge.compute() if self._class_hausdorff_distance_gauge.has_been_updated() else np.array(
                                                                [0.0, 0.0, 0.0])
                                                        ])

        self.custom_variables[
            "Dice score per class per epoch"] = self._class_dice_gauge_on_patches.compute() if self._class_dice_gauge_on_patches.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Dice score per class per epoch on reconstructed image"] = self._class_dice_gauge_on_reconstructed_images.compute() if self._class_dice_gauge_on_reconstructed_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Dice score per class per epoch on reconstructed iSEG image"] = self._class_dice_gauge_on_reconstructed_iseg_images.compute() if self._class_dice_gauge_on_reconstructed_iseg_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Dice score per class per epoch on reconstructed MRBrainS image"] = self._class_dice_gauge_on_reconstructed_mrbrains_images.compute() if self._class_dice_gauge_on_reconstructed_mrbrains_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Dice score per class per epoch on reconstructed ABIDE image"] = self._class_dice_gauge_on_reconstructed_abide_images.compute() if self._class_dice_gauge_on_reconstructed_abide_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Hausdorff Distance per class per epoch on reconstructed iSEG image"] = self._hausdorff_distance_gauge_on_reconstructed_iseg_images.compute() if self._hausdorff_distance_gauge_on_reconstructed_iseg_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Hausdorff Distance per class per epoch on reconstructed MRBrainS image"] = self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.compute() if self._hausdorff_distance_gauge_on_reconstructed_mrbrains_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])
        self.custom_variables[
            "Hausdorff Distance per class per epoch on reconstructed ABIDE image"] = self._hausdorff_distance_gauge_on_reconstructed_abide_images.compute() if self._hausdorff_distance_gauge_on_reconstructed_abide_images.has_been_updated() else np.array(
            [0.0, 0.0, 0.0])

        if self._valid_dice_gauge.compute() > self._previous_mean_dice:
            new_table = to_html_per_dataset(
                ["CSF", "Grey Matter", "White Matter"],
                ["DSC", "HD"],
                [
                    [
                        self._iSEG_dice_gauge.compute() if self._iSEG_dice_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0]),
                        self._iSEG_hausdorff_gauge.compute() if self._iSEG_hausdorff_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0])],
                    [
                        self._MRBrainS_dice_gauge.compute() if self._MRBrainS_dice_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0]),
                        self._MRBrainS_hausdorff_gauge.compute() if self._MRBrainS_hausdorff_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0])],
                    [
                        self._ABIDE_dice_gauge.compute() if self._ABIDE_dice_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0]),
                        self._ABIDE_hausdorff_gauge.compute() if self._ABIDE_hausdorff_gauge.has_been_updated() else np.array(
                            [0.0, 0.0, 0.0])]],
                ["iSEG", "MRBrainS", "ABIDE"])

            self.custom_variables["Per-Dataset Metric Table"] = new_table
            self._previous_mean_dice = self._valid_dice_gauge.compute()
            self._previous_per_dataset_table = new_table
        else:
            self.custom_variables["Per-Dataset Metric Table"] = self._previous_per_dataset_table
        self._valid_dice_gauge.reset()

        self.custom_variables["Jensen-Shannon Table"] = to_html_JS(["Input data", "Generated Data"],
                                                                   ["JS Divergence"],
                                                                   [
                                                                       self._js_div_inputs_gauge.compute() if self._js_div_gen_gauge.has_been_updated() else np.array(
                                                                           [0.0]),
                                                                       self._js_div_gen_gauge.compute() if self._js_div_gen_gauge.has_been_updated() else np.array(
                                                                           [0.0])])
        self.custom_variables["Jensen-Shannon Divergence"] = [
            self._js_div_inputs_gauge.compute(),
            self._js_div_gen_gauge.compute()]
        self.custom_variables["Mean Hausdorff Distance"] = [
            self._class_hausdorff_distance_gauge.compute().mean() if self._class_hausdorff_distance_gauge.has_been_updated() else np.array(
                [0.0])]
        self.custom_variables[
            "Per Dataset Mean Hausdorff Distance"] = self._per_dataset_hausdorff_distance_gauge.compute()