示例#1
0
def main(model_path, exp_config, do_plots=False):

    n_samples = 50
    model_selection = 'best_ged'

    # Get Data
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    N = data.test.images.shape[0]

    ged_list = []
    ncc_list = []

    for ii in range(N):

        if ii % 10 == 0:
            logging.info("Progress: %d" % ii)

        x_b = data.test.images[ii, ...].reshape([1] + list(exp_config.image_size))
        s_b = data.test.labels[ii, ...]

        x_b_stacked = np.tile(x_b, [n_samples, 1, 1, 1])

        feed_dict = {}
        feed_dict[segvae_model.training_pl] = False
        feed_dict[segvae_model.x_inp] = x_b_stacked


        s_arr_sm = segvae_model.sess.run(segvae_model.s_out_eval_sm, feed_dict=feed_dict)
        s_arr = np.argmax(s_arr_sm, axis=-1)

        # s_arr = np.squeeze(np.asarray(s_list)) # num samples x X x Y
        s_b_r = s_b.transpose((2,0,1)) # num gts x X x Y
        s_b_r_sm = utils.convert_batch_to_onehot(s_b_r, exp_config.nlabels)  # num gts x X x Y x nlabels

        ged = utils.generalised_energy_distance(s_arr, s_b_r, nlabels=exp_config.nlabels-1, label_range=range(1,exp_config.nlabels))
        ged_list.append(ged)

        ncc = utils.variance_ncc_dist(s_arr_sm, s_b_r_sm)
        ncc_list.append(ncc)



    ged_arr = np.asarray(ged_list)
    ncc_arr = np.asarray(ncc_list)

    logging.info('-- GED: --')
    logging.info(np.mean(ged_arr))
    logging.info(np.std(ged_arr))

    logging.info('-- NCC: --')
    logging.info(np.mean(ncc_arr))
    logging.info(np.std(ncc_arr))

    np.savez(os.path.join(model_path, 'ged%s_%s.npz' % (str(n_samples), model_selection)), ged_arr)
    np.savez(os.path.join(model_path, 'ncc%s_%s.npz' % (str(n_samples), model_selection)), ncc_arr)
示例#2
0
def test_ncc(lidc_data):
    random_index = 99
    s_gt_arr = lidc_data.test.labels[random_index, ...]

    x_b = lidc_data.test.images[random_index, ...]
    patch = torch.tensor(x_b, dtype=torch.float32).to('cpu')

    assert s_gt_arr.shape == (128, 128, 4)
    val_masks = torch.tensor(s_gt_arr, dtype=torch.float32).to('cpu')  # HWC
    val_masks = val_masks.transpose(0, 2).transpose(1, 2)
    assert val_masks.shape == (4, 128, 128)

    s_gt_arr_r = val_masks.unsqueeze(dim=1)

    ground_truth_arrangement_one_hot = utils.convert_batch_to_onehot(
        s_gt_arr_r, nlabels=2)

    ncc = utils.variance_ncc_dist(ground_truth_arrangement_one_hot,
                                  ground_truth_arrangement_one_hot)

    assert math.isclose(ncc[0], 1.0)
