示例#1
0
    def __init__(self):
        self.config = TCNNConfig()
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.config.vocab_size = len(self.words)
        self.model = TextCNN(self.config)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
示例#2
0
    def __init__(self):
        self.config = BILSTMConfig()
        self.categories, self.cat_to_id = read_category()
        self.words = np.load('./datas/dict.npy')
        self.word_to_id = np.load('./datas/dict.npy').tolist()
        self.config.vocab_size = len(self.word_to_id.keys())
        self.model = BILSTMModel(self.config)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
示例#3
0
文件: predict.py 项目: reganzm/ai
    def __init__(self):
        self.config = TCNNConfig()
        self.categories, self.cat_to_id = read_category()
        self.words = np.load('./datas/all_phrase.npy')
        self.word_to_id = np.load('./datas/phrase_to_id.npy').tolist()
        self.config.vocab_size = len(self.words)
        self.model = TextCNN(self.config)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
示例#4
0
文件: run_cnn_v2.py 项目: reganzm/ai
    logger.info(
        metrics.classification_report(y_test_cls,
                                      y_pred_cls,
                                      target_names=categories))

    # 混淆矩阵
    logger.info("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    logger.info(cm)

    time_dif = get_time_dif(start_time)
    logger.info("Time usage:", time_dif)


if __name__ == '__main__':

    config = TCNNConfig()
    #if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
    #    build_vocab(train_dir, vocab_dir, config.vocab_size)
    categories, cat_to_id = read_category()
    #words, word_to_id = read_vocab(vocab_dir)
    words = np.load('./datas/dict_token.npy')
    word_to_id = np.load('./datas/token_to_id.npy').tolist()
    config.vocab_size = len(words)
    model = TextCNN(config)
    option = 'train'
    if option == 'train':
        train()
    else:
        test()