コード例 #1
0
    def blocks(self):

        test_image, gt = load_vol_brats(self.vol_path[1], slicen=78, pad=8)

        prediction = np.argmax(self.model.predict(test_image[None, ...]),
                               axis=-1)[0]
        n_classes = (len(np.unique(prediction)))
        corr = np.zeros((n_classes, n_classes))
        slices = [78]

        intervention_image = np.empty(test_image.shape)

        for _modality in range(4):
            for i in range(2):
                for j in range(2):
                    try:
                        intervention_image[:, :, _modality][
                            test_image.shape[0] // 2 * i:test_image.shape[0] //
                            2 * (i + 1), test_image.shape[1] // 2 *
                            j:test_image.shape[1] // 2 * (j + 1)].fill(
                                np.mean(test_image[gt == 2 * i + j],
                                        axis=0)[_modality])
                    except Exception as e:
                        print(e)

        prediction_intervention = model.predict(intervention_image[None, ...])
        plt.imshow(intervention_image[:, :, 0])
        plt.colorbar()
        plt.show()
        plt.imshow(np.argmax(prediction_intervention, axis=-1)[0],
                   vmin=0,
                   vmax=3)
        plt.colorbar()
        plt.show()
コード例 #2
0
    def mean_swap(self,
                  test_path,
                  plot=True,
                  save_path='home/brats/parth/BioExp/results/RCT'):

        channel = 3

        vol_path = glob(test_path)
        test_image, gt = load_vol_brats(vol_path[0], slicen=78, pad=8)

        prediction = np.argmax(self.model.predict(test_image[None, ...]),
                               axis=-1)[0]
        n_classes = (len(np.unique(prediction)))
        corr = np.zeros((n_classes, n_classes))
        slices = [78]

        plt.figure(figsize=(20, 20))

        for vol in vol_path:
            for slicen in slices:

                test_image, gt = load_vol_brats(vol, slicen=slicen, pad=8)

                prediction = np.argmax(self.model.predict(test_image[None,
                                                                     ...]),
                                       axis=-1)[0]
                print("Original Dice Whole:", dice_whole_coef(prediction, gt))

                class_dict = {0: 'bg', 1: 'core', 2: 'edema', 3: 'enhancing'}

                corr_temp = np.zeros((n_classes, n_classes))
                for i in range(n_classes):
                    for j in range(n_classes):
                        new_mean = np.mean(test_image[gt == i], axis=0)
                        old_mean = np.mean(test_image[gt == j], axis=0)
                        test_image_intervention = np.copy(test_image)
                        test_image_intervention[gt == j] += (new_mean -
                                                             old_mean)
                        prediction_intervention = np.argmax(self.model.predict(
                            test_image_intervention[None, ...]),
                                                            axis=-1)[0]
                        if plot == True:
                            plt.subplot(n_classes, n_classes, 1 + 4 * i + j)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("{} --> {}".format(
                                class_dict[j], class_dict[i]))
                            plt.imshow(prediction_intervention,
                                       cmap=plt.cm.RdBu,
                                       vmin=0,
                                       vmax=3)
                            plt.colorbar()

                        corr[i, j] += dice_label_coef(
                            prediction, gt, (j, )) - dice_label_coef(
                                prediction_intervention, gt, (j, ))
                        corr_temp[i, j] += dice_label_coef(
                            prediction, gt, (j, )) - dice_label_coef(
                                prediction_intervention, gt, (j, ))

        np.set_printoptions(precision=2, suppress=True)

        intervention_importance = corr / (len(vol_path) * len(slices))
        print(intervention_importance)
        os.makedirs(save_path, exist_ok=True)
        np.save(save_path + '/mean_swap_all_images.npy',
                intervention_importance)
        if plot == True:
            plt.show()
コード例 #3
0
 def __init__(self, model, test_path):
     self.model = model
     self.vol_path = glob(test_path)
     self.test_image, self.gt = load_vol_brats(self.vol_path[3],
                                               slicen=78,
                                               pad=0)
コード例 #4
0
        })

    model.load_weights(
        '/home/parth/Interpretable_ML/saved_models/SimUnet/SimUnet.40_0.060.hdf5'
    )

    I = intervention(model, '/media/parth/DATA/datasets/brats_2018/val/**')

    test_path = glob('/media/parth/DATA/datasets/brats_2018/val/**')

    average_change = []

    for epsilon in [0.7]:  #, 0.07, 0.21, 0.7]:
        for i in tqdm(range(len(test_path))):

            test_image, gt = load_vol_brats(test_path[i], slicen=78, pad=0)
            if len(np.unique(gt)) == 4:
                print(len(np.unique(gt)))
                # I.blocks('/home/parth/Interpretable_ML/BioExp/sample_vol/brats/**')
                adv = I.adverserial(epsilon=epsilon,
                                    mode='gradient',
                                    test_image=test_image,
                                    gt=gt)
                if adv[1] > 0:
                    average_change.append(adv)
                    print(adv)

        print(np.mean(average_change, axis=0))

    # I.generate_random_classification(mode='swap')
    # I.mean_swap(plot = False)