示例#3
0
    def _do_validation(self, data):

        global_step = self.sess.run(self.global_step) - 1

        checkpoint_file = os.path.join(self.log_dir, 'model.ckpt')
        self.saver.save(self.sess, checkpoint_file, global_step=global_step)

        val_x, val_s = data.validation.next_batch(self.exp_config.batch_size)
        val_losses_out = self.sess.run(list(self.loss_dict.values()),
                                       feed_dict={
                                           self.x_inp: val_x,
                                           self.s_inp: val_s,
                                           self.training_pl: False
                                       })

        # Note that val_losses_out are now sorted in the same way as loss_dict,
        tot_loss_index = list(self.loss_dict.keys()).index('total_loss')
        val_loss_tot = val_losses_out[tot_loss_index]

        train_x, train_s = data.train.next_batch(self.exp_config.batch_size)
        train_losses_out = self.sess.run(list(self.loss_dict.values()),
                                         feed_dict={
                                             self.x_inp: train_x,
                                             self.s_inp: train_s,
                                             self.training_pl: False
                                         })

        logging.info('----- Step: %d ------' % global_step)
        logging.info('BATCH VALIDATION:')
        for ii, loss_name in enumerate(self.loss_dict.keys()):
            logging.info('%s | training: %f | validation: %f' %
                         (loss_name, train_losses_out[ii], val_losses_out[ii]))

        # Evaluate validation Dice:

        start_dice_val = time.time()
        num_batches = 0

        dice_list = []
        elbo_list = []
        ged_list = []
        ncc_list = []

        N = data.validation.images.shape[0]

        for ii in range(N):

            # logging.info(ii)

            x = data.validation.images[ii, ...].reshape(
                [1] + list(self.exp_config.image_size))
            s_gt_arr = data.validation.labels[ii, ...]
            s = s_gt_arr[:, :,
                         np.random.choice(self.exp_config.annotator_range)]

            x_b = np.tile(x, [self.exp_config.validation_samples, 1, 1, 1])
            s_b = np.tile(s, [self.exp_config.validation_samples, 1, 1])

            feed_dict = {}
            feed_dict[self.training_pl] = False
            feed_dict[self.x_inp] = x_b
            feed_dict[self.s_inp] = s_b

            s_pred_sm_arr, elbo = self.sess.run(
                [self.s_out_eval_sm, self.loss_tot], feed_dict=feed_dict)

            s_pred_sm_mean_ = np.mean(s_pred_sm_arr, axis=0)

            s_pred_arr = np.argmax(s_pred_sm_arr, axis=-1)
            s_gt_arr_r = s_gt_arr.transpose((2, 0, 1))  # num gts x X x Y

            s_gt_arr_r_sm = utils.convert_batch_to_onehot(
                s_gt_arr_r,
                self.exp_config.nlabels)  # num gts x X x Y x nlabels

            ged = utils.generalised_energy_distance(
                s_pred_arr,
                s_gt_arr_r,
                nlabels=self.exp_config.nlabels - 1,
                label_range=range(1, self.exp_config.nlabels))

            ncc = utils.variance_ncc_dist(s_pred_sm_arr, s_gt_arr_r_sm)

            s_ = np.argmax(s_pred_sm_mean_, axis=-1)

            # Write losses to list
            per_lbl_dice = []
            for lbl in range(self.exp_config.nlabels):
                binary_pred = (s_ == lbl) * 1
                binary_gt = (s == lbl) * 1

                if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                    per_lbl_dice.append(1)
                elif np.sum(binary_pred) > 0 and np.sum(
                        binary_gt) == 0 or np.sum(
                            binary_pred) == 0 and np.sum(binary_gt) > 0:
                    per_lbl_dice.append(0)
                else:
                    per_lbl_dice.append(dc(binary_pred, binary_gt))

            num_batches += 1

            dice_list.append(per_lbl_dice)
            elbo_list.append(elbo)
            ged_list.append(ged)
            ncc_list.append(ncc)

        dice_arr = np.asarray(dice_list)
        per_structure_dice = dice_arr.mean(axis=0)

        avg_dice = np.mean(dice_arr)
        avg_elbo = utils.list_mean(elbo_list)
        avg_ged = utils.list_mean(ged_list)
        avg_ncc = utils.list_mean(ncc_list)

        logging.info('FULL VALIDATION (%d images):' % N)
        logging.info(' - Mean foreground dice: %.4f' %
                     np.mean(per_structure_dice))
        logging.info(' - Mean (neg.) ELBO: %.4f' % avg_elbo)
        logging.info(' - Mean GED: %.4f' % avg_ged)
        logging.info(' - Mean NCC: %.4f' % avg_ncc)

        logging.info('@ Running through validation set took: %.2f secs' %
                     (time.time() - start_dice_val))

        if np.mean(per_structure_dice) >= self.best_dice:
            self.best_dice = np.mean(per_structure_dice)
            logging.info('New best validation Dice! (%.3f)' % self.best_dice)
            best_file = os.path.join(self.log_dir, 'model_best_dice.ckpt')
            self.saver_best_dice.save(self.sess,
                                      best_file,
                                      global_step=global_step)

        if avg_elbo <= self.best_loss:
            self.best_loss = avg_elbo
            logging.info('New best validation loss! (%.3f)' % self.best_loss)
            best_file = os.path.join(self.log_dir, 'model_best_loss.ckpt')
            self.saver_best_loss.save(self.sess,
                                      best_file,
                                      global_step=global_step)

        if avg_ged <= self.best_ged:
            self.best_ged = avg_ged
            logging.info('New best GED score! (%.3f)' % self.best_ged)
            best_file = os.path.join(self.log_dir, 'model_best_ged.ckpt')
            self.saver_best_ged.save(self.sess,
                                     best_file,
                                     global_step=global_step)

        if avg_ncc >= self.best_ncc:
            self.best_ncc = avg_ncc
            logging.info('New best NCC score! (%.3f)' % self.best_ncc)
            best_file = os.path.join(self.log_dir, 'model_best_ncc.ckpt')
            self.saver_best_ncc.save(self.sess,
                                     best_file,
                                     global_step=global_step)

        # Create Validation Summary feed dict
        z_prior_list = self.generate_prior_samples(x_in=val_x)
        val_summary_feed_dict = {
            i: d
            for i, d in zip(self.z_list_gen, z_prior_list)
        }  # this is for prior samples
        val_summary_feed_dict[self.x_for_gen] = val_x

        # Fill placeholders for all losses
        for loss_key, loss_val in zip(self.loss_dict.keys(), val_losses_out):
            # The detour over loss_dict.keys() is necessary because val_losses_out is sorted in the same
            # way as loss_dict. Same for the training below.
            loss_pl = self.validation_loss_pl_dict[loss_key]
            val_summary_feed_dict[loss_pl] = loss_val

        # Fill placeholders for validation Dice
        val_summary_feed_dict[self.val_tot_dice_score] = avg_dice
        val_summary_feed_dict[self.val_mean_dice_score] = np.mean(
            per_structure_dice)
        val_summary_feed_dict[self.val_elbo] = avg_elbo
        val_summary_feed_dict[self.val_ged] = avg_ged
        val_summary_feed_dict[self.val_ncc] = np.squeeze(avg_ncc)

        for ii in range(self.exp_config.nlabels):
            val_summary_feed_dict[
                self.val_lbl_dice_scores[ii]] = per_structure_dice[ii]

        val_summary_feed_dict[self.x_inp] = val_x
        val_summary_feed_dict[self.s_inp] = val_s
        val_summary_feed_dict[self.training_pl] = False

        val_summary_msg = self.sess.run(self.val_summary,
                                        feed_dict=val_summary_feed_dict)
        self.summary_writer.add_summary(val_summary_msg, global_step)

        # Create train Summary feed dict
        train_summary_feed_dict = {}
        for loss_key, loss_val in zip(self.loss_dict.keys(), train_losses_out):
            loss_pl = self.train_loss_pl_dict[loss_key]
            train_summary_feed_dict[loss_pl] = loss_val
        train_summary_feed_dict[self.training_pl] = False

        train_summary_msg = self.sess.run(self.train_summary,
                                          feed_dict=train_summary_feed_dict)
        self.summary_writer.add_summary(train_summary_msg, global_step)
