예제 #1
0
    def train(self, max_steps=30001, batch_size=3):
        # For train the bayes, the FLAG_OPT SHOULD BE SGD, BUT FOR TRAIN THE NORMAL SEGNET,
        # THE FLAG_OPT SHOULD BE ADAM!!!

        image_filename, label_filename = get_filename_list(self.train_file, self.config)
        val_image_filename, val_label_filename = get_filename_list(self.val_file, self.config)

        with self.graph.as_default():
            if self.images_tr is None:
                self.images_tr, self.labels_tr = dataset_inputs(image_filename, label_filename, batch_size, self.config)
                self.images_val, self.labels_val = dataset_inputs(val_image_filename, val_label_filename, batch_size,
                                                                  self.config)

            loss, accuracy, prediction = cal_loss(logits=self.logits, labels=self.labels_pl,
                                                     number_class=self.num_classes)
            train, global_step = train_op(total_loss=loss, opt=self.opt)

            summary_op = tf.summary.merge_all()

            with self.sess.as_default():
                self.sess.run(tf.local_variables_initializer())
                self.sess.run(tf.global_variables_initializer())

                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(coord=coord)
                # The queue runners basic reference:
                # https://www.tensorflow.org/versions/r0.12/how_tos/threading_and_queues
                train_writer = tf.summary.FileWriter(self.tb_logs, self.sess.graph)
                for step in range(max_steps):
                    print("OK")
                    image_batch, label_batch = self.sess.run([self.images_tr, self.labels_tr])
                    feed_dict = {self.inputs_pl: image_batch,
                                 self.labels_pl: label_batch,
                                 self.is_training_pl: True,
                                 self.keep_prob_pl: 0.5,
                                 self.with_dropout_pl: True,
                                 self.batch_size_pl: batch_size}

                    _, _loss, _accuracy, summary = self.sess.run([train, loss, accuracy, summary_op],
                                                                 feed_dict=feed_dict)
                    self.train_loss.append(_loss)
                    self.train_accuracy.append(_accuracy)
                    print("Iteration {}: Train Loss{:6.3f}, Train Accu {:6.3f}".format(step, self.train_loss[-1],
                                                                                       self.train_accuracy[-1]))

                    if step % 100 == 0:
                        conv_classifier = self.sess.run(self.logits, feed_dict=feed_dict)
                        print('per_class accuracy by logits in training time',
                              per_class_acc(conv_classifier, label_batch, self.num_classes))
                        # per_class_acc is a function from utils
                        train_writer.add_summary(summary, step)

                    if step % 1000 == 0:
                        print("start validating.......")
                        _val_loss = []
                        _val_acc = []
                        hist = np.zeros((self.num_classes, self.num_classes))
                        for test_step in range(int(20)):
                            fetches_valid = [loss, accuracy, self.logits]
                            image_batch_val, label_batch_val = self.sess.run([self.images_val, self.labels_val])
                            feed_dict_valid = {self.inputs_pl: image_batch_val,
                                               self.labels_pl: label_batch_val,
                                               self.is_training_pl: True,
                                               self.keep_prob_pl: 1.0,
                                               self.with_dropout_pl: False,
                                               self.batch_size_pl: batch_size}
                            # since we still using mini-batch, so in the batch norm we set phase_train to be
                            # true, and because we didin't run the trainop process, so it will not update
                            # the weight!
                            _loss, _acc, _val_pred = self.sess.run(fetches_valid, feed_dict_valid)
                            _val_loss.append(_loss)
                            _val_acc.append(_acc)
                            hist += get_hist(_val_pred, label_batch_val)

                        print_hist_summary(hist)

                        self.val_loss.append(np.mean(_val_loss))
                        self.val_acc.append(np.mean(_val_acc))

                        print(
                            "Iteration {}: Train Loss {:6.3f}, Train Acc {:6.3f}, Val Loss {:6.3f}, Val Acc {:6.3f}".format(
                                step, self.train_loss[-1], self.train_accuracy[-1], self.val_loss[-1],
                                self.val_acc[-1]))

                coord.request_stop()
                coord.join(threads)
