Exemplo n.º 1
0
def main(model_path, exp_config, do_plots=False):

    n_samples = 50
    model_selection = 'best_ged'

    # Get Data
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    N = data.test.images.shape[0]

    ged_list = []
    ncc_list = []

    for ii in range(N):

        if ii % 10 == 0:
            logging.info("Progress: %d" % ii)

        x_b = data.test.images[ii, ...].reshape([1] + list(exp_config.image_size))
        s_b = data.test.labels[ii, ...]

        x_b_stacked = np.tile(x_b, [n_samples, 1, 1, 1])

        feed_dict = {}
        feed_dict[segvae_model.training_pl] = False
        feed_dict[segvae_model.x_inp] = x_b_stacked


        s_arr_sm = segvae_model.sess.run(segvae_model.s_out_eval_sm, feed_dict=feed_dict)
        s_arr = np.argmax(s_arr_sm, axis=-1)

        # s_arr = np.squeeze(np.asarray(s_list)) # num samples x X x Y
        s_b_r = s_b.transpose((2,0,1)) # num gts x X x Y
        s_b_r_sm = utils.convert_batch_to_onehot(s_b_r, exp_config.nlabels)  # num gts x X x Y x nlabels

        ged = utils.generalised_energy_distance(s_arr, s_b_r, nlabels=exp_config.nlabels-1, label_range=range(1,exp_config.nlabels))
        ged_list.append(ged)

        ncc = utils.variance_ncc_dist(s_arr_sm, s_b_r_sm)
        ncc_list.append(ncc)



    ged_arr = np.asarray(ged_list)
    ncc_arr = np.asarray(ncc_list)

    logging.info('-- GED: --')
    logging.info(np.mean(ged_arr))
    logging.info(np.std(ged_arr))

    logging.info('-- NCC: --')
    logging.info(np.mean(ncc_arr))
    logging.info(np.std(ncc_arr))

    np.savez(os.path.join(model_path, 'ged%s_%s.npz' % (str(n_samples), model_selection)), ged_arr)
    np.savez(os.path.join(model_path, 'ncc%s_%s.npz' % (str(n_samples), model_selection)), ncc_arr)
Exemplo n.º 2
0
    def forward(self, patch, segm=None, training_prior=False, z_list=None):
        if segm is not None:

            with torch.no_grad():
                segm_one_hot = utils.convert_batch_to_onehot(segm, nlabels=2)\
                    .to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

                segm_one_hot = segm_one_hot.float()
            patch = torch.cat([patch, torch.add(segm_one_hot, -0.5)], dim=1)

        blocks = []
        z = [None] * self.latent_levels # contains all hidden z
        sigma = [None] * self.latent_levels
        mu = [None] * self.latent_levels

        x = patch
        for i, down in enumerate(self.contracting_path):
            x = down(x)
            if i != len(self.contracting_path) - 1:
                blocks.append(x)

        pre_conv = x
        for i, sample_z in enumerate(self.sample_z_path):
            if i != 0:
                pre_conv = self.upsampling_path[i-1](z[-i], blocks[-i])
            mu[-i-1], sigma[-i-1], z[-i-1] = self.sample_z_path[i](pre_conv)
            if training_prior:
                z[-i-1] = z_list[-i-1]

        del blocks

        return z, mu, sigma
Exemplo n.º 3
0
    def forward(self, input, segm=None):
        if segm is not None:
            with torch.no_grad():
                segm_one_hot = utils.convert_batch_to_onehot(segm, nlabels=2) \
                    .to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

                segm_one_hot = segm_one_hot.float()
            input = torch.cat([input, torch.add(segm_one_hot, -0.5)], dim=1)

        encoding = self.encoder(input)

        # We only want the mean of the resulting hxw image
        encoding = torch.mean(encoding, dim=2, keepdim=True)
        encoding = torch.mean(encoding, dim=3, keepdim=True)

        # Convert encoding to 2 x latent dim and split up for mu and log_sigma
        mu_log_sigma = self.conv_layer(encoding)

        # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)

        mu = mu_log_sigma[:, :self.latent_dim]
        log_sigma = mu_log_sigma[:, self.latent_dim:]

        # This is a multivariate normal with diagonal covariance matrix sigma
        # https://github.com/pytorch/pytorch/pull/11178
        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)
        return dist
