예제 #1
0
    def test_inference(self):
        """
        Like the testing function but this one is for calculate the inference time
        and measure the frame per second
        """
        print("INFERENCE mode will begin NOW..")

        # load the best model checkpoint to test on it
        self.load_best_model()

        # init tqdm and get the epoch value
        tt = tqdm(range(1000))  #self.test_data_len))

        # idx of image
        idx = 0

        # create the FPS Meter
        fps_meter = FPSMeter()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            flo_batch = self.test_data['Flo'][idx:idx + 1]
            y_batch = self.test_data['Y'][idx:idx + 1]

            # update idx of mini_batch
            #idx += 1

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {
                    self.test_model.x_pl_before: x_batch,
                    self.test_model.flo_pl_before: flo_batch,
                    self.test_model.y_pl_before: y_batch,
                    self.test_model.is_training: False,
                }
            else:
                feed_dict = {
                    self.test_model.x_pl: x_batch,
                    self.test_model.flo_pl: flo_batch,
                    self.test_model.y_pl: y_batch,
                    self.test_model.is_training: False
                }

            # calculate the time of one inference
            start = time.time()

            # run the feed_forward
            _ = self.sess.run([self.test_model.out_argmax],
                              feed_dict=feed_dict)

            # update the FPS meter
            fps_meter.update(time.time() - start)

        fps_meter.print_statistics()
예제 #2
0
    def realsense_inference(self):
        print("INFERENCE will begin NOW..")

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # create the FPS Meter
        fps_meter = FPSMeter()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            # y_batch = self.test_data['Y'][idx:idx + 1]

            # update idx of mini_batch
            idx += 1

            # Feed this variables to the network
            feed_dict = {self.model.x_pl: x_batch,
                         self.model.is_training: False}

            # calculate the time of one inference
            start = time.time()

            # run the feed_forward
            _ = self.sess.run(
                [self.model.out_argmax],
                feed_dict=feed_dict)

            # update the FPS meter
            fps_meter.update(time.time() - start)

        fps_meter.print_statistics()
예제 #3
0
    def test_inference(self):
        """
        Like the testing function but this one is for calculate the inference time
        and measure the frame per second
        """
        print("INFERENCE mode will begin NOW..")

        # load the best model checkpoint to test on it
        self.load_best_model()

        # output_node: network/output/Argmax
        # input_node: network/input/Placeholder
        #        for n in tf.get_default_graph().as_graph_def().node:
        #            if 'input' in n.name:#if 'Argmax' in n.name:
        #                import pdb; pdb.set_trace()
        print("Saving graph...")
        tf.train.write_graph(self.sess.graph_def, ".", 'graph.pb')
        print("Graph saved successfully.\n\n")
        exit(1)

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # create the FPS Meter
        fps_meter = FPSMeter()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            y_batch = self.test_data['Y'][idx:idx + 1]

            # update idx of mini_batch
            idx += 1

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {self.test_model.x_pl_before: x_batch,
                             self.test_model.y_pl_before: y_batch
                             #                             self.test_model.is_training: False,
                             }
            else:
                feed_dict = {self.test_model.x_pl: x_batch,
                             self.test_model.y_pl: y_batch
                             #                             self.test_model.is_training: False
                             }

            # calculate the time of one inference
            start = time.time()

            # run the feed_forward
            _ = self.sess.run(
                [self.test_model.out_argmax],
                feed_dict=feed_dict)

            # update the FPS meter
            fps_meter.update(time.time() - start)

        fps_meter.print_statistics()
예제 #4
0
    def test_eval(self, pkl=False):
        print("Testing mode will begin NOW..")

        # load the best model checkpoint to test on it
        if not pkl:
            self.load_best_model()

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # create the FPS Meter
        fps_meter = FPSMeter()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {self.test_model.x_pl_before: x_batch,
                             self.test_model.is_training: False,
                             }
            else:
                feed_dict = {self.test_model.x_pl: x_batch,
                             self.test_model.is_training: False
                             }
            
            start = time.time()

            # run the feed_forward
            out_argmax, segmented_imgs = self.sess.run(
                [self.test_model.out_argmax,
                 self.test_model.segmented_summary],
                feed_dict=feed_dict)

            fps_meter.update(time.time() - start)

            if pkl:
                out_argmax[0] = self.linknet_postprocess(out_argmax[0])
                segmented_imgs = decode_labels(out_argmax, 20)

            # Colored results for visualization
            colored_save_path = self.args.out_dir + 'imgs/' + (self.names_mapper['Y'][idx]).decode()
            if not os.path.exists(os.path.dirname(colored_save_path)):
                os.makedirs(os.path.dirname(colored_save_path))
            plt.imsave(colored_save_path, segmented_imgs[0])

            # Results for official evaluation
            save_path = self.args.out_dir + 'results/'+ (self.names_mapper['Y'][idx]).decode()
            if not os.path.exists(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path))
            output = postprocess(out_argmax[0])
            misc.imsave(save_path, misc.imresize(output, [1024, 2048], 'nearest'))

            idx += 1

        # print in console
        tt.close()

        fps_meter.print_statistics()