示例#4
0
    def test(self, data, sys_config):
        self.net.eval()
        with torch.no_grad():

            model_selection = self.exp_config.experiment_name + '_best_loss.pth'
            self.logger.info('Testing {}'.format(model_selection))

            self.logger.info('Loading pretrained model {}'.format(model_selection))

            model_path = os.path.join(
                sys_config.log_root,
                self.exp_config.log_dir_name,
                self.exp_config.experiment_name,
                model_selection)

            if os.path.exists(model_path):
                self.net.load_state_dict(torch.load(model_path))
            else:
                self.logger.info('The file {} does not exist. Aborting test function.'.format(model_path))
                return

            ged_list = []
            dice_list = []
            ncc_list = []

            time_ = time.time()

            end_dice = 0.0
            end_ged = 0.0
            end_ncc = 0.0

            for i in range(10):
                self.logger.info('Doing iteration {}'.format(i))
                n_samples = 10

                for ii in range(data.test.images.shape[0]):

                    s_gt_arr = data.test.labels[ii, ...]

                    # from HW to NCHW
                    x_b = data.test.images[ii, ...]
                    patch = torch.tensor(x_b, dtype=torch.float32).to(self.device)
                    val_patch = patch.unsqueeze(dim=0).unsqueeze(dim=1)

                    s_b = s_gt_arr[:, :, np.random.choice(self.exp_config.annotator_range)]
                    mask = torch.tensor(s_b, dtype=torch.float32).to(self.device)
                    val_mask = mask.unsqueeze(dim=0).unsqueeze(dim=1)
                    val_masks = torch.tensor(s_gt_arr, dtype=torch.float32).to(self.device)  # HWC
                    val_masks = val_masks.transpose(0, 2).transpose(1, 2)  # CHW

                    patch_arrangement = val_patch.repeat((n_samples, 1, 1, 1))

                    mask_arrangement = val_mask.repeat((n_samples, 1, 1, 1))

                    self.mask = mask_arrangement
                    self.patch = patch_arrangement

                    # training=True for constructing posterior as well
                    s_out_eval_list = self.net.forward(patch_arrangement, mask_arrangement, training=False)
                    s_prediction_softmax_arrangement = self.net.accumulate_output(s_out_eval_list, use_softmax=True)

                    s_prediction_softmax_mean = torch.mean(s_prediction_softmax_arrangement, axis=0)
                    s_prediction_arrangement = torch.argmax(s_prediction_softmax_arrangement, dim=1)

                    ground_truth_arrangement = val_masks  # nlabels, H, W
                    ged = utils.generalised_energy_distance(s_prediction_arrangement, ground_truth_arrangement,
                                                            nlabels=self.exp_config.n_classes - 1,
                                                            label_range=range(1, self.exp_config.n_classes))

                    # num_gts, nlabels, H, W
                    s_gt_arr_r = val_masks.unsqueeze(dim=1)
                    ground_truth_arrangement_one_hot = utils.convert_batch_to_onehot(s_gt_arr_r,
                                                                                     nlabels=self.exp_config.n_classes)
                    ncc = utils.variance_ncc_dist(s_prediction_softmax_arrangement, ground_truth_arrangement_one_hot)

                    s_ = torch.argmax(s_prediction_softmax_mean, dim=0)  # HW
                    s = val_mask.view(val_mask.shape[-2], val_mask.shape[-1])  # HW

                    # Write losses to list
                    per_lbl_dice = []
                    for lbl in range(self.exp_config.n_classes):
                        binary_pred = (s_ == lbl) * 1
                        binary_gt = (s == lbl) * 1

                        if torch.sum(binary_gt) == 0 and torch.sum(binary_pred) == 0:
                            per_lbl_dice.append(1.0)
                        elif torch.sum(binary_pred) > 0 and torch.sum(binary_gt) == 0 or torch.sum(
                                binary_pred) == 0 and torch.sum(
                                binary_gt) > 0:
                            per_lbl_dice.append(0.0)
                        else:
                            per_lbl_dice.append(dc(binary_pred.detach().cpu().numpy(), binary_gt.detach().cpu().numpy()))
                    dice_list.append(per_lbl_dice)

                    ged_list.append(ged)
                    ncc_list.append(ncc)

                    if ii % 100 == 0:
                        self.logger.info(' - Mean GED: %.4f' % torch.mean(torch.tensor(ged_list)))
                        self.logger.info(' - Mean NCC: %.4f' % torch.mean(torch.tensor(ncc_list)))


                dice_tensor = torch.tensor(dice_list)
                per_structure_dice = dice_tensor.mean(dim=0)

                ged_tensor = torch.tensor(ged_list)
                ncc_tensor = torch.tensor(ncc_list)

                model_path = os.path.join(
                    sys_config.log_root,
                    self.exp_config.log_dir_name,
                    self.exp_config.experiment_name)

                np.savez(os.path.join(model_path, 'ged%s_%s_2.npz' % (str(n_samples), model_selection)), ged_tensor.numpy())
                np.savez(os.path.join(model_path, 'ncc%s_%s_2.npz' % (str(n_samples), model_selection)), ncc_tensor.numpy())

                self.avg_dice = torch.mean(dice_tensor)
                self.foreground_dice = torch.mean(dice_tensor, dim=0)[1]

                self.avg_ged = torch.mean(ged_tensor)
                self.avg_ncc = torch.mean(ncc_tensor)

                logging.info('-- GED: --')
                logging.info(torch.mean(ged_tensor))
                logging.info(torch.std(ged_tensor))

                logging.info('-- NCC: --')
                logging.info(torch.mean(ncc_tensor))
                logging.info(torch.std(ncc_tensor))

                self.logger.info(' - Foreground dice: %.4f' % torch.mean(self.foreground_dice))
                self.logger.info(' - Mean (neg.) ELBO: %.4f' % self.val_elbo)
                self.logger.info(' - Mean GED: %.4f' % self.avg_ged)
                self.logger.info(' - Mean NCC: %.4f' % self.avg_ncc)

                self.logger.info('Testing took {} seconds'.format(time.time() - time_))

                end_dice += self.avg_dice
                end_ged += self.avg_ged
                end_ncc += self.avg_ncc
            self.logger.info('Mean dice: {}'.format(end_dice/10))
            self.logger.info('Mean ged: {}'.format(end_ged / 10))
            self.logger.info('Mean ncc: {}'.format(end_ncc / 10))
