Ejemplo n.º 1
0
            # _, train_loss, unet, __, train_loss2, segs, ___, train_loss3,segs2 = sess.run([optimizer,
            #                                                                     loss, unet_seg, optimizer2, loss2, seg, optimizer3, loss3, seg_prior],
            #                                                                  feed_dict={images: batch,
                                                                                        # y_true: y_train_batch[step]})

            segs_t.append(segs)
            # segs_t2.append(segs2)
            u_t.append(unet)

            tl.append(train_loss)
            # cet.append(ce_t)
            # klt.append(kl_t)

        # calculo de loss promedio entre los batches
        mean_train = utils.list_mean(tl)
        train_loss_.append(mean_train)

        # mean_ce_train = utils.list_mean(cet)
        # mean_kl_train = utils.list_mean(klt)

        # validación, se realiza en todas las épocas pero se podría modificar la frecuencia
        if (epoch+1) % 1 == 0:
            for step, batch in enumerate(X_val_batch):
                val_loss, segs, unet = sess.run([loss, seg, unet_seg],
                                                feed_dict={images: batch,
                                                           y_true: y_val_batch[step]})
                # val_loss, unet, val_loss2, segs, val_loss3, segs2 = sess.run([loss, unet_seg, loss2, seg, loss3, seg_prior],
                #                            feed_dict={images: batch,
                #                                            y_true: y_val_batch[step]})
