Exemplo n.º 1
0
def write_socres(model_path, checkpoint_file, meta_file):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        config = CNNConfig()
        cnn = TextCNN(config)
        cnn.prepare_data()

        # 读取模型
        checkpoint_dir = os.path.abspath(model_path)
        checkpoint_file = os.path.join(checkpoint_dir, checkpoint_file)
        saver = tf.train.import_meta_graph(
            os.path.join(checkpoint_dir, meta_file))
        saver.restore(sess, checkpoint_file)
        graph = tf.get_default_graph()

        # 从图中读取变量
        input_x = graph.get_operation_by_name("input_x").outputs[0]
        dropout_keep_prob = graph.get_operation_by_name(
            "dropout_keep_prob").outputs[0]
        # prediction = graph.get_operation_by_name("output/prediction").outputs[0]
        score = graph.get_operation_by_name("output/score").outputs[0]
        training = graph.get_operation_by_name("training").outputs[0]

        train_init_op, valid_init_op, next_train_element, next_valid_element = cnn.shuffle_datset(
        )
        sess.run(train_init_op)

        label = preprocess.read_label(
            os.path.join('data', preprocess.LABEL_ID_PATH))
        while True:
            try:
                lines = sess.run(next_train_element)
                titles, batch_y = cnn.convert_input(lines)
                batch_x = []
                if cnn.train_mode == 'CHAR-RANDOM' or cnn.train_mode == 'WORD-NON-STATIC':
                    for title in titles:
                        batch_x.append(
                            preprocess.to_id(title, cnn.vocab, cnn.train_mode))
                batch_x = np.stack(batch_x)

                feed_dict = {
                    input_x: batch_x,
                    dropout_keep_prob: 1.0,
                    training: False
                }
                scr = sess.run(score, feed_dict)
                print(scr)
                break

            except tf.errors.OutOfRangeError:
                # 初始化验证集迭代器
                sess.run(valid_init_op)
                # 计算验证集准确率
                # valid_step(next_valid_element)
                break
Exemplo n.º 2
0
 def predict(self, titles):
     # 自定义批量查询
     # ================================================================================
     batch_x = []
     # 1.id
     for title in titles:
         batch_x.append(preprocess.to_id(title, self.vocab, self.pred_mode))
     batch_x = np.stack(batch_x)
     pre = self.predictStep(batch_x)
     results = [self.label[x] for x in pre]
     final_results = list(zip(titles, results))
     return final_results
Exemplo n.º 3
0
    def convert_test_input(self, titles):
        """
        将测试集tsv数据转为id或词向量表示
        :param titles:
        :return:
        """
        batch_x = []
        # 1.id
        for title in titles:
            valid_title = title.decode('gb18030').strip('\t')
            batch_x.append(preprocess.to_id(valid_title, self.vocab, self.train_mode))

        batch_x = np.stack(batch_x)
        return batch_x
Exemplo n.º 4
0
    def convert_input(self, lines):
        """
        将训练集数据转换为id或词向量表示
        """
        batch_x = []
        batch_y = []
        title = ""
        # 1.id
        for line in lines:
            line_ = line.decode("gbk").strip().split(',')
            title = ''.join(line_[0:-1])    # 逗号前段为标题
            label = ''.join(line_[-1])      # 最后一项为标签
            batch_x.append(preprocess.to_id(title, self.vocab, self.train_mode))
            batch_y.append(label)

        batch_x = np.stack(batch_x)
        return batch_x, batch_y
Exemplo n.º 5
0
def get_score(model_path, checkpoint_file, meta_file, titles):
    '''
    根据给定模型和标题,返回模型softmax层得分[None, 1258]
    '''
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        checkpoint_dir = os.path.abspath(model_path)
        checkpoint_file = os.path.join(checkpoint_dir, checkpoint_file)
        saver = tf.train.import_meta_graph(
            os.path.join(checkpoint_dir, meta_file))
        saver.restore(sess, checkpoint_file)
        graph = tf.get_default_graph()

        config = CNNConfig()
        model = cnn(config)
        # 读取测试集及词汇表数据
        dataset, next_element = model.prepare_test_data()

        # 从图中读取变量
        input_x = graph.get_operation_by_name("input_x").outputs[0]
        dropout_keep_prob = graph.get_operation_by_name(
            "dropout_keep_prob").outputs[0]
        # prediction = graph.get_operation_by_name("output/prediction").outputs[0]
        score = graph.get_operation_by_name("output/score").outputs[0]
        training = graph.get_operation_by_name("training").outputs[0]

        label = preprocess.read_label(
            os.path.join('data', preprocess.LABEL_ID_PATH))
        batch_x = []
        if model.train_mode == 'CHAR-RANDOM' or model.train_mode == 'WORD-NON-STATIC':
            # 1.id
            for title in titles:
                batch_x.append(
                    preprocess.to_id(title, model.vocab, model.train_mode))
        batch_x = np.stack(batch_x)

        feed_dict = {input_x: batch_x, dropout_keep_prob: 1.0, training: False}
        scr = sess.run(score, feed_dict)
        # pre = sess.run(prediction, feed_dict)
        return scr
Exemplo n.º 6
0
    def predict_train(self):
        # 给训练集打标签
        rf = open('data/train.tsv', mode='r', encoding='gbk', errors='ignore')
        wf = open('data/train_predict.tsv', mode='w', encoding='gbk')

        # 跳过标题
        wf.write(rf.readline().strip() + '\n')
        while True:
            line = rf.readline()
            if line == '':
                break
            title, label = line.strip().split('\t')
            batch_x = np.asarray(
                [preprocess.to_id(title, self.vocab, self.pred_mode)],
                dtype=np.int32)
            pre = self.predictStep(batch_x)
            result = self.label[pre[0]]
            wf.write(title + '\t' + label + '\t' + result + '\n')

        rf.close()
        wf.close()