示例#5
0
    def validate(self, data):
        self.net.eval()
        with torch.no_grad():
            self.logger.info('Validation for step {}'.format(self.iteration))

            self.logger.info('Checkpointing model.')
            self.save_model('validation_ckpt')
            if self.device == torch.device('cuda'):
                allocated_memory = torch.cuda.max_memory_allocated(self.device)

                self.logger.info('Memory allocated in current iteration: {}{}'.format(allocated_memory, self.iteration))

            ged_list = []
            dice_list = []
            ncc_list = []
            elbo_list = []
            kl_list = []
            recon_list = []

            time_ = time.time()

            validation_set_size = data.validation.images.shape[0]\
                if self.exp_config.num_validation_images == 'all' else self.exp_config.num_validation_images

            for ii in range(validation_set_size):

                s_gt_arr = data.validation.labels[ii, ...]

                # from HW to NCHW
                x_b = data.validation.images[ii, ...]
                patch = torch.tensor(x_b, dtype=torch.float32).to(self.device)
                val_patch = patch.unsqueeze(dim=0).unsqueeze(dim=1)

                s_b = s_gt_arr[:, :, np.random.choice(self.exp_config.annotator_range)]
                mask = torch.tensor(s_b, dtype=torch.float32).to(self.device)
                val_mask = mask.unsqueeze(dim=0).unsqueeze(dim=1)
                val_masks = torch.tensor(s_gt_arr, dtype=torch.float32).to(self.device)  # HWC
                val_masks = val_masks.transpose(0, 2).transpose(1, 2)  # CHW

                patch_arrangement = val_patch.repeat((self.exp_config.validation_samples, 1, 1, 1))

                mask_arrangement = val_mask.repeat((self.exp_config.validation_samples, 1, 1, 1))

                self.mask = mask_arrangement
                self.patch = patch_arrangement

                # training=True for constructing posterior as well
                s_out_eval_list = self.net.forward(patch_arrangement, mask_arrangement, training=False)
                s_prediction_softmax_arrangement = self.net.accumulate_output(s_out_eval_list, use_softmax=True)

                # sample N times
                self.val_loss = self.net.loss(mask_arrangement)
                elbo = self.val_loss
                kl = self.net.kl_divergence_loss
                recon = self.net.reconstruction_loss

                s_prediction_softmax_mean = torch.mean(s_prediction_softmax_arrangement, axis=0)
                s_prediction_arrangement = torch.argmax(s_prediction_softmax_arrangement, dim=1)

                ground_truth_arrangement = val_masks  # nlabels, H, W
                ged = utils.generalised_energy_distance(s_prediction_arrangement, ground_truth_arrangement,
                                                        nlabels=self.exp_config.n_classes - 1,
                                                        label_range=range(1, self.exp_config.n_classes))

                # num_gts, nlabels, H, W
                s_gt_arr_r = val_masks.unsqueeze(dim=1)
                ground_truth_arrangement_one_hot = utils.convert_batch_to_onehot(s_gt_arr_r, nlabels=self.exp_config.n_classes)
                ncc = utils.variance_ncc_dist(s_prediction_softmax_arrangement, ground_truth_arrangement_one_hot)

                s_ = torch.argmax(s_prediction_softmax_mean, dim=0) # HW
                s = val_mask.view(val_mask.shape[-2], val_mask.shape[-1]) #HW

                # Write losses to list
                per_lbl_dice = []
                for lbl in range(self.exp_config.n_classes):
                    binary_pred = (s_ == lbl) * 1
                    binary_gt = (s == lbl) * 1

                    if torch.sum(binary_gt) == 0 and torch.sum(binary_pred) == 0:
                        per_lbl_dice.append(1.0)
                    elif torch.sum(binary_pred) > 0 and torch.sum(binary_gt) == 0 or torch.sum(binary_pred) == 0 and torch.sum(
                            binary_gt) > 0:
                        per_lbl_dice.append(0.0)
                    else:
                        per_lbl_dice.append(dc(binary_pred.detach().cpu().numpy(), binary_gt.detach().cpu().numpy()))

                dice_list.append(per_lbl_dice)
                elbo_list.append(elbo)
                kl_list.append(kl)
                recon_list.append(recon)

                ged_list.append(ged)
                ncc_list.append(ncc)

            dice_tensor = torch.tensor(dice_list)
            per_structure_dice = dice_tensor.mean(dim=0)

            elbo_tensor = torch.tensor(elbo_list)
            kl_tensor = torch.tensor(kl_list)
            recon_tensor = torch.tensor(recon_list)

            ged_tensor = torch.tensor(ged_list)
            ncc_tensor = torch.tensor(ncc_list)

            self.avg_dice = torch.mean(dice_tensor)
            self.foreground_dice = torch.mean(dice_tensor, dim=0)[1]
            self.val_elbo = torch.mean(elbo_tensor)
            self.val_recon_loss = torch.mean(recon_tensor)
            self.val_kl_loss = torch.mean(kl_tensor)

            self.avg_ged = torch.mean(ged_tensor)
            self.avg_ncc = torch.mean(ncc_tensor)

            self.logger.info(' - Foreground dice: %.4f' % torch.mean(self.foreground_dice))
            self.logger.info(' - Mean (neg.) ELBO: %.4f' % self.val_elbo)
            self.logger.info(' - Mean GED: %.4f' % self.avg_ged)
            self.logger.info(' - Mean NCC: %.4f' % self.avg_ncc)

            if torch.mean(per_structure_dice) >= self.best_dice:
                self.best_dice = torch.mean(per_structure_dice)
                self.logger.info('New best validation Dice! (%.3f)' % self.best_dice)
                self.save_model(savename='best_dice')
            if self.val_elbo <= self.best_loss:
                self.best_loss = self.val_elbo
                self.logger.info('New best validation loss! (%.3f)' % self.best_loss)
                self.save_model(savename='best_loss')
            if self.avg_ged <= self.best_ged:
                self.best_ged = self.avg_ged
                self.logger.info('New best GED score! (%.3f)' % self.best_ged)
                self.save_model(savename='best_ged')
            if self.avg_ncc >= self.best_ncc:
                self.best_ncc = self.avg_ncc
                self.logger.info('New best NCC score! (%.3f)' % self.best_ncc)
                self.save_model(savename='best_ncc')

            self.logger.info('Validation took {} seconds'.format(time.time()-time_))

        self.net.train()
