def test_should_compute_jensen_shannon_divergence_of_same_distribution(
            self):
        prob_distributions = torch.tensor(
            [[self.PROB_DIST1, self.PROB_DIST1, self.PROB_DIST1]])
        expected_results = torch.tensor([0.0])

        assert_that(F.js_div(prob_distributions), equal_to(expected_results))
Beispiel #2
0
    def test_step(self, inputs, target):
        if self._should_activate_autoencoder():
            gen_pred = self._test_g(self._model_trainers[GENERATOR],
                                    inputs[NON_AUGMENTED_INPUTS])

            seg_pred, _ = self._test_s(self._model_trainers[SEGMENTER],
                                       inputs[NON_AUGMENTED_INPUTS],
                                       target[IMAGE_TARGET],
                                       self._class_dice_gauge_on_patches)

        if self._should_activate_segmentation():
            gen_pred = self._test_g(self._model_trainers[GENERATOR],
                                    inputs[AUGMENTED_INPUTS])

            seg_pred, loss_S = self._test_s(self._model_trainers[SEGMENTER],
                                            gen_pred, target[IMAGE_TARGET],
                                            self._class_dice_gauge_on_patches)

            if seg_pred[torch.where(
                    target[DATASET_ID] == ISEG_ID)].shape[0] != 0:
                self._iSEG_dice_gauge.update(
                    np.array(self._model_trainers[SEGMENTER].compute_metrics(
                        torch.nn.functional.softmax(seg_pred[torch.where(
                            target[DATASET_ID] == ISEG_ID)],
                                                    dim=1),
                        torch.squeeze(target[IMAGE_TARGET][torch.where(
                            target[DATASET_ID] == ISEG_ID)],
                                      dim=1).long())["Dice"].numpy()))

                self._iSEG_hausdorff_gauge.update(
                    mean_hausdorff_distance(
                        to_onehot(torch.argmax(torch.nn.functional.softmax(
                            seg_pred[torch.where(
                                target[DATASET_ID] == ISEG_ID)],
                            dim=1),
                                               dim=1),
                                  num_classes=4),
                        to_onehot(torch.squeeze(
                            target[IMAGE_TARGET][torch.where(
                                target[DATASET_ID] == ISEG_ID)],
                            dim=1).long(),
                                  num_classes=4))[-3:])

                self._iSEG_confusion_matrix_gauge.update(
                    (to_onehot(torch.argmax(torch.nn.functional.softmax(
                        seg_pred[torch.where(target[DATASET_ID] == ISEG_ID)],
                        dim=1),
                                            dim=1,
                                            keepdim=False),
                               num_classes=4),
                     torch.squeeze(target[IMAGE_TARGET][torch.where(
                         target[DATASET_ID] == ISEG_ID)].long(),
                                   dim=1)))

            else:
                self._iSEG_dice_gauge.update(np.zeros((3, )))
                self._iSEG_hausdorff_gauge.update(np.zeros((3, )))

            if seg_pred[torch.where(
                    target[DATASET_ID] == MRBRAINS_ID)].shape[0] != 0:
                self._MRBrainS_dice_gauge.update(
                    np.array(self._model_trainers[SEGMENTER].compute_metrics(
                        torch.nn.functional.softmax(seg_pred[torch.where(
                            target[DATASET_ID] == MRBRAINS_ID)],
                                                    dim=1),
                        torch.squeeze(target[IMAGE_TARGET][torch.where(
                            target[DATASET_ID] == MRBRAINS_ID)],
                                      dim=1).long())["Dice"].numpy()))

                self._MRBrainS_hausdorff_gauge.update(
                    mean_hausdorff_distance(
                        to_onehot(torch.argmax(torch.nn.functional.softmax(
                            seg_pred[torch.where(
                                target[DATASET_ID] == MRBRAINS_ID)],
                            dim=1),
                                               dim=1),
                                  num_classes=4),
                        to_onehot(torch.squeeze(
                            target[IMAGE_TARGET][torch.where(
                                target[DATASET_ID] == MRBRAINS_ID)],
                            dim=1).long(),
                                  num_classes=4))[-3:])

                self._MRBrainS_confusion_matrix_gauge.update(
                    (to_onehot(torch.argmax(torch.nn.functional.softmax(
                        seg_pred[torch.where(
                            target[DATASET_ID] == MRBRAINS_ID)],
                        dim=1),
                                            dim=1,
                                            keepdim=False),
                               num_classes=4),
                     torch.squeeze(target[IMAGE_TARGET][torch.where(
                         target[DATASET_ID] == MRBRAINS_ID)].long(),
                                   dim=1)))
            else:
                self._MRBrainS_dice_gauge.update(np.zeros((3, )))
                self._MRBrainS_hausdorff_gauge.update(np.zeros((3, )))

            if seg_pred[torch.where(
                    target[DATASET_ID] == ABIDE_ID)].shape[0] != 0:
                self._ABIDE_dice_gauge.update(
                    np.array(self._model_trainers[SEGMENTER].compute_metrics(
                        torch.nn.functional.softmax(seg_pred[torch.where(
                            target[DATASET_ID] == ABIDE_ID)],
                                                    dim=1),
                        torch.squeeze(target[IMAGE_TARGET][torch.where(
                            target[DATASET_ID] == ABIDE_ID)],
                                      dim=1).long())["Dice"].numpy()))

                self._ABIDE_hausdorff_gauge.update(
                    mean_hausdorff_distance(
                        to_onehot(torch.argmax(torch.nn.functional.softmax(
                            seg_pred[torch.where(
                                target[DATASET_ID] == ABIDE_ID)],
                            dim=1),
                                               dim=1),
                                  num_classes=4),
                        to_onehot(torch.squeeze(
                            target[IMAGE_TARGET][torch.where(
                                target[DATASET_ID] == ABIDE_ID)],
                            dim=1).long(),
                                  num_classes=4))[-3:])

                self._ABIDE_confusion_matrix_gauge.update(
                    (to_onehot(torch.argmax(torch.nn.functional.softmax(
                        seg_pred[torch.where(target[DATASET_ID] == ABIDE_ID)],
                        dim=1),
                                            dim=1,
                                            keepdim=False),
                               num_classes=4),
                     torch.squeeze(target[IMAGE_TARGET][torch.where(
                         target[DATASET_ID] == ABIDE_ID)].long(),
                                   dim=1)))

            self._class_hausdorff_distance_gauge.update(
                mean_hausdorff_distance(
                    to_onehot(torch.argmax(torch.nn.functional.softmax(
                        seg_pred, dim=1),
                                           dim=1),
                              num_classes=4),
                    to_onehot(torch.squeeze(target[IMAGE_TARGET],
                                            dim=1).long(),
                              num_classes=4))[-3:])

            self._general_confusion_matrix_gauge.update(
                (to_onehot(torch.argmax(torch.nn.functional.softmax(seg_pred,
                                                                    dim=1),
                                        dim=1,
                                        keepdim=False),
                           num_classes=4),
                 torch.squeeze(target[IMAGE_TARGET].long(), dim=1)))

            inputs_reshaped = inputs[AUGMENTED_INPUTS].reshape(
                inputs[AUGMENTED_INPUTS].shape[0],
                inputs[AUGMENTED_INPUTS].shape[1] *
                inputs[AUGMENTED_INPUTS].shape[2] *
                inputs[AUGMENTED_INPUTS].shape[3] *
                inputs[AUGMENTED_INPUTS].shape[4])

            c, d, h, w = inputs[AUGMENTED_INPUTS].shape[1], inputs[AUGMENTED_INPUTS].shape[2], \
                         inputs[AUGMENTED_INPUTS].shape[3], inputs[AUGMENTED_INPUTS].shape[4]

            hist_inputs = torch.cat([
                torch.histc(inputs[AUGMENTED_INPUTS][i].view(1, c * d * h * w),
                            bins=256,
                            min=0,
                            max=1).unsqueeze(0)
                for i in range(inputs[0].shape[0])
            ]).unsqueeze(0)
            hist_inputs = hist_inputs / (c * d * h * w)
            hist_inputs = torch.nn.Softmax(dim=2)(hist_inputs)

            hist_gen = torch.cat([
                torch.histc(gen_pred[i].view(1, c * d * h * w),
                            bins=256,
                            min=0,
                            max=1).unsqueeze(0)
                for i in range(gen_pred.shape[0])
            ]).unsqueeze(0)
            hist_gen = hist_gen / (c * d * h * w)
            hist_gen = torch.nn.Softmax(dim=2)(hist_gen)

            self._js_div_inputs_gauge.update(js_div(hist_inputs).item())
            self._js_div_gen_gauge.update(js_div(hist_gen).item())

        if self.current_test_step % 100 == 0:
            self._update_histograms(inputs[NON_AUGMENTED_INPUTS], target,
                                    gen_pred)
            self._update_image_plots(
                self.phase, inputs[NON_AUGMENTED_INPUTS].cpu().detach(),
                gen_pred.cpu().detach(),
                seg_pred.cpu().detach(), target[IMAGE_TARGET].cpu().detach(),
                target[DATASET_ID].cpu().detach())
Beispiel #3
0
def compute_js_divergence(samples):
    """
    Computes the JS Divergence using the support intersection between two different samples
    """
    return js_div(samples)