Exemplo n.º 4
0
def test_ncc(lidc_data):
    random_index = 99
    s_gt_arr = lidc_data.test.labels[random_index, ...]

    x_b = lidc_data.test.images[random_index, ...]
    patch = torch.tensor(x_b, dtype=torch.float32).to('cpu')

    assert s_gt_arr.shape == (128, 128, 4)
    val_masks = torch.tensor(s_gt_arr, dtype=torch.float32).to('cpu')  # HWC
    val_masks = val_masks.transpose(0, 2).transpose(1, 2)
    assert val_masks.shape == (4, 128, 128)

    s_gt_arr_r = val_masks.unsqueeze(dim=1)

    ground_truth_arrangement_one_hot = utils.convert_batch_to_onehot(
        s_gt_arr_r, nlabels=2)

    ncc = utils.variance_ncc_dist(ground_truth_arrangement_one_hot,
                                  ground_truth_arrangement_one_hot)

    assert math.isclose(ncc[0], 1.0)
def main(model_path, exp_config):

    # Make and restore vagan model
    phiseg_model = phiseg(exp_config=exp_config)
    phiseg_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    N = data.test.images.shape[0]

    n_images = 16
    n_samples = 16

    # indices = np.arange(N)
    # sample_inds = np.random.choice(indices, n_images)
    sample_inds = [165, 280, 213]  # <-- prostate
    # sample_inds = [1551] #[907, 1296, 1551]  # <-- LIDC

    for ii in sample_inds:

        print('------- Processing image %d -------' % ii)

        outfolder = os.path.join(model_path, 'samples_%s' % model_selection,
                                 str(ii))
        utils.makefolder(outfolder)

        x_b = data.test.images[ii,
                               ...].reshape([1] + list(exp_config.image_size))
        s_b = data.test.labels[ii, ...]

        if np.sum(s_b) < 10:
            print('WARNING: skipping cases with no structures')
            continue

        s_b_r = utils.convert_batch_to_onehot(s_b.transpose((2, 0, 1)),
                                              exp_config.nlabels)

        print('Plotting input image')
        plt.figure()
        x_b_d = preproc_image(x_b)
        plt.imshow(x_b_d, cmap='gray')
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'input_img_%d.png' % ii),
                    bbox_inches='tight')

        print('Generating 100 samples')
        s_p_list = []
        for kk in range(100):
            s_p_list.append(
                phiseg_model.predict_segmentation_sample(x_b,
                                                         return_softmax=True))
        s_p_arr = np.squeeze(np.asarray(s_p_list))

        print('Plotting %d of those samples' % n_samples)
        for jj in range(n_samples):

            s_p_sm = s_p_arr[jj, ...]
            s_p_am = np.argmax(s_p_sm, axis=-1)

            plt.figure()
            s_p_d = preproc_image(s_p_am, nlabels=exp_config.nlabels)
            plt.imshow(s_p_d, cmap='gray')
            plt.axis('off')
            plt.savefig(os.path.join(outfolder,
                                     'sample_img_%d_samp_%d.png' % (ii, jj)),
                        bbox_inches='tight')

        print('Plotting ground-truths masks')
        for jj in range(s_b_r.shape[0]):

            s_b_sm = s_b_r[jj, ...]
            s_b_am = np.argmax(s_b_sm, axis=-1)

            plt.figure()
            s_p_d = preproc_image(s_b_am, nlabels=exp_config.nlabels)
            plt.imshow(s_p_d, cmap='gray')
            plt.axis('off')
            plt.savefig(os.path.join(outfolder,
                                     'gt_img_%d_samp_%d.png' % (ii, jj)),
                        bbox_inches='tight')

        print('Generating error masks')
        E_ss, E_sy_avg, E_yy_avg = generate_error_maps(s_p_arr, s_b_r)

        print('Plotting them')
        plt.figure()
        plt.imshow(preproc_image(E_ss))
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'E_ss_%d.png' % ii),
                    bbox_inches='tight')

        print('Plotting them')
        plt.figure()
        plt.imshow(preproc_image(np.log(E_ss)))
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'log_E_ss_%d.png' % ii),
                    bbox_inches='tight')

        plt.figure()
        plt.imshow(preproc_image(E_sy_avg))
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'E_sy_avg_%d_.png' % ii),
                    bbox_inches='tight')

        plt.figure()
        plt.imshow(preproc_image(E_yy_avg))
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'E_yy_avg_%d_.png' % ii),
                    bbox_inches='tight')

        plt.close('all')