示例#6
0
def test(model_path, exp_config, model_selection='latest', num_samples=100, overwrite=False, mode=False):
    output_path = get_output_path(model_path, num_samples, model_selection, mode) + '.pickle'
    if os.path.exists(output_path) and not overwrite:
        return
    image_saver = ImageSaver(os.path.join(model_path, 'samples'))
    tf.reset_default_graph()
    phiseg_model = phiseg(exp_config=exp_config)
    phiseg_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    metrics = {key: [] for key in
               ['dsc', 'presence', 'ged', 'ncc', 'entropy', 'diversity', 'sample_dsc', 'ece', 'unweighted_ece',
                'loglikelihood']}

    num_samples = 1 if exp_config.likelihood is likelihoods.det_unet2D else num_samples

    for ii in tqdm(range(data.test.images.shape[0])):
        image = data.test.images[ii, ...].reshape([1] + list(exp_config.image_size))
        targets = data.test.labels[ii, ...].transpose((2, 0, 1))

        feed_dict = {phiseg_model.training_pl: False,
                     phiseg_model.x_inp: np.tile(image, [num_samples, 1, 1, 1])}

        prob_maps = phiseg_model.sess.run(phiseg_model.s_out_eval_sm, feed_dict=feed_dict)
        samples = np.argmax(prob_maps, axis=-1)
        probability = np.mean(prob_maps, axis=0) + 1e-10
        metrics['entropy'].append(float(np.sum(-probability * np.log(probability))))
        if mode:
            prediction = np.round(np.mean(np.argmax(prob_maps, axis=-1), axis=0)).astype(np.int64)
        else:
            if 'proposed' not in exp_config.experiment_name:
                prediction = np.argmax(np.sum(prob_maps, axis=0), axis=-1)
            else:
                mean = phiseg_model.sess.run(phiseg_model.dist_eval.loc, feed_dict=feed_dict)[0]
                mean = np.reshape(mean, image.shape[:-1] + (2,))
                prediction = np.argmax(mean, axis=-1)

        metrics['loglikelihood'].append(calculate_log_likelihood(targets, prob_maps))
        # calculate DSC per expert
        metrics['dsc'].append(
            [[calc_dsc(target == i, prediction == i) for i in range(exp_config.nlabels)] for target in targets])
        metrics['presence'].append([[np.any(target == i) for i in range(exp_config.nlabels)] for target in targets])

        metrics['sample_dsc'].append([[[calc_dsc(target == i, sample == i) for i in range(exp_config.nlabels)]
                                       for target in targets] for sample in samples])

        # ged and diversity
        ged_, diversity_ = utils.generalised_energy_distance(samples, targets, exp_config.nlabels - 1,
                                                             range(1, exp_config.nlabels))
        metrics['ged'].append(ged_)
        metrics['diversity'].append(diversity_)
        # NCC
        targets_one_hot = utils.to_one_hot(targets, exp_config.nlabels)
        metrics['ncc'].append(utils.variance_ncc_dist(prob_maps, targets_one_hot)[0])
        prob_map = np.mean(prob_maps, axis=0)
        ece, unweighted_ece = calc_class_wise_expected_calibration_error(targets, prob_map, 2, 10)
        metrics['ece'].append(ece)
        metrics['unweighted_ece'].append(unweighted_ece)
        image_saver(str(ii) + '/', image[0, ..., 0], targets, prediction, samples)

    metrics = {key: np.array(metric) for key, metric in metrics.items()}
    with open(output_path, 'wb') as f:
        pickle.dump(metrics, f)
    image_saver.close()