def main(_):
    #1.load data.
    if True:
        trainX, trainY, testX, testY = None, None, None, None
        vocabulary_word2index, vocabulary_index2word = create_vocabulary(word2vec_model_path=FLAGS.word2vec_model_path,name_scope="cnn2") 
        vocab_size = len(vocabulary_word2index)
        print("cnn_model.vocab_size:",vocab_size)
        vocabulary_word2index_label,vocabulary_index2word_label = create_vocabulary_label(name_scope="cnn2")
        train, test, _ = load_data_new(vocabulary_word2index, vocabulary_word2index_label,traning_data_path=FLAGS.traning_data_path) 
        trainX, trainY = train
        testX, testY = test
        # 2.Data preprocessing.Sequence padding. Post padding.
        print("start padding & transform to one hot...")
        trainX=np.array([row + [0] * (FLAGS.sentence_len - len(row)) for row in trainX])
        testX=np.array([row + [0] * (FLAGS.sentence_len - len(row)) for row in testX])
        print("trainX[0]:", trainX[0]) 
        # Converting labels to binary vectors
        print("end padding & transform to one hot...")
    #2.create session.
    config=tf.ConfigProto()
    config.gpu_options.allow_growth=True
    with tf.Session(config=config) as sess:
        #Instantiate Model
        textCNN=TextCNN(filter_sizes,FLAGS.num_filters,FLAGS.num_classes, FLAGS.learning_rate, FLAGS.batch_size, FLAGS.decay_steps,
                        FLAGS.decay_rate,FLAGS.sentence_len,vocab_size,FLAGS.embed_size,FLAGS.is_training)
        #Initialize Save
        saver=tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
            print("Restoring Variables from Checkpoint")
            saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_embedding: #load pre-trained word embedding
                assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, textCNN,word2vec_model_path=FLAGS.word2vec_model_path)
        curr_epoch=sess.run(textCNN.epoch_step)
        #3.feed data & training
        number_of_training_data=len(trainX)
        batch_size=FLAGS.batch_size
        for epoch in range(curr_epoch,FLAGS.num_epochs):
            loss, acc, counter = 0.0, 0.0, 0
            for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)):
                if epoch==0 and counter==0:
                    print("trainX[start:end]:",trainX[start:end])
                feed_dict = {textCNN.input_x: trainX[start:end],textCNN.dropout_keep_prob: 0.5}
                feed_dict[textCNN.input_y] = trainY[start:end]                
                curr_loss,curr_acc,_=sess.run([textCNN.loss_val,textCNN.accuracy,textCNN.train_op],feed_dict) 
                loss,counter,acc=loss+curr_loss,counter+1,acc+curr_acc
                if counter %3400==0:
                    print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tTrain Accuracy:%.3f" %(epoch,counter,loss/float(counter),acc/float(counter))) 
            #epoch increment
            print("going to increment epoch counter....")
            sess.run(textCNN.epoch_increment)

            # 4.validation
            print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0))
            if epoch % FLAGS.validate_every==0:
                eval_loss, eval_acc=do_eval(sess,textCNN,testX,testY,batch_size,vocabulary_index2word_label)
                print("Epoch %d Validation Loss:%.3f\tValidation Accuracy: %.3f" % (epoch,eval_loss,eval_acc))
                #save model to checkpoint
                save_path=FLAGS.ckpt_dir+"model.ckpt"
                saver.save(sess,save_path,global_step=epoch)
        
        test_loss, test_acc = do_eval(sess, textCNN, testX, testY, batch_size,vocabulary_index2word_label)
    pass
Beispiel #2
0
def main(_):
    # 1.load data with vocabulary of words and labels
    vocabulary_word2index, vocabulary_index2word = create_vocabulary(
        simple='simple',
        word2vec_model_path=FLAGS.word2vec_model_path,
        name_scope="rnn")
    vocab_size = len(vocabulary_word2index)
    vocabulary_word2index_label, vocabulary_index2word_label = create_vocabulary_label(
        name_scope="rnn")
    keyphraseid_keyphrase_lists = load_final_test_data(
        FLAGS.predict_source_file)
    keyphrase_string_list = []
    for t in keyphraseid_keyphrase_lists:
        kid, keyphrase = t
        keyphrase_string_list.append(keyphrase)
    test = load_data_predict(vocabulary_word2index,
                             vocabulary_word2index_label,
                             keyphraseid_keyphrase_lists)
    testX = []
    keyphrase_id_list = []
    for tuplee in test:
        keyphrase_id, keyphrase_string = tuplee
        keyphrase_id_list.append(keyphrase_id)
        testX.append(keyphrase_string)
    # 2.Data preprocessing: Sequence padding.Post padding.
    print("start padding....")
    testX2 = np.array([
        row + [0] * (FLAGS.sequence_length - len(row)) for row in testX
    ])  #padding to max length
    print("end padding...")
    # 3.create session.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # 4.Instantiate Model
        textRNN = TextRNN(FLAGS.num_classes, FLAGS.learning_rate,
                          FLAGS.batch_size, FLAGS.decay_steps,
                          FLAGS.decay_rate, FLAGS.sequence_length, vocab_size,
                          FLAGS.embed_size, FLAGS.is_training)
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint for TextRNN")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            print("Can't find the checkpoint.going to stop")
            return
        # 5.feed data, to get logits
        number_of_training_data = len(testX2)
        print("number_of_training_data:", number_of_training_data)
        index = 0
        no_of_relevant_phrases = 0
        predict_target_file_f = codecs.open(FLAGS.predict_target_file, 'w',
                                            'utf8')
        for start, end in zip(
                range(0, number_of_training_data, FLAGS.batch_size),
                range(FLAGS.batch_size, number_of_training_data + 1,
                      FLAGS.batch_size)):
            logits = sess.run(textRNN.logits,
                              feed_dict={
                                  textRNN.input_x: testX2[start:end],
                                  textRNN.dropout_keep_prob: 1
                              })
            print("start:", start, ";end:", end)
            keyphrase_id_sublist = keyphrase_id_list[start:end]
            keyphrase_string_sublist = keyphrase_string_list[start:end]
            get_label_using_logits_batch(keyphrase_string_sublist, logits,
                                         vocabulary_index2word_label,
                                         predict_target_file_f)
            index = index + 1
        predict_target_file_f.close()