예제 #2
0
    def test(self):
        image_filename, label_filename = get_filename_list(self.test_file, self.config)

        with self.graph.as_default():
            with self.sess as sess:
                loss, accuracy, prediction = normal_loss(self.logits, self.labels_pl, self.num_classes)
                prob = tf.nn.softmax(self.logits, dim=-1)
                prob = tf.reshape(prob, [self.input_h, self.input_w, self.num_classes])

                images, labels = get_all_test_data(image_filename, label_filename)

                NUM_SAMPLE = []
                for i in range(30):
                    NUM_SAMPLE.append(2 * i + 1)

                acc_final = []
                iu_final = []
                iu_mean_final = []
                # uncomment the line below to only run for two times.
                # NUM_SAMPLE = [1, 30]
                NUM_SAMPLE = [1]
                for num_sample_generate in NUM_SAMPLE:

                    loss_tot = []
                    acc_tot = []
                    pred_tot = []
                    var_tot = []
                    hist = np.zeros((self.num_classes, self.num_classes))
                    step = 0
                    for image_batch, label_batch in zip(images, labels):
                        image_batch = np.reshape(image_batch, [1, self.input_h, self.input_w, self.input_c])
                        label_batch = np.reshape(label_batch, [1, self.input_h, self.input_w, 1])
                        # comment the code below to apply the dropout for all the samples
                        if num_sample_generate == 1:
                            feed_dict = {self.inputs_pl: image_batch, self.labels_pl: label_batch,
                                         self.is_training_pl: False,
                                         self.keep_prob_pl: 0.5, self.with_dropout_pl: False,
                                         self.batch_size_pl: 1}
                        else:
                            feed_dict = {self.inputs_pl: image_batch, self.labels_pl: label_batch,
                                         self.is_training_pl: False,
                                         self.keep_prob_pl: 0.5, self.with_dropout_pl: True,
                                         self.batch_size_pl: 1}
                        # uncomment this code below to run the dropout for all the samples
                        # feed_dict = {test_data_tensor: image_batch, test_label_tensor:label_batch, phase_train: False, keep_prob:0.5, phase_train_dropout:True}
                        fetches = [loss, accuracy, self.logits, prediction]
                        if self.bayes is False:
                            loss_per, acc_per, logit, pred = sess.run(fetches=fetches, feed_dict=feed_dict)
                            var_one = []
                        else:
                            logit_iter_tot = []
                            loss_iter_tot = []
                            acc_iter_tot = []
                            prob_iter_tot = []
                            logit_iter_temp = []
                            for iter_step in range(num_sample_generate):
                                loss_iter_step, acc_iter_step, logit_iter_step, prob_iter_step = sess.run(
                                    fetches=[loss, accuracy, self.logits, prob], feed_dict=feed_dict)
                                loss_iter_tot.append(loss_iter_step)
                                acc_iter_tot.append(acc_iter_step)
                                logit_iter_tot.append(logit_iter_step)
                                prob_iter_tot.append(prob_iter_step)
                                logit_iter_temp.append(
                                    np.reshape(logit_iter_step, [self.input_h, self.input_w, self.num_classes]))

                            loss_per = np.nanmean(loss_iter_tot)
                            acc_per = np.nanmean(acc_iter_tot)
                            logit = np.nanmean(logit_iter_tot, axis=0)
                            print(np.shape(prob_iter_tot))

                            prob_mean = np.nanmean(prob_iter_tot, axis=0)
                            prob_variance = np.var(prob_iter_tot, axis=0)
                            logit_variance = np.var(logit_iter_temp, axis=0)

                            # THIS TIME I DIDN'T INCLUDE TAU
                            pred = np.reshape(np.argmax(prob_mean, axis=-1), [-1])  # pred is the predicted label

                            var_sep = []  # var_sep is the corresponding variance if this pixel choose label k
                            length_cur = 0  # length_cur represent how many pixels has been read for one images
                            for row in np.reshape(prob_variance, [self.input_h * self.input_w, self.num_classes]):
                                temp = row[pred[length_cur]]
                                length_cur += 1
                                var_sep.append(temp)
                            var_one = np.reshape(var_sep, [self.input_h,
                                                           self.input_w])  # var_one is the corresponding variance in terms of the "optimal" label
                            pred = np.reshape(pred, [self.input_h, self.input_w])

                        loss_tot.append(loss_per)
                        acc_tot.append(acc_per)
                        pred_tot.append(pred)
                        var_tot.append(var_one)
                        print("Image Index {}: TEST Loss{:6.3f}, TEST Accu {:6.3f}".format(step, loss_tot[-1], acc_tot[-1]))
                        step = step + 1
                        per_class_acc(logit, label_batch, self.num_classes)
                        hist += get_hist(logit, label_batch)

                    acc_tot = np.diag(hist).sum() / hist.sum()
                    iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))

                    print("Total Accuracy for test image: ", acc_tot)
                    print("Total MoI for test images: ", iu)
                    print("mean MoI for test images: ", np.nanmean(iu))

                    acc_final.append(acc_tot)
                    iu_final.append(iu)
                    iu_mean_final.append(np.nanmean(iu))

            return acc_final, iu_final, iu_mean_final, prob_variance, logit_variance, pred_tot, var_tot