Exemplo n.º 6
0
    def _do_validation(self, data):

        global_step = self.sess.run(self.global_step) - 1

        checkpoint_file = os.path.join(self.log_dir, 'model.ckpt')
        self.saver.save(self.sess, checkpoint_file, global_step=global_step)

        val_x, val_s = data.validation.next_batch(self.exp_config.batch_size)
        val_losses_out = self.sess.run(list(self.loss_dict.values()),
                                       feed_dict={
                                           self.x_inp: val_x,
                                           self.s_inp: val_s,
                                           self.training_pl: False
                                       })

        # Note that val_losses_out are now sorted in the same way as loss_dict,
        tot_loss_index = list(self.loss_dict.keys()).index('total_loss')
        val_loss_tot = val_losses_out[tot_loss_index]

        train_x, train_s = data.train.next_batch(self.exp_config.batch_size)
        train_losses_out = self.sess.run(list(self.loss_dict.values()),
                                         feed_dict={
                                             self.x_inp: train_x,
                                             self.s_inp: train_s,
                                             self.training_pl: False
                                         })

        logging.info('----- Step: %d ------' % global_step)
        logging.info('BATCH VALIDATION:')
        for ii, loss_name in enumerate(self.loss_dict.keys()):
            logging.info('%s | training: %f | validation: %f' %
                         (loss_name, train_losses_out[ii], val_losses_out[ii]))

        # Evaluate validation Dice:

        start_dice_val = time.time()
        num_batches = 0

        dice_list = []
        elbo_list = []
        ged_list = []
        ncc_list = []

        N = data.validation.images.shape[0]

        for ii in range(N):

            # logging.info(ii)

            x = data.validation.images[ii, ...].reshape(
                [1] + list(self.exp_config.image_size))
            s_gt_arr = data.validation.labels[ii, ...]
            s = s_gt_arr[:, :,
                         np.random.choice(self.exp_config.annotator_range)]

            x_b = np.tile(x, [self.exp_config.validation_samples, 1, 1, 1])
            s_b = np.tile(s, [self.exp_config.validation_samples, 1, 1])

            feed_dict = {}
            feed_dict[self.training_pl] = False
            feed_dict[self.x_inp] = x_b
            feed_dict[self.s_inp] = s_b

            s_pred_sm_arr, elbo = self.sess.run(
                [self.s_out_eval_sm, self.loss_tot], feed_dict=feed_dict)

            s_pred_sm_mean_ = np.mean(s_pred_sm_arr, axis=0)

            s_pred_arr = np.argmax(s_pred_sm_arr, axis=-1)
            s_gt_arr_r = s_gt_arr.transpose((2, 0, 1))  # num gts x X x Y

            s_gt_arr_r_sm = utils.convert_batch_to_onehot(
                s_gt_arr_r,
                self.exp_config.nlabels)  # num gts x X x Y x nlabels

            ged = utils.generalised_energy_distance(
                s_pred_arr,
                s_gt_arr_r,
                nlabels=self.exp_config.nlabels - 1,
                label_range=range(1, self.exp_config.nlabels))

            ncc = utils.variance_ncc_dist(s_pred_sm_arr, s_gt_arr_r_sm)

            s_ = np.argmax(s_pred_sm_mean_, axis=-1)

            # Write losses to list
            per_lbl_dice = []
            for lbl in range(self.exp_config.nlabels):
                binary_pred = (s_ == lbl) * 1
                binary_gt = (s == lbl) * 1

                if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                    per_lbl_dice.append(1)
                elif np.sum(binary_pred) > 0 and np.sum(
                        binary_gt) == 0 or np.sum(
                            binary_pred) == 0 and np.sum(binary_gt) > 0:
                    per_lbl_dice.append(0)
                else:
                    per_lbl_dice.append(dc(binary_pred, binary_gt))

            num_batches += 1

            dice_list.append(per_lbl_dice)
            elbo_list.append(elbo)
            ged_list.append(ged)
            ncc_list.append(ncc)

        dice_arr = np.asarray(dice_list)
        per_structure_dice = dice_arr.mean(axis=0)

        avg_dice = np.mean(dice_arr)
        avg_elbo = utils.list_mean(elbo_list)
        avg_ged = utils.list_mean(ged_list)
        avg_ncc = utils.list_mean(ncc_list)

        logging.info('FULL VALIDATION (%d images):' % N)
        logging.info(' - Mean foreground dice: %.4f' %
                     np.mean(per_structure_dice))
        logging.info(' - Mean (neg.) ELBO: %.4f' % avg_elbo)
        logging.info(' - Mean GED: %.4f' % avg_ged)
        logging.info(' - Mean NCC: %.4f' % avg_ncc)

        logging.info('@ Running through validation set took: %.2f secs' %
                     (time.time() - start_dice_val))

        if np.mean(per_structure_dice) >= self.best_dice:
            self.best_dice = np.mean(per_structure_dice)
            logging.info('New best validation Dice! (%.3f)' % self.best_dice)
            best_file = os.path.join(self.log_dir, 'model_best_dice.ckpt')
            self.saver_best_dice.save(self.sess,
                                      best_file,
                                      global_step=global_step)

        if avg_elbo <= self.best_loss:
            self.best_loss = avg_elbo
            logging.info('New best validation loss! (%.3f)' % self.best_loss)
            best_file = os.path.join(self.log_dir, 'model_best_loss.ckpt')
            self.saver_best_loss.save(self.sess,
                                      best_file,
                                      global_step=global_step)

        if avg_ged <= self.best_ged:
            self.best_ged = avg_ged
            logging.info('New best GED score! (%.3f)' % self.best_ged)
            best_file = os.path.join(self.log_dir, 'model_best_ged.ckpt')
            self.saver_best_ged.save(self.sess,
                                     best_file,
                                     global_step=global_step)

        if avg_ncc >= self.best_ncc:
            self.best_ncc = avg_ncc
            logging.info('New best NCC score! (%.3f)' % self.best_ncc)
            best_file = os.path.join(self.log_dir, 'model_best_ncc.ckpt')
            self.saver_best_ncc.save(self.sess,
                                     best_file,
                                     global_step=global_step)

        # Create Validation Summary feed dict
        z_prior_list = self.generate_prior_samples(x_in=val_x)
        val_summary_feed_dict = {
            i: d
            for i, d in zip(self.z_list_gen, z_prior_list)
        }  # this is for prior samples
        val_summary_feed_dict[self.x_for_gen] = val_x

        # Fill placeholders for all losses
        for loss_key, loss_val in zip(self.loss_dict.keys(), val_losses_out):
            # The detour over loss_dict.keys() is necessary because val_losses_out is sorted in the same
            # way as loss_dict. Same for the training below.
            loss_pl = self.validation_loss_pl_dict[loss_key]
            val_summary_feed_dict[loss_pl] = loss_val

        # Fill placeholders for validation Dice
        val_summary_feed_dict[self.val_tot_dice_score] = avg_dice
        val_summary_feed_dict[self.val_mean_dice_score] = np.mean(
            per_structure_dice)
        val_summary_feed_dict[self.val_elbo] = avg_elbo
        val_summary_feed_dict[self.val_ged] = avg_ged
        val_summary_feed_dict[self.val_ncc] = np.squeeze(avg_ncc)

        for ii in range(self.exp_config.nlabels):
            val_summary_feed_dict[
                self.val_lbl_dice_scores[ii]] = per_structure_dice[ii]

        val_summary_feed_dict[self.x_inp] = val_x
        val_summary_feed_dict[self.s_inp] = val_s
        val_summary_feed_dict[self.training_pl] = False

        val_summary_msg = self.sess.run(self.val_summary,
                                        feed_dict=val_summary_feed_dict)
        self.summary_writer.add_summary(val_summary_msg, global_step)

        # Create train Summary feed dict
        train_summary_feed_dict = {}
        for loss_key, loss_val in zip(self.loss_dict.keys(), train_losses_out):
            loss_pl = self.train_loss_pl_dict[loss_key]
            train_summary_feed_dict[loss_pl] = loss_val
        train_summary_feed_dict[self.training_pl] = False

        train_summary_msg = self.sess.run(self.train_summary,
                                          feed_dict=train_summary_feed_dict)
        self.summary_writer.add_summary(train_summary_msg, global_step)
