コード例 #1
0
                # allow_growth=True,  # it will cause fragmentation.
                per_process_gpu_memory_fraction=0.1)))
    graph_def = graph.as_graph_def()

    with sess.graph.as_default():

        sess.run(tf.global_variables_initializer())
        tf.keras.backend.set_session(session=sess)
        # with tf.gfile.GFile(COMPILE_MODEL_PATH.replace('.pb', '_{}.pb'.format(int(0.95 * 10000))), "rb") as f:
        #     graph_def_file = f.read()
        # graph_def.ParseFromString(graph_def_file)
        # print('{}.meta'.format(tf_checkpoint))

        model = NeuralNetwork(model_conf, RunMode.Predict, model_conf.neu_cnn,
                              model_conf.neu_recurrent)
        model.build_graph()

        saver = tf.train.Saver(var_list=tf.global_variables())
        """从项目中加载最后一次训练的网络参数"""
        saver.restore(sess,
                      tf.train.latest_checkpoint(model_conf.model_root_path))

        # _ = tf.import_graph_def(graph_def, name="")
    """定义操作符"""
    dense_decoded_op = sess.graph.get_tensor_by_name("dense_decoded:0")
    x_op = sess.graph.get_tensor_by_name('input:0')
    """固定网络"""
    sess.graph.finalize()

    true_count = 0
    false_count = 0
コード例 #2
0
    def testing(self, image_dir, limit=None):

        graph = tf.Graph()
        sess = tf.Session(
            graph=graph,
            config=tf.ConfigProto(
                # allow_soft_placement=True,
                # log_device_placement=True,
                gpu_options=tf.GPUOptions(
                    allocator_type='BFC',
                    # allow_growth=True,  # it will cause fragmentation.
                    per_process_gpu_memory_fraction=0.1)))

        with sess.graph.as_default():

            sess.run(tf.global_variables_initializer())
            tf.keras.backend.set_session(session=sess)

            model = NeuralNetwork(self.model_conf, RunMode.Predict,
                                  self.model_conf.neu_cnn,
                                  self.model_conf.neu_recurrent)
            model.build_graph()

            saver = tf.train.Saver(var_list=tf.global_variables())
            """从项目中加载最后一次训练的网络参数"""
            saver.restore(
                sess,
                tf.train.latest_checkpoint(self.model_conf.model_root_path))

            # _ = tf.import_graph_def(graph_def, name="")
        """定义操作符"""
        dense_decoded_op = sess.graph.get_tensor_by_name("dense_decoded:0")
        x_op = sess.graph.get_tensor_by_name('input:0')
        """固定网络"""
        sess.graph.finalize()

        true_count = 0
        false_count = 0
        """
        以下为根据路径调用预测函数输出结果的demo
        """
        # Fill in your own sample path
        dir_list = os.listdir(image_dir)
        random.shuffle(dir_list)
        lines = []
        for i, p in enumerate(dir_list):
            n = os.path.join(image_dir, p)
            if limit and i > limit:
                break
            with open(n, "rb") as f:
                b = f.read()

            batch = self.get_image_batch(b)
            st = time.time()
            predict_text = self.predict_func(
                batch,
                sess,
                dense_decoded_op,
                x_op,
            )
            et = time.time()
            # t = p.split(".")[0].lower() == predict_text.lower()
            # csv_output = "{},{}".format(p.split(".")[0], predict_text)
            # lines.append(csv_output)
            # print(csv_output)
            is_mark = '_' in p
            if is_mark:
                if 'LOWER' in self.model_conf.category_param:
                    label = p.split("_")[0].lower()
                    t = label == predict_text.lower()
                elif 'UPPER' in self.model_conf.category_param:
                    label = p.split("_")[0].upper()
                    t = label == predict_text.upper()
                else:
                    label = p.split("_")[0]
                    t = label == predict_text
            # Used to verify test sets
                if t:
                    true_count += 1
                else:
                    false_count += 1
                print(i, p, label, predict_text, t,
                      true_count / (true_count + false_count),
                      (et - st) * 1000)
            else:
                print(i, p, predict_text,
                      true_count / (true_count + false_count),
                      (et - st) * 1000)