def main(_):
    # 1.load data with vocabulary of words and labels
    vocabulary_word2index, vocabulary_index2word = create_voabulary(simple='simple',
                                                                    word2vec_model_path=FLAGS.word2vec_model_path,
                                                                    name_scope="cnn2")
    vocab_size = len(vocabulary_word2index)
    vocabulary_word2index_label, vocabulary_index2word_label = create_voabulary_label(name_scope="cnn2")
    questionid_question_lists = load_final_test_data(FLAGS.predict_source_file)
    test = load_data_predict(vocabulary_word2index, vocabulary_word2index_label, questionid_question_lists)
    testX = []
    question_id_list = []
    for tuple in test:
        question_id, question_string_list = tuple
        question_id_list.append(question_id)
        testX.append(question_string_list)
    # 2.Data preprocessing: Sequence padding
    print("start padding....")
    testX2 = pad_sequences(testX, maxlen=FLAGS.sentence_len, value=0.)  # padding to max length
    print("end padding...")
    # 3.create session.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # 4.Instantiate Model
        textCNN = TextCNN(filter_sizes, FLAGS.num_filters, FLAGS.num_classes, FLAGS.learning_rate, FLAGS.batch_size,
                          FLAGS.decay_steps, FLAGS.decay_rate,
                          FLAGS.sentence_len, vocab_size, FLAGS.embed_size, FLAGS.is_training)
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            print("Can't find the checkpoint.going to stop")
            return
        # 5.feed data, to get logits
        number_of_training_data = len(testX2);
        print("number_of_training_data:", number_of_training_data)
        index = 0
        predict_target_file_f = codecs.open(FLAGS.predict_target_file, 'a', 'utf8')
        for start, end in zip(range(0, number_of_training_data, FLAGS.batch_size),
                              range(FLAGS.batch_size, number_of_training_data + 1, FLAGS.batch_size)):
            logits = sess.run(textCNN.logits, feed_dict={textCNN.input_x: testX2[start:end],
                                                         textCNN.dropout_keep_prob: 1})  # 'shape of logits:', ( 1, 1999)
            # 6. get lable using logtis
            predicted_labels = get_label_using_logits(logits[0], vocabulary_index2word_label)
            # 7. write question id and labels to file system.
            write_question_id_with_labels(question_id_list[index], predicted_labels, predict_target_file_f)
            index = index + 1
        predict_target_file_f.close()
def main(_):
    # 1.load data with vocabulary of words and labels
    vocabulary_word2index, vocabulary_index2word = create_voabulary(simple='simple',word2vec_model_path=FLAGS.word2vec_model_path,name_scope="cnn2")
    vocab_size = len(vocabulary_word2index)
    vocabulary_word2index_label, vocabulary_index2word_label = create_voabulary_label(name_scope="cnn2")
    questionid_question_lists=load_final_test_data(FLAGS.predict_source_file)
    test= load_data_predict(vocabulary_word2index,vocabulary_word2index_label,questionid_question_lists)
    testX=[]
    question_id_list=[]
    for tuple in test:
        question_id,question_string_list=tuple
        question_id_list.append(question_id)
        testX.append(question_string_list)
    # 2.Data preprocessing: Sequence padding
    print("start padding....")
    testX2 = pad_sequences(testX, maxlen=FLAGS.sentence_len, value=0.)  # padding to max length
    print("end padding...")
   # 3.create session.
    config=tf.ConfigProto()
    config.gpu_options.allow_growth=True
    with tf.Session(config=config) as sess:
        # 4.Instantiate Model
        textCNN=TextCNN(filter_sizes,FLAGS.num_filters,FLAGS.num_classes, FLAGS.learning_rate, FLAGS.batch_size, FLAGS.decay_steps,FLAGS.decay_rate,
                        FLAGS.sentence_len,vocab_size,FLAGS.embed_size,FLAGS.is_training)
        saver=tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
            print("Restoring Variables from Checkpoint")
            saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            print("Can't find the checkpoint.going to stop")
            return
        # 5.feed data, to get logits
        number_of_training_data=len(testX2);print("number_of_training_data:",number_of_training_data)
        index=0
        predict_target_file_f = codecs.open(FLAGS.predict_target_file, 'a', 'utf8')
        for start, end in zip(range(0, number_of_training_data, FLAGS.batch_size),range(FLAGS.batch_size, number_of_training_data+1, FLAGS.batch_size)):
            logits=sess.run(textCNN.logits,feed_dict={textCNN.input_x:testX2[start:end],textCNN.dropout_keep_prob:1}) #'shape of logits:', ( 1, 1999)
            # 6. get lable using logtis
            predicted_labels=get_label_using_logits(logits[0],vocabulary_index2word_label)
            # 7. write question id and labels to file system.
            write_question_id_with_labels(question_id_list[index],predicted_labels,predict_target_file_f)
            index=index+1
        predict_target_file_f.close()
