Beispiel #1
0
    def train(self):
        self.sess.run(tf.local_variables_initializer())
        self.best_validation_loss = 0.0928
        self.best_mean_IOU = 0.56
        if self.conf.reload_step > 0:
            self.reload(self.conf.reload_step)
            print('----> Continue Training from step #{}'.format(
                self.conf.reload_step))
        else:
            print('----> Start Training')
        if self.conf.data == 'ct':
            from DataLoaders.Data_Loader_2D import DataLoader
        elif self.conf.data == 'camvid':
            from DataLoaders.CamVid_loader import DataLoader
        else:
            print('wrong data name')
        self.data_reader = DataLoader(self.conf)

        self.numValid = self.data_reader.count_num_samples(mode='valid')
        self.num_val_batch = int(self.numValid / self.conf.val_batch_size)
        for train_step in range(self.conf.reload_step, self.conf.reload_step +
                                self.conf.max_step + 1):
            x_batch, y_batch = self.data_reader.next_batch(mode='train')
            feed_dict = {
                self.inputs_pl: x_batch,
                self.labels_pl: y_batch,
                self.is_training_pl: True,
                self.with_dropout_pl: True,
                self.keep_prob_pl: self.conf.keep_prob
            }
            if train_step % self.conf.SUMMARY_FREQ == 0:
                _, _, _, summary = self.sess.run([
                    self.train_op, self.mean_loss_op, self.mean_accuracy_op,
                    self.merged_summary
                ],
                                                 feed_dict=feed_dict)
                loss, acc = self.sess.run([self.mean_loss, self.mean_accuracy])
                print('step: {0:<6}, train_loss= {1:.4f}, train_acc={2:.01%}'.
                      format(train_step, loss, acc))
                self.save_summary(summary, train_step, is_train=True)
            else:
                self.sess.run(
                    [self.train_op, self.mean_loss_op, self.mean_accuracy_op],
                    feed_dict=feed_dict)
            if train_step % self.conf.VAL_FREQ == 0:
                print('-' * 25 + 'Validation' + '-' * 25)
                self.normal_evaluate(dataset='valid', train_step=train_step)
Beispiel #2
0
    def test(self, step_num):
        self.sess.run(tf.local_variables_initializer())

        print('loading the model.......')
        self.reload(step_num)
        if self.conf.data == 'ct':
            from DataLoaders.Data_Loader_2D import DataLoader
        elif self.conf.data == 'camvid':
            from DataLoaders.CamVid_loader import DataLoader
        else:
            print('wrong data name')

        self.data_reader = DataLoader(self.conf)
        self.numTest = self.data_reader.count_num_samples(mode='test')
        self.num_test_batch = int(self.numTest / self.conf.val_batch_size)

        print('-' * 25 + 'Test' + '-' * 25)
        if not self.conf.bayes:
            self.normal_evaluate(dataset='test', train_step=step_num)
        else:
            self.MC_evaluate(dataset='test', train_step=step_num)
