コード例 #1
0
 def __init__(self, config):
     self.config = config
     self.model = TextCNN(self.config)
     self.session = tf.Session()
     self.session.run(tf.global_variables_initializer())
     saver = tf.train.Saver()
     saver.restore(sess=self.session,
                   save_path=os.path.join(config.model_save_dir,
                                          config.model_file_prefix))
コード例 #2
0
ファイル: main_cnn.py プロジェクト: shxliang/text-cnn-rnn
def main_test():
    """
    模型测试
    :return:
    """
    id_to_word, word_to_id = read_vocab(FLAGS.vocab_dir)
    id_to_cat, cat_to_id = read_category(FLAGS.category_dir)
    config = load_config(FLAGS.config_file)
    model = TextCNN(config)

    test(model, config, word_to_id, cat_to_id, id_to_cat)
コード例 #3
0
ファイル: predict.py プロジェクト: shxliang/text-cnn-rnn
    def __init__(self):
        self.config = TCNNConfig()
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.config.vocab_size = len(self.words)
        self.model = TextCNN(self.config)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
コード例 #4
0
ファイル: main_cnn.py プロジェクト: shxliang/text-cnn-rnn
def save_cnn_for_java():
    """
    保存模型用于Java加载调用
    :return:
    """
    config = load_config(FLAGS.config_file)
    model = TextCNN(config)

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=session, save_path=FLAGS.save_dir)  # 读取保存的模型

        builder = tf.saved_model.builder.SavedModelBuilder("tmp/cnn_model")
        builder.add_meta_graph_and_variables(
            session, [tf.saved_model.tag_constants.SERVING])
        builder.save()
コード例 #5
0
ファイル: main_cnn.py プロジェクト: shunsunsun/text-cnn-rnn
def main_train():
    """
    模型训练
    :return:
    """
    # 如果不存在词汇表则新建
    if not os.path.exists(FLAGS.vocab_dir):
        build_vocab(FLAGS.train_dir, FLAGS.vocab_dir, FLAGS.vocab_size)
    id_to_word, word_to_id = read_vocab(FLAGS.vocab_dir)

    if not os.path.exists(FLAGS.category_dir):
        build_category(FLAGS.train_dir, FLAGS.category_dir)
    id_to_cat, cat_to_id = read_category(FLAGS.category_dir)

    if os.path.isfile(FLAGS.config_file):
        config = load_config(FLAGS.config_file)
    else:
        config = create_cnn_config_model(FLAGS, id_to_word, id_to_cat)

    model = TextCNN(config)

    train(model, config, word_to_id, cat_to_id)
コード例 #6
0
    parser.add_argument("--base_dir",
                        type=str,
                        default="./dataset",
                        help="The Base Directory of Dataset")
    parser.add_argument("--data_dir",
                        type=str,
                        default="asks",
                        help="The Base Directory of Asks of 169 kang")
    parser.add_argument("--labels_json_file",
                        type=str,
                        default="labels.json",
                        help="json file of labels")
    parser.add_argument("--print_interval", type=int, default=100)
    parser.add_argument("--train_device", type=str, default="cpu")

    # global variables here
    FLAGS, unknown = parser.parse_known_args()
    TRAIN_FILE = os.path.join(FLAGS.base_dir, FLAGS.data_dir, 'train.txt')
    TEST_FILE = os.path.join(FLAGS.base_dir, FLAGS.data_dir, 'test.txt')
    VALID_FILE = os.path.join(FLAGS.base_dir, FLAGS.data_dir, 'val.txt')
    VOCAB_FILE = os.path.join(FLAGS.base_dir, FLAGS.data_dir, 'vocab.txt')
    LABELS = json.loads(
        open(os.path.join(FLAGS.base_dir, FLAGS.data_dir,
                          FLAGS.labels_json_file),
             encoding='utf-8').read())
    TEXT_CNN_CONFIG = TextCnnConfig()
    TEXT_CNN_CONFIG.train_device = "/{}:0".format(FLAGS.train_device)
    MODEL = TextCNN(TEXT_CNN_CONFIG)

    tf.app.run(main=main, argv=[sys.argv[0]] + unknown)