Exemplo n.º 7
0
    def test(self, data, sys_config):
        self.net.eval()
        with torch.no_grad():

            model_selection = self.exp_config.experiment_name + '_best_loss.pth'
            self.logger.info('Testing {}'.format(model_selection))

            self.logger.info('Loading pretrained model {}'.format(model_selection))

            model_path = os.path.join(
                sys_config.log_root,
                self.exp_config.log_dir_name,
                self.exp_config.experiment_name,
                model_selection)

            if os.path.exists(model_path):
                self.net.load_state_dict(torch.load(model_path))
            else:
                self.logger.info('The file {} does not exist. Aborting test function.'.format(model_path))
                return

            ged_list = []
            dice_list = []
            ncc_list = []

            time_ = time.time()

            end_dice = 0.0
            end_ged = 0.0
            end_ncc = 0.0

            for i in range(10):
                self.logger.info('Doing iteration {}'.format(i))
                n_samples = 10

                for ii in range(data.test.images.shape[0]):

                    s_gt_arr = data.test.labels[ii, ...]

                    # from HW to NCHW
                    x_b = data.test.images[ii, ...]
                    patch = torch.tensor(x_b, dtype=torch.float32).to(self.device)
                    val_patch = patch.unsqueeze(dim=0).unsqueeze(dim=1)

                    s_b = s_gt_arr[:, :, np.random.choice(self.exp_config.annotator_range)]
                    mask = torch.tensor(s_b, dtype=torch.float32).to(self.device)
                    val_mask = mask.unsqueeze(dim=0).unsqueeze(dim=1)
                    val_masks = torch.tensor(s_gt_arr, dtype=torch.float32).to(self.device)  # HWC
                    val_masks = val_masks.transpose(0, 2).transpose(1, 2)  # CHW

                    patch_arrangement = val_patch.repeat((n_samples, 1, 1, 1))

                    mask_arrangement = val_mask.repeat((n_samples, 1, 1, 1))

                    self.mask = mask_arrangement
                    self.patch = patch_arrangement

                    # training=True for constructing posterior as well
                    s_out_eval_list = self.net.forward(patch_arrangement, mask_arrangement, training=False)
                    s_prediction_softmax_arrangement = self.net.accumulate_output(s_out_eval_list, use_softmax=True)

                    s_prediction_softmax_mean = torch.mean(s_prediction_softmax_arrangement, axis=0)
                    s_prediction_arrangement = torch.argmax(s_prediction_softmax_arrangement, dim=1)

                    ground_truth_arrangement = val_masks  # nlabels, H, W
                    ged = utils.generalised_energy_distance(s_prediction_arrangement, ground_truth_arrangement,
                                                            nlabels=self.exp_config.n_classes - 1,
                                                            label_range=range(1, self.exp_config.n_classes))

                    # num_gts, nlabels, H, W
                    s_gt_arr_r = val_masks.unsqueeze(dim=1)
                    ground_truth_arrangement_one_hot = utils.convert_batch_to_onehot(s_gt_arr_r,
                                                                                     nlabels=self.exp_config.n_classes)
                    ncc = utils.variance_ncc_dist(s_prediction_softmax_arrangement, ground_truth_arrangement_one_hot)

                    s_ = torch.argmax(s_prediction_softmax_mean, dim=0)  # HW
                    s = val_mask.view(val_mask.shape[-2], val_mask.shape[-1])  # HW

                    # Write losses to list
                    per_lbl_dice = []
                    for lbl in range(self.exp_config.n_classes):
                        binary_pred = (s_ == lbl) * 1
                        binary_gt = (s == lbl) * 1

                        if torch.sum(binary_gt) == 0 and torch.sum(binary_pred) == 0:
                            per_lbl_dice.append(1.0)
                        elif torch.sum(binary_pred) > 0 and torch.sum(binary_gt) == 0 or torch.sum(
                                binary_pred) == 0 and torch.sum(
                                binary_gt) > 0:
                            per_lbl_dice.append(0.0)
                        else:
                            per_lbl_dice.append(dc(binary_pred.detach().cpu().numpy(), binary_gt.detach().cpu().numpy()))
                    dice_list.append(per_lbl_dice)

                    ged_list.append(ged)
                    ncc_list.append(ncc)

                    if ii % 100 == 0:
                        self.logger.info(' - Mean GED: %.4f' % torch.mean(torch.tensor(ged_list)))
                        self.logger.info(' - Mean NCC: %.4f' % torch.mean(torch.tensor(ncc_list)))


                dice_tensor = torch.tensor(dice_list)
                per_structure_dice = dice_tensor.mean(dim=0)

                ged_tensor = torch.tensor(ged_list)
                ncc_tensor = torch.tensor(ncc_list)

                model_path = os.path.join(
                    sys_config.log_root,
                    self.exp_config.log_dir_name,
                    self.exp_config.experiment_name)

                np.savez(os.path.join(model_path, 'ged%s_%s_2.npz' % (str(n_samples), model_selection)), ged_tensor.numpy())
                np.savez(os.path.join(model_path, 'ncc%s_%s_2.npz' % (str(n_samples), model_selection)), ncc_tensor.numpy())

                self.avg_dice = torch.mean(dice_tensor)
                self.foreground_dice = torch.mean(dice_tensor, dim=0)[1]

                self.avg_ged = torch.mean(ged_tensor)
                self.avg_ncc = torch.mean(ncc_tensor)

                logging.info('-- GED: --')
                logging.info(torch.mean(ged_tensor))
                logging.info(torch.std(ged_tensor))

                logging.info('-- NCC: --')
                logging.info(torch.mean(ncc_tensor))
                logging.info(torch.std(ncc_tensor))

                self.logger.info(' - Foreground dice: %.4f' % torch.mean(self.foreground_dice))
                self.logger.info(' - Mean (neg.) ELBO: %.4f' % self.val_elbo)
                self.logger.info(' - Mean GED: %.4f' % self.avg_ged)
                self.logger.info(' - Mean NCC: %.4f' % self.avg_ncc)

                self.logger.info('Testing took {} seconds'.format(time.time() - time_))

                end_dice += self.avg_dice
                end_ged += self.avg_ged
                end_ncc += self.avg_ncc
            self.logger.info('Mean dice: {}'.format(end_dice/10))
            self.logger.info('Mean ged: {}'.format(end_ged / 10))
            self.logger.info('Mean ncc: {}'.format(end_ncc / 10))