Beispiel #3
0
class BaseModel(object):
    def __init__(self, sess, conf):
        self.sess = sess
        self.conf = conf
        self.input_shape = [None, None, None, self.conf.channel]
        self.output_shape = [None, None, None]
        self.create_placeholders()

    def create_placeholders(self):
        with tf.name_scope('Input'):
            self.inputs_pl = tf.placeholder(tf.float32,
                                            self.input_shape,
                                            name='input')
            self.labels_pl = tf.placeholder(tf.int64,
                                            self.output_shape,
                                            name='annotation')
            self.is_training_pl = tf.placeholder(tf.bool, name="is_training")
            self.with_dropout_pl = tf.placeholder(tf.bool, name="with_dropout")
            self.keep_prob_pl = tf.placeholder(tf.float32)

    def loss_func(self):
        with tf.name_scope('Loss'):
            self.y_prob = tf.nn.softmax(self.logits, axis=-1)
            y_one_hot = tf.one_hot(self.labels_pl,
                                   depth=self.conf.num_cls,
                                   axis=3,
                                   name='y_one_hot')
            if self.conf.weighted_loss:
                loss = weighted_cross_entropy(y_one_hot,
                                              self.logits,
                                              self.conf.num_cls,
                                              data=self.conf.data)
            else:
                if self.conf.loss_type == 'cross-entropy':
                    with tf.name_scope('cross_entropy'):
                        loss = cross_entropy(y_one_hot, self.logits,
                                             self.conf.num_cls)
                elif self.conf.loss_type == 'dice':
                    with tf.name_scope('dice_coefficient'):
                        loss = dice_coeff(y_one_hot, self.logits)
            with tf.name_scope('total'):
                if self.conf.use_reg:
                    with tf.name_scope('L2_loss'):
                        l2_loss = tf.reduce_sum(self.conf.lmbda * tf.stack([
                            tf.nn.l2_loss(v)
                            for v in tf.get_collection('weights')
                        ]))
                        self.total_loss = loss + l2_loss
                else:
                    self.total_loss = loss
                self.mean_loss, self.mean_loss_op = tf.metrics.mean(
                    self.total_loss)

    def accuracy_func(self):
        with tf.name_scope('Accuracy'):
            self.y_pred = tf.argmax(self.logits, axis=3, name='decode_pred')
            correct_prediction = tf.equal(self.labels_pl,
                                          self.y_pred,
                                          name='correct_pred')
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),
                                      name='accuracy_op')
            self.mean_accuracy, self.mean_accuracy_op = tf.metrics.mean(
                accuracy)

    def configure_network(self):
        self.loss_func()
        self.accuracy_func()
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        learning_rate = tf.train.exponential_decay(self.conf.init_lr,
                                                   global_step,
                                                   decay_steps=2000,
                                                   decay_rate=0.99,
                                                   staircase=True)
        self.learning_rate = tf.maximum(learning_rate, self.conf.lr_min)
        with tf.name_scope('Optimizer'):
            optimizer = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train_op = optimizer.minimize(self.total_loss,
                                                   global_step=global_step)
        self.sess.run(tf.global_variables_initializer())
        trainable_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(var_list=trainable_vars, max_to_keep=1000)
        self.train_writer = tf.summary.FileWriter(
            self.conf.logdir + self.conf.run_name + '/train/', self.sess.graph)
        self.valid_writer = tf.summary.FileWriter(self.conf.logdir +
                                                  self.conf.run_name +
                                                  '/valid/')
        self.configure_summary()
        print('*' * 50)
        print('Total number of trainable parameters: {}'.format(
            np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])))
        print('*' * 50)

    def configure_summary(self):
        summary_list = [
            tf.summary.scalar('learning_rate', self.learning_rate),
            tf.summary.scalar('loss', self.mean_loss),
            tf.summary.scalar('accuracy', self.mean_accuracy),
            tf.summary.image('train/original_image',
                             self.inputs_pl,
                             max_outputs=5),
            tf.summary.image('train/prediction_mask',
                             tf.cast(tf.expand_dims(self.y_pred, -1),
                                     tf.float32),
                             max_outputs=5),
            tf.summary.image('train/original_mask',
                             tf.cast(tf.expand_dims(self.labels_pl, -1),
                                     tf.float32),
                             max_outputs=5)
        ]
        self.merged_summary = tf.summary.merge(summary_list)

    def save_summary(self, summary, step, is_train):
        # print('----> Summarizing at step {}'.format(step))
        if is_train:
            self.train_writer.add_summary(summary, step)
        else:
            self.valid_writer.add_summary(summary, step)
        self.sess.run(tf.local_variables_initializer())

    def train(self):
        self.sess.run(tf.local_variables_initializer())
        self.best_validation_loss = 0.446
        self.best_mean_IOU = 0.16
        if self.conf.reload_step > 0:
            self.reload(self.conf.reload_step)
            print('----> Continue Training from step #{}'.format(
                self.conf.reload_step))
        else:
            print('----> Start Training')

        if self.conf.data == 'cityscapes':
            self.trainiter = CityscapesDataset(which_set='train',
                                               batch_size=self.conf.batch_size,
                                               seq_per_subset=0,
                                               seq_length=0,
                                               data_augm_kwargs={
                                                   'crop_size':
                                                   self.conf.crop_size,
                                                   'horizontal_flip': 0.5
                                               },
                                               return_one_hot=False,
                                               return_01c=True,
                                               use_threads=True,
                                               return_list=True,
                                               nthreads=8)
        else:
            print('wrong data name')

        for train_step in range(self.conf.reload_step, self.conf.reload_step +
                                self.conf.max_step + 1):
            train_data = self.trainiter.__next__()
            feed_dict = {
                self.inputs_pl: train_data[0],
                self.labels_pl: train_data[1],
                self.is_training_pl: True,
                self.with_dropout_pl: True,
                self.keep_prob_pl: self.conf.keep_prob
            }
            if train_step % self.conf.SUMMARY_FREQ == 0:
                _, _, _, summary = self.sess.run([
                    self.train_op, self.mean_loss_op, self.mean_accuracy_op,
                    self.merged_summary
                ],
                                                 feed_dict=feed_dict)
                loss, acc = self.sess.run([self.mean_loss, self.mean_accuracy])
                print('step: {0:<6}, train_loss= {1:.4f}, train_acc={2:.01%}'.
                      format(train_step, loss, acc))
                self.save_summary(summary, train_step, is_train=True)
            else:
                self.sess.run(
                    [self.train_op, self.mean_loss_op, self.mean_accuracy_op],
                    feed_dict=feed_dict)
            if train_step % self.conf.VAL_FREQ == 0 and train_step:
                print('-' * 25 + 'Validation' + '-' * 25)
                self.normal_evaluate(dataset='valid', train_step=train_step)

    def test(self, step_num):
        self.sess.run(tf.local_variables_initializer())

        print('loading the model.......')
        self.reload(step_num)
        if self.conf.data == 'ct':
            from DataLoaders.Data_Loader_2D import DataLoader
        elif self.conf.data == 'camvid':
            from DataLoaders.CamVid_loader import DataLoader
        else:
            print('wrong data name')

        self.data_reader = DataLoader(self.conf)
        self.numTest = self.data_reader.count_num_samples(mode='test')
        self.num_test_batch = int(self.numTest / self.conf.val_batch_size)

        print('-' * 25 + 'Test' + '-' * 25)
        if not self.conf.bayes:
            self.normal_evaluate(dataset='test', train_step=step_num)
        else:
            self.MC_evaluate(dataset='test', train_step=step_num)
        # self.visualize(num_samples=20, train_step=step_num, mode='test')

    def save(self, step):
        print('----> Saving the model at step #{0}'.format(step))
        checkpoint_path = os.path.join(self.conf.modeldir + self.conf.run_name,
                                       self.conf.model_name)
        self.saver.save(self.sess, checkpoint_path, global_step=step)

    def reload(self, step):
        checkpoint_path = os.path.join(self.conf.modeldir + self.conf.run_name,
                                       self.conf.model_name)
        model_path = checkpoint_path + '-' + str(step)
        if not os.path.exists(model_path + '.meta'):
            print('----> No such checkpoint found', model_path)
            return
        print('----> Restoring the model...')
        self.saver.restore(self.sess, model_path)
        print('----> Model successfully restored')

    def normal_evaluate(self, dataset='valid', train_step=None):
        valiter = CityscapesDataset(which_set='val',
                                    batch_size=self.conf.val_batch_size,
                                    seq_per_subset=0,
                                    seq_length=0,
                                    return_one_hot=False,
                                    return_01c=True,
                                    use_threads=True,
                                    return_list=True,
                                    nthreads=8,
                                    infinite_iterator=False)
        num_batch = self.num_test_batch if dataset == 'test' else 500
        self.sess.run(tf.local_variables_initializer())
        hist = np.zeros((self.conf.num_cls, self.conf.num_cls))
        plot_inputs = np.zeros(
            (0, self.conf.height, self.conf.width, self.conf.channel))
        plot_mask = np.zeros((0, self.conf.height, self.conf.width))
        plot_mask_pred = np.zeros((0, self.conf.height, self.conf.width))
        for step in range(num_batch):
            valid_data = valiter.__next__()
            feed_dict = {
                self.inputs_pl: valid_data[0],
                self.labels_pl: valid_data[1],
                self.is_training_pl: True,
                self.with_dropout_pl: False,
                self.keep_prob_pl: 1
            }
            self.sess.run([self.mean_loss_op, self.mean_accuracy_op],
                          feed_dict=feed_dict)
            mask_pred = self.sess.run(self.y_pred, feed_dict=feed_dict)
            hist += get_hist(mask_pred.flatten(),
                             valid_data[1].flatten(),
                             num_cls=self.conf.num_cls)
            if plot_inputs.shape[
                    0] < 20:  # randomly select a few slices to plot and save
                # idx = np.random.randint(self.conf.batch_size)
                plot_inputs = np.concatenate(
                    (plot_inputs, valid_data[0].reshape(
                        -1, self.conf.height, self.conf.width,
                        self.conf.channel)),
                    axis=0)
                plot_mask = np.concatenate((plot_mask, valid_data[1].reshape(
                    -1, self.conf.height, self.conf.width)),
                                           axis=0)
                plot_mask_pred = np.concatenate(
                    (plot_mask_pred,
                     mask_pred.reshape(-1, self.conf.height, self.conf.width)),
                    axis=0)
                # self.visualize(plot_inputs, plot_mask, plot_mask_pred, train_step=train_step, mode='valid')

        IOU, ACC = compute_iou(hist)
        mean_IOU = np.mean(IOU)
        loss, acc = self.sess.run([self.mean_loss, self.mean_accuracy])
        if dataset == "valid":  # save the summaries and improved model in validation mode
            summary_valid = self.sess.run(self.merged_summary,
                                          feed_dict=feed_dict)
            self.save_summary(summary_valid, train_step, is_train=False)
            if loss < self.best_validation_loss:
                self.best_validation_loss = loss
                if mean_IOU > self.best_mean_IOU:
                    self.best_mean_IOU = mean_IOU
                    print(
                        '>>>>>>>> Both model validation loss and mean IOU improved; saving the model......'
                    )
                else:
                    print(
                        '>>>>>>>> model validation loss improved; saving the model......'
                    )
                self.save(train_step)
            elif mean_IOU > self.best_mean_IOU:
                self.best_mean_IOU = mean_IOU
                print(
                    '>>>>>>>> model mean IOU improved; saving the model......')
                self.save(train_step)

        print('****** IoU & ACC ******')
        print('Mean IoU = {0:.01%}, valid_loss = {1:.4f}'.format(
            mean_IOU, loss))
        for ii in range(self.conf.num_cls):
            print('     - {0:<15}: IoU={1:<5.01%}, ACC={2:<5.01%}'.format(
                self.conf.label_name[ii], IOU[ii], ACC[ii]))
        print('-' * 20)
        self.visualize(plot_inputs,
                       plot_mask,
                       plot_mask_pred,
                       train_step=train_step,
                       mode='valid')

    def MC_evaluate(self, dataset='valid', train_step=None):
        num_batch = self.num_test_batch if dataset == 'test' else self.num_val_batch
        hist = np.zeros((self.conf.num_cls, self.conf.num_cls))
        self.sess.run(tf.local_variables_initializer())
        all_inputs = np.zeros(
            (0, self.conf.height, self.conf.width, self.conf.channel))
        all_mask = np.zeros((0, self.conf.height, self.conf.width))
        all_pred = np.zeros((0, self.conf.height, self.conf.width))
        all_var = np.zeros((0, self.conf.height, self.conf.width))
        cls_uncertainty = np.zeros(
            (0, self.conf.height, self.conf.width, self.conf.num_cls))
        for step in tqdm(range(num_batch)):
            start = self.conf.val_batch_size * step
            end = self.conf.val_batch_size * (step + 1)
            data_x, data_y = self.data_reader.next_batch(start=start,
                                                         end=end,
                                                         mode=dataset)
            mask_pred_mc = [
                np.zeros((self.conf.val_batch_size, self.conf.height,
                          self.conf.width))
                for _ in range(self.conf.monte_carlo_simulations)
            ]
            mask_prob_mc = [
                np.zeros((self.conf.val_batch_size, self.conf.height,
                          self.conf.width, self.conf.num_cls))
                for _ in range(self.conf.monte_carlo_simulations)
            ]
            feed_dict = {
                self.inputs_pl: data_x,
                self.labels_pl: data_y,
                self.is_training_pl: True,
                self.with_dropout_pl: True,
                self.keep_prob_pl: self.conf.keep_prob
            }
            for mc_iter in range(self.conf.monte_carlo_simulations):
                inputs, mask, mask_prob, mask_pred = self.sess.run(
                    [self.inputs_pl, self.labels_pl, self.y_prob, self.y_pred],
                    feed_dict=feed_dict)
                mask_prob_mc[mc_iter] = mask_prob
                mask_pred_mc[mc_iter] = mask_pred

            prob_mean = np.nanmean(mask_prob_mc, axis=0)
            prob_variance = np.var(mask_prob_mc, axis=0)
            pred = np.argmax(prob_mean, axis=-1)
            # var_one = np.nanmean(prob_variance, axis=-1)
            # var_one = var_calculate_2d(pred, prob_variance)
            # var_one = predictive_entropy(prob_mean)
            var_one = mutual_info(prob_mean, mask_prob_mc)
            hist += get_hist(pred.flatten(),
                             mask.flatten(),
                             num_cls=self.conf.num_cls)

            # if all_inputs.shape[0] < 6:
            # ii = np.random.randint(self.conf.val_batch_size)
            # ii = 1
            all_inputs = np.concatenate(
                (all_inputs,
                 inputs.reshape(-1, self.conf.height, self.conf.width,
                                self.conf.channel)),
                axis=0)
            all_mask = np.concatenate(
                (all_mask, mask.reshape(-1, self.conf.height,
                                        self.conf.width)),
                axis=0)
            all_pred = np.concatenate(
                (all_pred, pred.reshape(-1, self.conf.height,
                                        self.conf.width)),
                axis=0)
            all_var = np.concatenate(
                (all_var, var_one.reshape(-1, self.conf.height,
                                          self.conf.width)),
                axis=0)
            cls_uncertainty = np.concatenate(
                (cls_uncertainty,
                 prob_variance.reshape(-1, self.conf.height, self.conf.width,
                                       self.conf.num_cls)),
                axis=0)
        self.visualize(all_inputs,
                       all_mask,
                       all_pred,
                       all_var,
                       cls_uncertainty,
                       train_step=train_step,
                       mode='test')
        import h5py
        h5f = h5py.File(self.conf.run_name + '_MI.h5', 'w')
        h5f.create_dataset('x', data=all_inputs)
        h5f.create_dataset('y', data=all_mask)
        h5f.create_dataset('y_pred', data=all_pred)
        h5f.create_dataset('y_var', data=all_var)
        h5f.create_dataset('cls_uncertainty', data=cls_uncertainty)
        h5f.close()

        # h5f = h5py.File(self.conf.run_name + '_bayes.h5', 'r')
        # all_mask = h5f['y'][:]
        # all_pred = h5f['y_pred'][:]
        # all_var = h5f['y_var'][:]
        # h5f.close()
        uncertainty_measure = get_uncertainty_precision(
            all_mask, all_pred, all_var)

        # break
        IOU, ACC = compute_iou(hist)
        mean_IOU = np.mean(IOU)
        print('****** IoU & ACC ******')
        print('Uncertainty Quality Measure = {}'.format(uncertainty_measure))
        print('Mean IoU = {0:.01%}'.format(mean_IOU))
        for ii in range(self.conf.num_cls):
            print('     - {0} class: IoU={1:.01%}, ACC={2:.01%}'.format(
                self.conf.label_name[ii], IOU[ii], ACC[ii]))
        print('-' * 20)

    def visualize(self,
                  x,
                  y,
                  y_pred,
                  var=None,
                  cls_uncertainty=None,
                  train_step=None,
                  mode='valid'):
        # all of shape (#images, 512, 512)
        if mode == 'valid':
            dest_path = os.path.join(self.conf.imagedir + self.conf.run_name,
                                     str(train_step))
        elif mode == "test":
            dest_path = os.path.join(self.conf.imagedir + self.conf.run_name,
                                     str(train_step) + '_test_MI')

        print('saving sample prediction images....... ')
        cls_uncertainty = None
        if not self.conf.bayes or mode == 'valid':
            # run it either in validation mode or when non-bayesian network
            plot_save_preds_2d(x,
                               y,
                               y_pred,
                               path=dest_path,
                               label_names=np.array(self.conf.label_name))
        else:
            if cls_uncertainty is None:
                plot_save_preds_2d(x,
                                   y,
                                   y_pred,
                                   var,
                                   path=dest_path,
                                   label_names=np.array(self.conf.label_name))
            else:
                plot_save_preds_2d(x,
                                   y,
                                   y_pred,
                                   var,
                                   cls_uncertainty,
                                   path=dest_path,
                                   label_names=np.array(self.conf.label_name))
        print('Images saved in {}'.format(dest_path))
        print('-' * 20)