コード例 #1
0
    def visual_results_external_image(self, images, FLAG_MAX_VOTE = False):

        #train_dir = "./saved_models/segnet_vgg_bayes/segnet_vgg_bayes_30000/model.ckpt-30000"
        #train_dir = "./saved_models/segnet_scratch/segnet_scratch_30000/model.ckpt-30000"


        i_width = 64
        i_height =64
        images = [misc.imresize(image, (i_height, i_width)) for image in images]

        image_w = self.config["INPUT_WIDTH"]
        image_h = self.config["INPUT_HEIGHT"]
        image_c = self.config["INPUT_CHANNELS"]
        train_dir = self.config["SAVE_MODEL_DIR"]
        FLAG_BAYES = self.config["BAYES"]

        with self.sess as sess:

            # Restore saved session
            saver = tf.train.Saver()
            saver.restore(sess, train_dir)

            _, _, prediction = cal_loss(logits=self.logits,
                                           labels=self.labels_pl)
            prob = tf.nn.softmax(self.logits,dim = -1)

            num_sample_generate = 30
            pred_tot = []
            var_tot = []

            labels = []
            for i in range(len(images)):
                labels.append(np.array([[1 for x in range(64)] for y in range(64)]))


            inference_time = []
            start_time = time.time()

            for image_batch, label_batch in zip(images,labels):
            #for image_batch in zip(images):

                image_batch = np.reshape(image_batch,[1,image_h,image_w,image_c])
                label_batch = np.reshape(label_batch,[1,image_h,image_w,1])

                if FLAG_BAYES is False:
                    fetches = [prediction]
                    feed_dict = {self.inputs_pl: image_batch,
                                 self.labels_pl: label_batch,
                                 self.is_training_pl: False,
                                 self.keep_prob_pl: 0.5,
                                 self.batch_size_pl: 1}
                    pred = sess.run(fetches = fetches, feed_dict = feed_dict)
                    pred = np.reshape(pred,[image_h,image_w])
                    var_one = []
                else:
                    feed_dict = {self.inputs_pl: image_batch,
                                 self.labels_pl: label_batch,
                                 self.is_training_pl: False,
                                 self.keep_prob_pl: 0.5,
                                 self.with_dropout_pl: True,
                                 self.batch_size_pl: 1}
                    prob_iter_tot = []
                    pred_iter_tot = []
                    for iter_step in range(num_sample_generate):
                        prob_iter_step = sess.run(fetches = [prob], feed_dict = feed_dict)
                        prob_iter_tot.append(prob_iter_step)
                        pred_iter_tot.append(np.reshape(np.argmax(prob_iter_step,axis = -1),[-1]))

                    if FLAG_MAX_VOTE is True:
                        prob_variance,pred = MAX_VOTE(pred_iter_tot,prob_iter_tot,self.config["NUM_CLASSES"])
                        #acc_per = np.mean(np.equal(pred,np.reshape(label_batch,[-1])))
                        var_one = var_calculate(pred,prob_variance)
                        pred = np.reshape(pred,[image_h,image_w])
                    else:
                        prob_mean = np.nanmean(prob_iter_tot,axis = 0)
                        prob_variance = np.var(prob_iter_tot, axis = 0)
                        pred = np.reshape(np.argmax(prob_mean,axis = -1),[-1]) #pred is the predicted label with the mean of generated samples
                        #THIS TIME I DIDN'T INCLUDE TAU
                        var_one = var_calculate(pred,prob_variance)
                        pred = np.reshape(pred,[image_h,image_w])


                pred_tot.append(pred)
                var_tot.append(var_one)
                inference_time.append(time.time() - start_time)
                start_time = time.time()

            try:
                draw_plots_bayes_external(images, pred_tot, var_tot)
                return pred_tot, var_tot, inference_time
            except:
                return pred_tot, var_tot, inference_time
コード例 #2
0
ファイル: SegNet.py プロジェクト: skugele/SegNet-tensorflow
    def visual_results_external_image(self, images, model_file):

        images = [
            misc.imresize(image, (self.input_h, self.input_w))
            for image in images
        ]

        with tf.Session() as sess:

            # Restore saved session
            saver = tf.train.Saver()

            if model_file is None:
                saver.restore(sess,
                              tf.train.latest_checkpoint(FLAGS.runtime_dir))
            else:
                saver.restore(sess, os.path.join(FLAGS.runtime_dir,
                                                 model_file))

            _, _, prediction = cal_loss(logits=self.logits,
                                        labels=self.labels_pl,
                                        n_classes=self.n_classes)
            prob = tf.nn.softmax(self.logits, dim=-1)

            pred_tot = []
            var_tot = []

            labels = []
            for i in range(len(images)):
                labels.append(
                    np.array([[1 for x in range(self.input_w)]
                              for y in range(self.input_h)]))

            inference_time = []
            start_time = time.time()

            for image_batch, label_batch in zip(images, labels):
                image_batch = np.reshape(
                    image_batch, [1, self.input_h, self.input_w, self.input_c])
                label_batch = np.reshape(label_batch,
                                         [1, self.input_h, self.input_w, 1])

                fetches = [prediction]
                feed_dict = {
                    self.inputs_pl: image_batch,
                    self.labels_pl: label_batch,
                    self.is_training_pl: False,
                    self.keep_prob_pl: 0.5,
                    self.batch_size_pl: 1
                }
                pred = sess.run(fetches=fetches, feed_dict=feed_dict)
                pred = np.reshape(pred, [self.input_h, self.input_w])

                pred_tot.append(pred)
                inference_time.append(time.time() - start_time)
                start_time = time.time()

            try:
                draw_plots_bayes_external(images, pred_tot)
                return pred_tot, var_tot, inference_time
            except:
                return pred_tot, var_tot, inference_time