Ejemplo n.º 2
0
    def _do_validation(self, data):

        global_step = self.sess.run(self.global_step) - 1

        checkpoint_file = os.path.join(self.log_dir, 'model.ckpt')
        self.saver.save(self.sess, checkpoint_file, global_step=global_step)

        val_x, val_s = data.validation.next_batch(self.exp_config.batch_size)
        val_losses_out = self.sess.run(list(self.loss_dict.values()),
                                       feed_dict={
                                           self.x_inp: val_x,
                                           self.s_inp: val_s,
                                           self.training_pl: False
                                       })

        # Note that val_losses_out are now sorted in the same way as loss_dict,
        tot_loss_index = list(self.loss_dict.keys()).index('total_loss')
        val_loss_tot = val_losses_out[tot_loss_index]

        train_x, train_s = data.train.next_batch(self.exp_config.batch_size)
        train_losses_out = self.sess.run(list(self.loss_dict.values()),
                                         feed_dict={
                                             self.x_inp: train_x,
                                             self.s_inp: train_s,
                                             self.training_pl: False
                                         })

        logging.info('----- Step: %d ------' % global_step)
        logging.info('BATCH VALIDATION:')
        for ii, loss_name in enumerate(self.loss_dict.keys()):
            logging.info('%s | training: %f | validation: %f' %
                         (loss_name, train_losses_out[ii], val_losses_out[ii]))

        # Evaluate validation Dice:

        start_dice_val = time.time()
        num_batches = 0

        dice_list = []
        elbo_list = []
        ged_list = []
        ncc_list = []

        N = data.validation.images.shape[0]

        for ii in range(N):

            # logging.info(ii)

            x = data.validation.images[ii, ...].reshape(
                [1] + list(self.exp_config.image_size))
            s_gt_arr = data.validation.labels[ii, ...]
            s = s_gt_arr[:, :,
                         np.random.choice(self.exp_config.annotator_range)]

            x_b = np.tile(x, [self.exp_config.validation_samples, 1, 1, 1])
            s_b = np.tile(s, [self.exp_config.validation_samples, 1, 1])

            feed_dict = {}
            feed_dict[self.training_pl] = False
            feed_dict[self.x_inp] = x_b
            feed_dict[self.s_inp] = s_b

            s_pred_sm_arr, elbo = self.sess.run(
                [self.s_out_eval_sm, self.loss_tot], feed_dict=feed_dict)

            s_pred_sm_mean_ = np.mean(s_pred_sm_arr, axis=0)

            s_pred_arr = np.argmax(s_pred_sm_arr, axis=-1)
            s_gt_arr_r = s_gt_arr.transpose((2, 0, 1))  # num gts x X x Y

            s_gt_arr_r_sm = utils.convert_batch_to_onehot(
                s_gt_arr_r,
                self.exp_config.nlabels)  # num gts x X x Y x nlabels

            ged = utils.generalised_energy_distance(
                s_pred_arr,
                s_gt_arr_r,
                nlabels=self.exp_config.nlabels - 1,
                label_range=range(1, self.exp_config.nlabels))

            ncc = utils.variance_ncc_dist(s_pred_sm_arr, s_gt_arr_r_sm)

            s_ = np.argmax(s_pred_sm_mean_, axis=-1)

            # Write losses to list
            per_lbl_dice = []
            for lbl in range(self.exp_config.nlabels):
                binary_pred = (s_ == lbl) * 1
                binary_gt = (s == lbl) * 1

                if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                    per_lbl_dice.append(1)
                elif np.sum(binary_pred) > 0 and np.sum(
                        binary_gt) == 0 or np.sum(
                            binary_pred) == 0 and np.sum(binary_gt) > 0:
                    per_lbl_dice.append(0)
                else:
                    per_lbl_dice.append(dc(binary_pred, binary_gt))

            num_batches += 1

            dice_list.append(per_lbl_dice)
            elbo_list.append(elbo)
            ged_list.append(ged)
            ncc_list.append(ncc)

        dice_arr = np.asarray(dice_list)
        per_structure_dice = dice_arr.mean(axis=0)

        avg_dice = np.mean(dice_arr)
        avg_elbo = utils.list_mean(elbo_list)
        avg_ged = utils.list_mean(ged_list)
        avg_ncc = utils.list_mean(ncc_list)

        logging.info('FULL VALIDATION (%d images):' % N)
        logging.info(' - Mean foreground dice: %.4f' %
                     np.mean(per_structure_dice))
        logging.info(' - Mean (neg.) ELBO: %.4f' % avg_elbo)
        logging.info(' - Mean GED: %.4f' % avg_ged)
        logging.info(' - Mean NCC: %.4f' % avg_ncc)

        logging.info('@ Running through validation set took: %.2f secs' %
                     (time.time() - start_dice_val))

        if np.mean(per_structure_dice) >= self.best_dice:
            self.best_dice = np.mean(per_structure_dice)
            logging.info('New best validation Dice! (%.3f)' % self.best_dice)
            best_file = os.path.join(self.log_dir, 'model_best_dice.ckpt')
            self.saver_best_dice.save(self.sess,
                                      best_file,
                                      global_step=global_step)

        if avg_elbo <= self.best_loss:
            self.best_loss = avg_elbo
            logging.info('New best validation loss! (%.3f)' % self.best_loss)
            best_file = os.path.join(self.log_dir, 'model_best_loss.ckpt')
            self.saver_best_loss.save(self.sess,
                                      best_file,
                                      global_step=global_step)

        if avg_ged <= self.best_ged:
            self.best_ged = avg_ged
            logging.info('New best GED score! (%.3f)' % self.best_ged)
            best_file = os.path.join(self.log_dir, 'model_best_ged.ckpt')
            self.saver_best_ged.save(self.sess,
                                     best_file,
                                     global_step=global_step)

        if avg_ncc >= self.best_ncc:
            self.best_ncc = avg_ncc
            logging.info('New best NCC score! (%.3f)' % self.best_ncc)
            best_file = os.path.join(self.log_dir, 'model_best_ncc.ckpt')
            self.saver_best_ncc.save(self.sess,
                                     best_file,
                                     global_step=global_step)

        # Create Validation Summary feed dict
        z_prior_list = self.generate_prior_samples(x_in=val_x)
        val_summary_feed_dict = {
            i: d
            for i, d in zip(self.z_list_gen, z_prior_list)
        }  # this is for prior samples
        val_summary_feed_dict[self.x_for_gen] = val_x

        # Fill placeholders for all losses
        for loss_key, loss_val in zip(self.loss_dict.keys(), val_losses_out):
            # The detour over loss_dict.keys() is necessary because val_losses_out is sorted in the same
            # way as loss_dict. Same for the training below.
            loss_pl = self.validation_loss_pl_dict[loss_key]
            val_summary_feed_dict[loss_pl] = loss_val

        # Fill placeholders for validation Dice
        val_summary_feed_dict[self.val_tot_dice_score] = avg_dice
        val_summary_feed_dict[self.val_mean_dice_score] = np.mean(
            per_structure_dice)
        val_summary_feed_dict[self.val_elbo] = avg_elbo
        val_summary_feed_dict[self.val_ged] = avg_ged
        val_summary_feed_dict[self.val_ncc] = np.squeeze(avg_ncc)

        for ii in range(self.exp_config.nlabels):
            val_summary_feed_dict[
                self.val_lbl_dice_scores[ii]] = per_structure_dice[ii]

        val_summary_feed_dict[self.x_inp] = val_x
        val_summary_feed_dict[self.s_inp] = val_s
        val_summary_feed_dict[self.training_pl] = False

        val_summary_msg = self.sess.run(self.val_summary,
                                        feed_dict=val_summary_feed_dict)
        self.summary_writer.add_summary(val_summary_msg, global_step)

        # Create train Summary feed dict
        train_summary_feed_dict = {}
        for loss_key, loss_val in zip(self.loss_dict.keys(), train_losses_out):
            loss_pl = self.train_loss_pl_dict[loss_key]
            train_summary_feed_dict[loss_pl] = loss_val
        train_summary_feed_dict[self.training_pl] = False

        train_summary_msg = self.sess.run(self.train_summary,
                                          feed_dict=train_summary_feed_dict)
        self.summary_writer.add_summary(train_summary_msg, global_step)
Ejemplo n.º 3
0
        for step, batch in enumerate(X_train_batch):

            _, train_loss, unet, sout, kl_t, ce_t, train_dc = sess.run(
                [optimizer, reg_loss, unet_seg, seg, kl, ce, dice_coef],
                feed_dict={
                    images: batch,
                    y_true: y_train_batch[step]
                })

            tl.append(train_loss)
            tdc.append(train_dc)
            cet.append(ce_t)
            klt.append(kl_t)

        mean_train = utils.list_mean(tl)
        mean_dc_train = utils.list_mean(tdc)

        train_loss_.append(mean_train)
        train_dc_.append(mean_dc_train)

        mean_ce_train = utils.list_mean(cet)
        mean_kl_train = utils.list_mean(klt)

        if (epoch + 1) % 1 == 0:
            for step, batch in enumerate(X_val_batch):
                val_loss, kl_v, ce_v, unet, sout, val_dc = sess.run(
                    [reg_loss, kl, ce, unet_seg, seg, dice_coef],
                    feed_dict={
                        images: batch,
                        y_true: y_val_batch[step]