Exemplo n.º 8
0
    def validate(self, data):
        self.net.eval()
        with torch.no_grad():
            self.logger.info('Validation for step {}'.format(self.iteration))

            self.logger.info('Checkpointing model.')
            self.save_model('validation_ckpt')
            if self.device == torch.device('cuda'):
                allocated_memory = torch.cuda.max_memory_allocated(self.device)

                self.logger.info('Memory allocated in current iteration: {}{}'.format(allocated_memory, self.iteration))

            ged_list = []
            dice_list = []
            ncc_list = []
            elbo_list = []
            kl_list = []
            recon_list = []

            time_ = time.time()

            validation_set_size = data.validation.images.shape[0]\
                if self.exp_config.num_validation_images == 'all' else self.exp_config.num_validation_images

            for ii in range(validation_set_size):

                s_gt_arr = data.validation.labels[ii, ...]

                # from HW to NCHW
                x_b = data.validation.images[ii, ...]
                patch = torch.tensor(x_b, dtype=torch.float32).to(self.device)
                val_patch = patch.unsqueeze(dim=0).unsqueeze(dim=1)

                s_b = s_gt_arr[:, :, np.random.choice(self.exp_config.annotator_range)]
                mask = torch.tensor(s_b, dtype=torch.float32).to(self.device)
                val_mask = mask.unsqueeze(dim=0).unsqueeze(dim=1)
                val_masks = torch.tensor(s_gt_arr, dtype=torch.float32).to(self.device)  # HWC
                val_masks = val_masks.transpose(0, 2).transpose(1, 2)  # CHW

                patch_arrangement = val_patch.repeat((self.exp_config.validation_samples, 1, 1, 1))

                mask_arrangement = val_mask.repeat((self.exp_config.validation_samples, 1, 1, 1))

                self.mask = mask_arrangement
                self.patch = patch_arrangement

                # training=True for constructing posterior as well
                s_out_eval_list = self.net.forward(patch_arrangement, mask_arrangement, training=False)
                s_prediction_softmax_arrangement = self.net.accumulate_output(s_out_eval_list, use_softmax=True)

                # sample N times
                self.val_loss = self.net.loss(mask_arrangement)
                elbo = self.val_loss
                kl = self.net.kl_divergence_loss
                recon = self.net.reconstruction_loss

                s_prediction_softmax_mean = torch.mean(s_prediction_softmax_arrangement, axis=0)
                s_prediction_arrangement = torch.argmax(s_prediction_softmax_arrangement, dim=1)

                ground_truth_arrangement = val_masks  # nlabels, H, W
                ged = utils.generalised_energy_distance(s_prediction_arrangement, ground_truth_arrangement,
                                                        nlabels=self.exp_config.n_classes - 1,
                                                        label_range=range(1, self.exp_config.n_classes))

                # num_gts, nlabels, H, W
                s_gt_arr_r = val_masks.unsqueeze(dim=1)
                ground_truth_arrangement_one_hot = utils.convert_batch_to_onehot(s_gt_arr_r, nlabels=self.exp_config.n_classes)
                ncc = utils.variance_ncc_dist(s_prediction_softmax_arrangement, ground_truth_arrangement_one_hot)

                s_ = torch.argmax(s_prediction_softmax_mean, dim=0) # HW
                s = val_mask.view(val_mask.shape[-2], val_mask.shape[-1]) #HW

                # Write losses to list
                per_lbl_dice = []
                for lbl in range(self.exp_config.n_classes):
                    binary_pred = (s_ == lbl) * 1
                    binary_gt = (s == lbl) * 1

                    if torch.sum(binary_gt) == 0 and torch.sum(binary_pred) == 0:
                        per_lbl_dice.append(1.0)
                    elif torch.sum(binary_pred) > 0 and torch.sum(binary_gt) == 0 or torch.sum(binary_pred) == 0 and torch.sum(
                            binary_gt) > 0:
                        per_lbl_dice.append(0.0)
                    else:
                        per_lbl_dice.append(dc(binary_pred.detach().cpu().numpy(), binary_gt.detach().cpu().numpy()))

                dice_list.append(per_lbl_dice)
                elbo_list.append(elbo)
                kl_list.append(kl)
                recon_list.append(recon)

                ged_list.append(ged)
                ncc_list.append(ncc)

            dice_tensor = torch.tensor(dice_list)
            per_structure_dice = dice_tensor.mean(dim=0)

            elbo_tensor = torch.tensor(elbo_list)
            kl_tensor = torch.tensor(kl_list)
            recon_tensor = torch.tensor(recon_list)

            ged_tensor = torch.tensor(ged_list)
            ncc_tensor = torch.tensor(ncc_list)

            self.avg_dice = torch.mean(dice_tensor)
            self.foreground_dice = torch.mean(dice_tensor, dim=0)[1]
            self.val_elbo = torch.mean(elbo_tensor)
            self.val_recon_loss = torch.mean(recon_tensor)
            self.val_kl_loss = torch.mean(kl_tensor)

            self.avg_ged = torch.mean(ged_tensor)
            self.avg_ncc = torch.mean(ncc_tensor)

            self.logger.info(' - Foreground dice: %.4f' % torch.mean(self.foreground_dice))
            self.logger.info(' - Mean (neg.) ELBO: %.4f' % self.val_elbo)
            self.logger.info(' - Mean GED: %.4f' % self.avg_ged)
            self.logger.info(' - Mean NCC: %.4f' % self.avg_ncc)

            if torch.mean(per_structure_dice) >= self.best_dice:
                self.best_dice = torch.mean(per_structure_dice)
                self.logger.info('New best validation Dice! (%.3f)' % self.best_dice)
                self.save_model(savename='best_dice')
            if self.val_elbo <= self.best_loss:
                self.best_loss = self.val_elbo
                self.logger.info('New best validation loss! (%.3f)' % self.best_loss)
                self.save_model(savename='best_loss')
            if self.avg_ged <= self.best_ged:
                self.best_ged = self.avg_ged
                self.logger.info('New best GED score! (%.3f)' % self.best_ged)
                self.save_model(savename='best_ged')
            if self.avg_ncc >= self.best_ncc:
                self.best_ncc = self.avg_ncc
                self.logger.info('New best NCC score! (%.3f)' % self.best_ncc)
                self.save_model(savename='best_ncc')

            self.logger.info('Validation took {} seconds'.format(time.time()-time_))

        self.net.train()