Ejemplo n.º 3
0
    "target file path for final prediction")
tf.app.flags.DEFINE_string("predict_source_file",
                           'test-zhihu-forpredict-title-desc-v6.txt',
                           "target file path for final prediction"
                           )  #test-zhihu-forpredict-v4only-title.txt
tf.app.flags.DEFINE_string(
    "word2vec_model_path", "data/zhihu-word2vec-title-desc.bin-100",
    "word2vec's vocabulary and vectors")  #zhihu-word2vec.bin-100
tf.app.flags.DEFINE_integer("num_filters", 256, "number of filters")  #128

##############################################################################################################################################
filter_sizes = [1, 2, 3, 4, 5, 6, 7]  #[1,2,3,4,5,6,7]
#1.load data(X:list of lint,y:int). 2.create session. 3.feed data. 4.training (5.validation) ,(6.prediction)
# 1.load data with vocabulary of words and labels
vocabulary_word2index, vocabulary_index2word = create_voabulary(
    simple='simple',
    word2vec_model_path=FLAGS.word2vec_model_path,
    name_scope="cnn2")
vocab_size = len(vocabulary_word2index)
vocabulary_word2index_label, vocabulary_index2word_label = create_voabulary_label(
    name_scope="cnn2")
questionid_question_lists = load_final_test_data(FLAGS.predict_source_file)
test = load_data_predict(vocabulary_word2index, vocabulary_word2index_label,
                         questionid_question_lists)
testX = []
question_id_list = []
for tuple in test:
    question_id, question_string_list = tuple
    question_id_list.append(question_id)
    testX.append(question_string_list)
# 2.Data preprocessing: Sequence padding
print("start padding....")
tf.app.flags.DEFINE_integer("sentence_len",100,"max sentence length")
tf.app.flags.DEFINE_integer("embed_size",100,"embedding size")
tf.app.flags.DEFINE_boolean("is_training",False,"is traning.true:tranining,false:testing/inference")
tf.app.flags.DEFINE_integer("num_epochs",15,"number of epochs.")
tf.app.flags.DEFINE_integer("validate_every", 1, "Validate every validate_every epochs.") #每10轮做一次验证
tf.app.flags.DEFINE_string("predict_target_file","text_cnn_title_desc_checkpoint/zhihu_result_cnn_multilabel_v6_e14.csv","target file path for final prediction")
tf.app.flags.DEFINE_string("predict_source_file",'test-zhihu-forpredict-title-desc-v6.txt',"target file path for final prediction") #test-zhihu-forpredict-v4only-title.txt
tf.app.flags.DEFINE_string("word2vec_model_path","zhihu-word2vec-title-desc.bin-100","word2vec's vocabulary and vectors") #zhihu-word2vec.bin-100
tf.app.flags.DEFINE_integer("num_filters", 256, "number of filters") #128

##############################################################################################################################################
filter_sizes=[1,2,3,4,5,6,7]#[1,2,3,4,5,6,7]
#1.load data(X:list of lint,y:int). 2.create session. 3.feed data. 4.training (5.validation) ,(6.prediction)
# 1.load data with vocabulary of words and labels
vocabulary_word2index, vocabulary_index2word = create_voabulary(simple='simple',
                                                                word2vec_model_path=FLAGS.word2vec_model_path,
                                                                name_scope="cnn2")
vocab_size = len(vocabulary_word2index)
vocabulary_word2index_label, vocabulary_index2word_label = create_voabulary_label(name_scope="cnn2")
questionid_question_lists = load_final_test_data(FLAGS.predict_source_file)
test = load_data_predict(vocabulary_word2index, vocabulary_word2index_label, questionid_question_lists)
testX = []
question_id_list = []
for tuple in test:
    question_id, question_string_list = tuple
    question_id_list.append(question_id)
    testX.append(question_string_list)
# 2.Data preprocessing: Sequence padding
print("start padding....")
testX2 = pad_sequences(testX, maxlen=FLAGS.sentence_len, value=0.)  # padding to max length
print("end padding...")