예제 #3
0
    def test(self):
        image_filename, label_filename = get_filename_list(self.test_file)

        with tf.Session() as sess:
            # Restore saved session
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.runtime_dir))

            loss, accuracy, prediction = normal_loss(self.logits,
                                                     self.labels_pl,
                                                     self.n_classes)

            images, labels = get_all_test_data(image_filename, label_filename)

            NUM_SAMPLE = []
            for i in range(30):
                NUM_SAMPLE.append(2 * i + 1)

            acc_final = []
            iu_final = []
            iu_mean_final = []
            # uncomment the line below to only run for two times.
            # NUM_SAMPLE = [1, 30]
            NUM_SAMPLE = [1]
            for num_sample_generate in NUM_SAMPLE:

                loss_tot = []
                acc_tot = []

                hist = np.zeros((self.n_classes, self.n_classes))
                step = 0
                for image_batch, label_batch in zip(images, labels):
                    image_batch = np.reshape(
                        image_batch,
                        [1, self.input_h, self.input_w, self.input_c])
                    label_batch = np.reshape(
                        label_batch, [1, self.input_h, self.input_w, 1])
                    # comment the code below to apply the dropout for all the samples
                    if num_sample_generate == 1:
                        feed_dict = {
                            self.inputs_pl: image_batch,
                            self.labels_pl: label_batch,
                            self.is_training_pl: False,
                            self.keep_prob_pl: 0.5,
                            self.with_dropout_pl: False,
                            self.batch_size_pl: 1
                        }
                    else:
                        feed_dict = {
                            self.inputs_pl: image_batch,
                            self.labels_pl: label_batch,
                            self.is_training_pl: False,
                            self.keep_prob_pl: 0.5,
                            self.with_dropout_pl: True,
                            self.batch_size_pl: 1
                        }

                    loss_per, acc_per, logit, pred = sess.run(
                        [loss, accuracy, self.logits, prediction],
                        feed_dict=feed_dict)

                    loss_tot.append(loss_per)
                    acc_tot.append(acc_per)
                    print(
                        "Image Index {}: TEST Loss{:6.3f}, TEST Accu {:6.3f}".
                        format(step, loss_tot[-1], acc_tot[-1]))
                    step = step + 1
                    per_class_acc(logit, label_batch, self.n_classes)
                    hist += get_hist(logit, label_batch)

                acc_tot = np.diag(hist).sum() / hist.sum()
                iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) -
                                      np.diag(hist))

                print("Total Accuracy for test image: ", acc_tot)
                print("Total MoI for test images: ", iu)
                print("mean MoI for test images: ", np.nanmean(iu))

                acc_final.append(acc_tot)
                iu_final.append(iu)
                iu_mean_final.append(np.nanmean(iu))

            return acc_final, iu_final, iu_mean_final