Пример #1
0
def main(_):
    logger = logging.getLogger('ai_law')
    logger.setLevel(logging.INFO)
    fh = logging.FileHandler(FLAGS.log_path, mode='a')
    fh.setLevel(logging.INFO)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
        # 制定formatter
    fmt = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
    datefmt = "%a %Y-%m-%d %H:%M:%S"  # TODO month
    formatter = logging.Formatter(fmt, datefmt)

    # 为文件和控制台设置输出格式
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)

    # 添加两种句柄到logger对象
    logger.addHandler(fh)
    logger.addHandler(ch)

    logger.info("model:{}".format(FLAGS.model))
    vocab_word2index, accusation_label2index,articles_label2index= create_or_load_vocabulary(FLAGS.data_path,FLAGS.predict_path,FLAGS.traning_data_file,FLAGS.vocab_size,name_scope=FLAGS.name_scope,test_mode=FLAGS.test_mode)
    deathpenalty_label2index={True:1,False:0}
    lifeimprisonment_label2index={True:1,False:0}
    vocab_size = len(vocab_word2index);print("cnn_model.vocab_size:",vocab_size);
    accusation_num_classes=len(accusation_label2index);article_num_classes=len(articles_label2index)
    deathpenalty_num_classes=len(deathpenalty_label2index);lifeimprisonment_num_classes=len(lifeimprisonment_label2index)
    logger.info("accusation_num_classes:{} article_num_clasess:{} ".format(accusation_num_classes, article_num_classes))
    train,valid, test= load_data_multilabel(FLAGS.traning_data_file,FLAGS.valid_data_file,FLAGS.test_data_path,FLAGS.stopwords_file ,vocab_word2index, accusation_label2index,articles_label2index,deathpenalty_label2index,lifeimprisonment_label2index,
                                      FLAGS.sentence_len,name_scope=FLAGS.name_scope,test_mode=FLAGS.test_mode)
    train_X, train_Y_accusation, train_Y_article, train_Y_deathpenalty, train_Y_lifeimprisonment, train_Y_imprisonment,train_weights_accusation,train_weights_article = train
    valid_X, valid_Y_accusation, valid_Y_article, valid_Y_deathpenalty, valid_Y_lifeimprisonment, valid_Y_imprisonment,valid_weights_accusation,valid_weights_article = valid
    test_X, test_Y_accusation, test_Y_article, test_Y_deathpenalty, test_Y_lifeimprisonment, test_Y_imprisonment,test_weights_accusation,test_weights_article = test
    #print some message for debug purpose
    # print("length of training data:",len(train_X),";valid data:",len(valid_X),";test data:",len(test_X))
    logger.info("length of training data:{} ;valid data:{} ;test data:{}".format(len(train_X),len(valid_X),len(test_X)))
    # print("trainX_[0]:", train_X[0]);
    train_Y_accusation_short = get_target_label_short(train_Y_accusation[0])
    train_Y_article_short = get_target_label_short(train_Y_article[0])
    # print("train_Y_accusation_short:", train_Y_accusation_short,";train_Y_article_short:",train_Y_article_short)
    # print("train_Y_deathpenalty:",train_Y_deathpenalty[0],";train_Y_lifeimprisonment:",train_Y_lifeimprisonment[0],";train_Y_imprisonment:",train_Y_imprisonment[0])
    #2.create session.
    config=tf.ConfigProto()
    config.gpu_options.allow_growth=True
    with tf.Session(config=config) as sess:
        #Instantiate Model
        model=HierarchicalAttention( accusation_num_classes,article_num_classes, deathpenalty_num_classes,lifeimprisonment_num_classes,FLAGS.learning_rate,FLAGS.batch_size,
                            FLAGS.decay_steps, FLAGS.decay_rate, FLAGS.sentence_len, FLAGS.num_sentences,vocab_size, FLAGS.embed_size,FLAGS.hidden_size,
                                     num_filters=FLAGS.num_filters,model=FLAGS.model,filter_sizes=filter_sizes,stride_length=stride_length,pooling_strategy=FLAGS.pooling_strategy)
        #Initialize Save
        saver=tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir_accu+"checkpoint"):
            logger.info("Restoring Variables from Checkpoint.")
            saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir_accu))
            for i in range(2): #decay learning rate if necessary.
                logger.info("{} Going to decay learning rate by half.".format(i))
                sess.run(model.learning_rate_decay_half_op)
                #sess.run(model.learning_rate_decay_half_op)

        else:
            logger.info('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_pretrained_embedding: #load pre-trained word embedding
                vocabulary_index2word={index:word for word,index in vocab_word2index.items()}
                assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path,model.Embedding)
                #assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path2,model.Embedding2) #TODO

        curr_epoch=sess.run(model.epoch_step)
        #3.feed data & training
        number_of_training_data=len(train_X)
        batch_size=FLAGS.batch_size
        iteration=0
        accasation_score_best=-100
        law_score_best=-100
        imprisonment_score_best=-100

        for epoch in range(curr_epoch,FLAGS.num_epochs):
            loss_total, counter =  0.0, 0
            for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)):
                iteration=iteration+1
                if epoch==0 and counter==0:
                    logger.info("trainX[start:end]: {} train_X.shape: {}".format(train_X[start:end], train_X.shape))
                feed_dict = {model.input_x: train_X[start:end],model.input_y_accusation:train_Y_accusation[start:end],model.input_y_article:train_Y_article[start:end],
                             model.input_y_deathpenalty:train_Y_deathpenalty[start:end],model.input_y_lifeimprisonment:train_Y_lifeimprisonment[start:end],
                             model.input_y_imprisonment:train_Y_imprisonment[start:end],model.input_weight_accusation:train_weights_accusation[start:end],
                             model.input_weight_article:train_weights_article[start:end],model.dropout_keep_prob: FLAGS.keep_dropout_rate,
                             model.is_training_flag:FLAGS.is_training_flag}
                             #model.iter: iteration,model.tst: not FLAGS.is_training
                current_loss,lr,loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,_=\
                    sess.run([model.loss_val,model.learning_rate,model.loss_accusation,model.loss_article,model.loss_deathpenalty,
                                         model.loss_lifeimprisonment,model.loss_imprisonment,model.l2_loss,model.train_op],feed_dict) #model.update_ema
                loss_total,counter=loss_total+current_loss,counter+1
                if counter %200==0:
                    print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f" %(epoch,counter,float(loss_total)/float(counter),lr))
                if counter %600==0:
                    # print("Loss_accusation:%.3f\tLoss_article:%.3f\tLoss_deathpenalty:%.3f\tLoss_lifeimprisonment:%.3f\tLoss_imprisonment:%.3f\tL2_loss:%.3f\tCurrent_loss:%.3f\t"
                          # %(loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,current_loss))
                    logger.info("Loss_accusation:{} \tLoss_article:{} \tLoss_deathpenalty:{} \tLoss_lifeimprisonment:{} \tLoss_imprisonment:{} \tL2_loss:{} \tCurrent_loss:{} \t".format(loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,current_loss))
                ########################################################################################################
                if start!=0 and start%(2000*FLAGS.batch_size)==0: # eval every 400 steps.
                    loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle, f1_a_death, f1_i_death, f1_a_life, f1_i_life, score_penalty = \
                        do_eval(sess, model, valid,iteration,accusation_num_classes,article_num_classes)
                    accasation_score=((f1_macro_accasation+f1_micro_accasation)/2.0)*100.0
                    article_score=((f1_a_article+f1_i_aritcle)/2.0)*100.0
                    score_all=accasation_score+article_score+score_penalty #3ecfDzJbjUvZPUdS
                    # print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f Micro_f1_article:%.3f Macro_f1_deathpenalty:%.3f\t"
                                # "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t"
                                # % (epoch, loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle,f1_a_death, f1_i_death, f1_a_life, f1_i_life))
                    logger.info("Epoch {} ValidLoss:{} \n Macro_f1_accasation:{} \tMicro_f1_accsastion:{}\tMacro_f1_article:{} \t Micro_f1_article:{} \t Macro_f1_deathpenalty:{} \t"
                                "Micro_f1_deathpenalty:{} \tMacro_f1_lifeimprisonment:{} \tMicro_f1_lifeimprisonment:{}\t".format(epoch, loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle,f1_a_death, f1_i_death, f1_a_life, f1_i_life))
                    # print("1.Accasation Score:", accasation_score, ";2.Article Score:", article_score, ";3.Penalty Score:",score_penalty, ";Score ALL:", score_all)
                    logger.info("Epoch:{} 1.Accasation Score:{} ;2.Article Score:{} ;3.Penalty Score:{} ;Score ALL:{}\n accasation_score_best{}".format(epoch,accasation_score, article_score, score_penalty, score_all, accasation_score_best))
                    # save model to checkpoint
                    if accasation_score>accasation_score_best:
                        save_path = FLAGS.ckpt_dir_accu + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once.
                        logger.info("going to save check point for accusation.")
                        saver.save(sess, save_path, global_step=epoch)
                        accasation_score_best=accasation_score
                    if article_score > law_score_best:
                        save_path = FLAGS.ckpt_dir_law + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once.
                        logger.info("going to save check point for article.")
                        saver.save(sess, save_path, global_step=epoch)
                        law_score_best = article_score
                    if score_penalty > imprisonment_score_best:
                        save_path = FLAGS.ckpt_dir_imprision + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once.
                        logger.info("going to save check point for imprisonment.")
                        saver.save(sess, save_path, global_step=epoch)
                        imprisonment_score_best = score_penalty

                    logger.info("Epoch:{} Bestscore:1 Accasation:{} ;2. Article:{} ;3.penalty:{}".format(epoch, accasation_score_best, law_score_best, imprisonment_score_best))
            #epoch increment
            # print("going to increment epoch counter....")
            logger.info("going to increment epoch counter....")
            sess.run(model.epoch_increment)

            # 4.validation
            print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0))
            if epoch % FLAGS.validate_every==0:
                loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life,score_penalty=\
                    do_eval(sess,model,valid,iteration,accusation_num_classes,article_num_classes)
                accasation_score = ((f1_macro_accasation + f1_micro_accasation) / 2.0) * 100.0
                article_score = ((f1_a_article + f1_i_aritcle) / 2.0) * 100.0
                score_all = accasation_score + article_score + score_penalty
                print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f\tMicro_f1_article:%.3f\tMacro_f1_deathpenalty:%.3f\t"
                      "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t"
                      % (epoch,loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life))
                # print("===>1.Accasation Score:", accasation_score, ";2.Article Score:", article_score,";3.Penalty Score:",score_penalty,";Score ALL:",score_all)
                logger.info("===>1.Accasation Score: {} ;2.Article Score: {} ;3.Penalty Score:{} ;Score ALL:{}".format(accasation_score, article_score, score_penalty, score_all))
                #save model to checkpoint
                if accasation_score > accasation_score_best:
                    save_path=FLAGS.ckpt_dir_accu+"model.ckpt"
                    print("going to save check point.")
                    saver.save(sess,save_path,global_step=epoch)
                    accasation_score_best = accasation_score
                if article_score > law_score_best:
                    save_path = FLAGS.ckpt_dir_law + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once.
                    logger.info("going to save check point for article.")
                    saver.save(sess, save_path, global_step=epoch)
                    law_score_best = article_score
                if score_penalty > imprisonment_score_best:
                    save_path = FLAGS.ckpt_dir_imprision + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once.
                    logger.info("going to save check point for imprisonment.")
                    saver.save(sess, save_path, global_step=epoch)
                    imprisonment_score_best = score_penalty
            #if (epoch == 2 or epoch == 4 or epoch == 7 or epoch==10 or epoch == 13  or epoch==19):
            if (epoch == 1 or epoch == 3 or epoch == 6 or epoch == 9 or epoch == 12 or epoch == 18):
                for i in range(2):
                    print(i, "Going to decay learning rate by half.")
                    sess.run(model.learning_rate_decay_half_op)

        # 5.最后在测试集上做测试,并报告测试准确率 Testto 0.0
        # test_loss,macrof1,microf1 = do_eval(sess, flags.model, testX, testY,iteration)
        # print("Test Loss:%.3f\tMacro f1:%.3f\tMicro f1:%.3f" % (test_loss,macrof1,microf1))
        # print("training completed...")
    pass
Пример #2
0
def main(_):
    training_data_path = '/Users/liyangyang/Downloads/bdci/train.txt'
    vocabulary_word2index, vocabulary_index2word, vocabulary_label2index, vocabulary_index2label = \
        data_util.create_vocabulary(training_data_path, 17259, name_scope='cnn')
    vocab_size = len(vocabulary_word2index) + 1
    print("cnn_model.vocab_size:", vocab_size)
    num_classes = len(vocabulary_index2label)
    print("num_classes:", num_classes)
    print(vocabulary_index2label)
    train, test = data_util.load_data_multilabel(training_data_path,
                                                 vocabulary_word2index,
                                                 vocabulary_label2index, 200)
    trainX, trainY = train
    testX, testY = test
    # trainX = trainX[0:8000]
    # trainY = trainY[0:8000]
    # testX = testX[0:500]
    # testY = testY[0:500]
    # print some message for debug purpose
    print("length of training data:", len(trainX),
          ";length of validation data:", len(testX))
    print("trainX.shape", np.array(trainX).shape)
    print("trainY.shape", np.array(trainY).shape)
    print("trainX[0]:", trainX[1])
    print("trainY[0]:", trainY[1])

    print("end padding & transform to one hot...")

    # 2.create session.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # Instantiate Model
        textCNN = TextCNN(filter_sizes, FLAGS.num_filters, FLAGS.num_classes,
                          FLAGS.learning_rate, FLAGS.batch_size,
                          FLAGS.decay_steps, FLAGS.decay_rate,
                          FLAGS.sentence_len, vocab_size, FLAGS.embed_size,
                          FLAGS.is_training)
        # Initialize Save
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint.")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            # for i in range(3): #decay learning rate if necessary.
            #    print(i,"Going to decay learning rate by half.")
            #    sess.run(textCNN.learning_rate_decay_half_op)
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_embedding:  # load pre-trained word embedding
                assign_pretrained_word_embedding(sess, vocabulary_index2word,
                                                 vocab_size, textCNN)
        curr_epoch = sess.run(textCNN.epoch_step)
        # 3.feed data & training
        number_of_training_data = len(trainX)
        batch_size = FLAGS.batch_size
        iteration = 0
        for epoch in range(curr_epoch, FLAGS.num_epochs):
            loss, acc, counter = 0.0, 0.0, 0
            for start, end in zip(
                    range(0, number_of_training_data, batch_size),
                    range(batch_size, number_of_training_data, batch_size)):
                iteration = iteration + 1
                if epoch == 0 and counter == 0:
                    print("trainX[start:end]:", trainX[start:end])
                    print("trainY[start:end]:", trainY[start:end])
                feed_dict = {
                    textCNN.input_x: trainX[start:end],
                    textCNN.dropout_keep_prob: 0.5,
                    textCNN.iter: iteration,
                    textCNN.tst: not FLAGS.is_training
                }
                if not FLAGS.multi_label_flag:
                    feed_dict[textCNN.input_y] = trainY[start:end]
                else:
                    feed_dict[textCNN.input_y_multilabel] = trainY[start:end]
                curr_loss, lr, curr_acc, _ = sess.run([
                    textCNN.loss_val, textCNN.learning_rate, textCNN.accuracy,
                    textCNN.train_op
                ], feed_dict)
                loss, counter, acc = loss + curr_loss, counter + 1, acc + curr_acc
                if counter % 2 == 0:
                    print(
                        "Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f\tTrain Accuracy:%.3f"
                        % (epoch, counter, loss / float(counter), lr,
                           acc / float(counter)))

                ########################################################################################################
                # if start % (2000 * FLAGS.batch_size) == 0:  # eval every 3000 steps.
                #     eval_loss, f1_score, precision, recall = do_eval(sess, textCNN, testX, testY, iteration)
                #     print("Epoch %d Validation Loss:%.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f" % (
                #         epoch, eval_loss, f1_score, precision, recall))
                #     # save model to checkpoint
                #     save_path = FLAGS.ckpt_dir + "model.ckpt"
                #     saver.save(sess, save_path, global_step=epoch)
                ########################################################################################################
            # epoch increment
            print("going to increment epoch counter....")
            sess.run(textCNN.epoch_increment)

            # 4.validation
            print(epoch, FLAGS.validate_every,
                  (epoch % FLAGS.validate_every == 0))
            if epoch % FLAGS.validate_every == 0:
                # save model to checkpoint
                save_path = FLAGS.ckpt_dir + "model.ckpt"
                saver.save(sess, save_path, global_step=epoch)

                eval_loss, eval_acc = do_eval(sess, textCNN, testX, testY,
                                              iteration, batch_size)
                print(
                    "Epoch %d Validation Loss:%.3f\tValidation Accuracy: %.3f"
                    % (epoch, eval_loss, eval_acc))

        # 5.最后在测试集上做测试,并报告测试准确率 Test
        eval_loss, eval_acc = do_eval(sess, textCNN, testX, testY, iteration,
                                      batch_size)
        print("Test Loss:%.3f" % (eval_loss))
    pass
Пример #3
0
def main(_):
    trainX, trainY, testX, testY = None, None, None, None
    vocabulary_word2index, vocabulary_index2word, vocabulary_label2index, vocabulary_index2label = create_vocabulary(
        FLAGS.traning_data_path, FLAGS.vocab_size, name_scope=FLAGS.name_scope)
    vocab_size = len(vocabulary_word2index)
    print("cnn_model.vocab_size:", vocab_size)
    num_classes = len(vocabulary_index2label)
    print("num_classes:", num_classes)
    train, test = load_data_multilabel(FLAGS.traning_data_path,
                                       vocabulary_word2index,
                                       vocabulary_label2index,
                                       FLAGS.sentence_len)
    trainX, trainY = train
    testX, testY = test
    #print some message for debug purpose
    print("length of training data:", len(trainX),
          ";length of validation data:", len(testX))
    print("trainX[0]:", trainX[0])
    print("trainY[0]:", trainY[0])
    train_y_short = get_target_label_short(trainY[0])
    print("train_y_short:", train_y_short)

    #2.create session.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        #Instantiate Model
        textCNN = TextCNN(filter_sizes,
                          FLAGS.num_filters,
                          num_classes,
                          FLAGS.learning_rate,
                          FLAGS.batch_size,
                          FLAGS.decay_steps,
                          FLAGS.decay_rate,
                          FLAGS.sentence_len,
                          vocab_size,
                          FLAGS.embed_size,
                          FLAGS.is_training,
                          multi_label_flag=FLAGS.multi_label_flag)
        #Initialize Save
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint.")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            #for i in range(3): #decay learning rate if necessary.
            #    print(i,"Going to decay learning rate by half.")
            #    sess.run(textCNN.learning_rate_decay_half_op)
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_embedding:  #load pre-trained word embedding
                assign_pretrained_word_embedding(sess, vocabulary_index2word,
                                                 vocab_size, textCNN,
                                                 FLAGS.word2vec_model_path)
        curr_epoch = sess.run(textCNN.epoch_step)
        #3.feed data & training
        number_of_training_data = len(trainX)
        batch_size = FLAGS.batch_size
        iteration = 0
        for epoch in range(curr_epoch, FLAGS.num_epochs):
            loss, counter = 0.0, 0
            for start, end in zip(
                    range(0, number_of_training_data, batch_size),
                    range(batch_size, number_of_training_data, batch_size)):
                iteration = iteration + 1
                if epoch == 0 and counter == 0:
                    print("trainX[start:end]:", trainX[start:end])
                feed_dict = {
                    textCNN.input_x: trainX[start:end],
                    textCNN.dropout_keep_prob: 0.5,
                    textCNN.iter: iteration,
                    textCNN.tst: not FLAGS.is_training
                }
                if not FLAGS.multi_label_flag:
                    feed_dict[textCNN.input_y] = trainY[start:end]
                else:
                    feed_dict[textCNN.input_y_multilabel] = trainY[start:end]
                curr_loss, lr, _, _ = sess.run([
                    textCNN.loss_val, textCNN.learning_rate,
                    textCNN.update_ema, textCNN.train_op
                ], feed_dict)
                loss, counter = loss + curr_loss, counter + 1
                if counter % 50 == 0:
                    print(
                        "Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f"
                        % (epoch, counter, loss / float(counter), lr))

                ########################################################################################################
                if start % (2000 *
                            FLAGS.batch_size) == 0:  # eval every 3000 steps.
                    eval_loss, f1_score, precision, recall = do_eval(
                        sess, textCNN, testX, testY, iteration)
                    print(
                        "Epoch %d Validation Loss:%.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f"
                        % (epoch, eval_loss, f1_score, precision, recall))
                    # save model to checkpoint
                    save_path = FLAGS.ckpt_dir + "model.ckpt"
                    saver.save(sess, save_path, global_step=epoch)
                ########################################################################################################
            #epoch increment
            print("going to increment epoch counter....")
            sess.run(textCNN.epoch_increment)

            # 4.validation
            print(epoch, FLAGS.validate_every,
                  (epoch % FLAGS.validate_every == 0))
            if epoch % FLAGS.validate_every == 0:
                eval_loss, f1_score, precision, recall = do_eval(
                    sess, textCNN, testX, testY, iteration)
                print(
                    "Epoch %d Validation Loss:%.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f"
                    % (epoch, eval_loss, f1_score, precision, recall))
                #save model to checkpoint
                save_path = FLAGS.ckpt_dir + "model.ckpt"
                saver.save(sess, save_path, global_step=epoch)

        # 5.最后在测试集上做测试,并报告测试准确率 Test
        test_loss, _, _, _ = do_eval(sess, textCNN, testX, testY, iteration)
        print("Test Loss:%.3f" % (test_loss))
    pass
Пример #4
0
def main(_):
    # 1.load data(X:list of lint,y:int).
    # if os.path.exists(FLAGS.cache_path):  # 如果文件系统中存在,那么加载故事(词汇表索引化的)
    #    with open(FLAGS.cache_path, 'r') as data_f:
    #        trainX, trainY, testX, testY, vocabulary_index2word=pickle.load(data_f)
    #        vocab_size=len(vocabulary_index2word)
    # else:
    if 1 == 1:
        # vocab_processor_path = '/Users/liyangyang/PycharmProjects/mypy/venv/dwb/testcnn/vocab'
        # # print("end padding & transform to one hot...")
        # x_train, y = data_helpers.load_data_and_labels(FLAGS.data_file)
        #
        # # vocab_processor = learn.preprocessing.VocabularyProcessor(2000,min_frequency=2)
        # # x = np.array(list(vocab_processor.fit_transform(x_train)))
        # # vocab_processor.save(vocab_processor_path)
        #
        # vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_processor_path)
        # x = np.array(list(vocab_processor.transform(x_train)))
        #
        # trainX = x[:100000]
        # testX = x[100000:]
        # trainY = y[:100000]
        # testY = y[100000:]
        # vocab_size = len(vocab_processor.vocabulary_)
        # print('vocab_size', vocab_size)
        # print("trainX[0]:", trainX[0])  # ;print("trainY[0]:", trainY[0])
        # # Converting labels to binary vectors
        # print("end padding & transform to one hot...")
        training_data_path = '/Users/liyangyang/Downloads/dwb/new_data/train_set.txt'
        vocabulary_word2index, vocabulary_index2word, vocabulary_label2index, vocabulary_index2label = \
            data_util.create_vocabulary(training_data_path, 345325, name_scope='cnn')
        vocab_size = len(vocabulary_word2index) + 1
        print("cnn_model.vocab_size:", vocab_size)
        num_classes = len(vocabulary_index2label)
        print("num_classes:", num_classes)
        print(vocabulary_index2label)
        train, test = data_util.load_data_multilabel(training_data_path,
                                                     vocabulary_word2index,
                                                     vocabulary_label2index,
                                                     5000)
        trainX, trainY = train
        testX, testY = test
        trainX = trainX[0:1000]
        trainY = trainY[0:1000]
        testX = testX[0:500]
        testY = testY[0:500]
        # print some message for debug purpose
        print("length of training data:", len(trainX),
              ";length of validation data:", len(testX))
        print("trainX.shape", np.array(trainX).shape)
        print("trainY.shape", np.array(trainY).shape)
        print("trainX[0]:", trainX[1])
        print("trainY[0]:", trainY[1])

        print("end padding & transform to one hot...")
    # 2.create session.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # Instantiate Model
        textRNN = TextRNN(FLAGS.num_classes, FLAGS.learning_rate,
                          FLAGS.batch_size, FLAGS.decay_steps,
                          FLAGS.decay_rate, FLAGS.sequence_length, vocab_size,
                          FLAGS.embed_size, FLAGS.is_training)
        # Initialize Save
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint for rnn model.")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_embedding:  # load pre-trained word embedding
                assign_pretrained_word_embedding(sess, vocabulary_index2word,
                                                 vocab_size, textRNN)
        curr_epoch = sess.run(textRNN.epoch_step)
        # 3.feed data & training
        number_of_training_data = len(trainX)
        batch_size = FLAGS.batch_size
        for epoch in range(curr_epoch, FLAGS.num_epochs):
            loss, acc, counter = 0.0, 0.0, 0
            for start, end in zip(
                    range(0, number_of_training_data, batch_size),
                    range(batch_size, number_of_training_data, batch_size)):
                if epoch == 0 and counter == 0:
                    print("trainX[start:end]:", trainX[start:end]
                          )  # ;print("trainY[start:end]:",trainY[start:end])
                curr_loss, curr_acc, _ = sess.run(
                    [textRNN.loss_val, textRNN.accuracy, textRNN.train_op],
                    feed_dict={
                        textRNN.input_x: trainX[start:end],
                        textRNN.input_y: trainY[start:end],
                        textRNN.dropout_keep_prob: 1
                    }
                )  # curr_acc--->TextCNN.accuracy -->,textRNN.dropout_keep_prob:1
                loss, counter, acc = loss + curr_loss, counter + 1, acc + curr_acc
                if counter % 1 == 0:
                    print(
                        "Epoch %d\tBatch %d\tTrain Loss:%.3f\tTrain Accuracy:%.3f"
                        % (epoch, counter, loss / float(counter),
                           acc / float(counter))
                    )  # tTrain Accuracy:%.3f---》acc/float(counter)
            # epoch increment
            print("going to increment epoch counter....")
            sess.run(textRNN.epoch_increment)
            # 4.validation
            print(epoch, FLAGS.validate_every,
                  (epoch % FLAGS.validate_every == 0))
            if epoch % FLAGS.validate_every == 0:
                eval_loss, eval_acc = do_eval(sess, textRNN, testX, testY,
                                              batch_size)
                print(
                    "Epoch %d Validation Loss:%.3f\tValidation Accuracy: %.3f"
                    % (epoch, eval_loss, eval_acc))
                # save model to checkpoint
                save_path = FLAGS.ckpt_dir + "model.ckpt"
                saver.save(sess, save_path, global_step=epoch)

        # 5.最后在测试集上做测试,并报告测试准确率 Test
        test_loss, test_acc = do_eval(sess, textRNN, testX, testY, batch_size)
    pass
Пример #5
0
def main(_):
    print("model:",FLAGS.model)
    name_scope=FLAGS.model
    vocab_word2index, accusation_label2index,articles_label2index= create_or_load_vocabulary(FLAGS.data_path,FLAGS.predict_path,FLAGS.traning_data_file,FLAGS.vocab_size,name_scope=name_scope,test_mode=FLAGS.test_mode)
    deathpenalty_label2index={True:1,False:0}
    lifeimprisonment_label2index={True:1,False:0}
    vocab_size = len(vocab_word2index);print("cnn_model.vocab_size:",vocab_size);
    accusation_num_classes=len(accusation_label2index);article_num_classes=len(articles_label2index)
    deathpenalty_num_classes=len(deathpenalty_label2index);lifeimprisonment_num_classes=len(lifeimprisonment_label2index)
    print("accusation_num_classes:",accusation_num_classes);print("article_num_clasess:",article_num_classes)
    train,valid, test= load_data_multilabel(FLAGS.traning_data_file,FLAGS.valid_data_file,FLAGS.test_data_path,vocab_word2index, accusation_label2index,articles_label2index,deathpenalty_label2index,lifeimprisonment_label2index,
                                      FLAGS.sentence_len,name_scope=name_scope,test_mode=FLAGS.test_mode)
    train_X, train_Y_accusation, train_Y_article, train_Y_deathpenalty, train_Y_lifeimprisonment, train_Y_imprisonment,train_weights_accusation,train_weights_article = train
    valid_X, valid_Y_accusation, valid_Y_article, valid_Y_deathpenalty, valid_Y_lifeimprisonment, valid_Y_imprisonment,valid_weights_accusation,valid_weights_article = valid
    test_X, test_Y_accusation, test_Y_article, test_Y_deathpenalty, test_Y_lifeimprisonment, test_Y_imprisonment,test_weights_accusation,test_weights_article = test
    #print some message for debug purpose
    print("length of training data:",len(train_X),";valid data:",len(valid_X),";test data:",len(test_X))
    print("trainX_[0]:", train_X[0]);
    train_Y_accusation_short1 = get_target_label_short(train_Y_accusation[0]);train_Y_accusation_short2 = get_target_label_short(train_Y_accusation[1]);train_Y_accusation_short3 = get_target_label_short(train_Y_accusation[2]);train_Y_accusation_short4 = get_target_label_short(train_Y_accusation[20]);train_Y_accusation_short5 = get_target_label_short(train_Y_accusation[200])
    train_Y_article_short = get_target_label_short(train_Y_article[0])
    print("train_Y_accusation_short:", train_Y_accusation_short1,train_Y_accusation_short2,train_Y_accusation_short3,train_Y_accusation_short4,train_Y_accusation_short4,";train_Y_article_short:",train_Y_article_short)
    print("train_Y_deathpenalty:",train_Y_deathpenalty[0],";train_Y_lifeimprisonment:",train_Y_lifeimprisonment[0],";train_Y_imprisonment:",train_Y_imprisonment[0])
    #2.create session.
    config=tf.ConfigProto()
    config.gpu_options.allow_growth=True
    with tf.Session(config=config) as sess:
        #Instantiate Model
        model=HierarchicalAttention( accusation_num_classes,article_num_classes, deathpenalty_num_classes,lifeimprisonment_num_classes,FLAGS.learning_rate,FLAGS.batch_size,
                            FLAGS.decay_steps, FLAGS.decay_rate, FLAGS.sentence_len, FLAGS.num_sentences,vocab_size, FLAGS.embed_size,FLAGS.hidden_size,
                                     num_filters=FLAGS.num_filters,model=FLAGS.model,filter_sizes=filter_sizes,stride_length=stride_length,pooling_strategy=FLAGS.pooling_strategy)
        #Initialize Save
        saver=tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
            print("Restoring Variables from Checkpoint.")
            saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            for i in range(2): #decay learning rate if necessary.
                print(i,"Going to decay learning rate by half.")
                sess.run(model.learning_rate_decay_half_op)
                #sess.run(model.learning_rate_decay_half_op)

        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_pretrained_embedding: #load pre-trained word embedding
                vocabulary_index2word={index:word for word,index in vocab_word2index.items()}
                assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path,model.Embedding)
                #assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path2,model.Embedding2) #TODO

        curr_epoch=sess.run(model.epoch_step)
        #3.feed data & training
        number_of_training_data=len(train_X)
        batch_size=FLAGS.batch_size
        iteration=0
        accasation_score_best=-100


        for epoch in range(curr_epoch,FLAGS.num_epochs):
            loss_total, counter =  0.0, 0
            for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)):
                iteration=iteration+1
                if epoch==0 and counter==0:
                    print("trainX[start:end]:",train_X[start:end],"train_X.shape:",train_X.shape)
                feed_dict = {model.input_x: train_X[start:end],model.input_y_accusation:train_Y_accusation[start:end],model.input_y_article:train_Y_article[start:end],
                             model.input_y_deathpenalty:train_Y_deathpenalty[start:end],model.input_y_lifeimprisonment:train_Y_lifeimprisonment[start:end],
                             model.input_y_imprisonment:train_Y_imprisonment[start:end],model.input_weight_accusation:train_weights_accusation[start:end],
                             model.input_weight_article:train_weights_article[start:end],model.dropout_keep_prob: FLAGS.keep_dropout_rate,
                             model.is_training_flag:FLAGS.is_training_flag}
                             #model.iter: iteration,model.tst: not FLAGS.is_training
                current_loss,lr,loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,_=\
                    sess.run([model.loss_val,model.learning_rate,model.loss_accusation,model.loss_article,model.loss_deathpenalty,
                                         model.loss_lifeimprisonment,model.loss_imprisonment,model.l2_loss,model.train_op],feed_dict) #model.update_ema
                loss_total,counter=loss_total+current_loss,counter+1
                if counter %20==0:
                    print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f" %(epoch,counter,float(loss_total)/float(counter),lr))
                if counter %60==0:
                    print("Loss_accusation:%.3f\tLoss_article:%.3f\tLoss_deathpenalty:%.3f\tLoss_lifeimprisonment:%.3f\tLoss_imprisonment:%.3f\tL2_loss:%.3f\tCurrent_loss:%.3f\t"
                          %(loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,current_loss))
                ########################################################################################################
                if start!=0 and start%(2000*FLAGS.batch_size)==0: # eval every 400 steps.
                    loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle, f1_a_death, f1_i_death, f1_a_life, f1_i_life, score_penalty = \
                        do_eval(sess, model, valid,iteration,accusation_num_classes,article_num_classes,accusation_label2index)
                    accasation_score=((f1_macro_accasation+f1_micro_accasation)/2.0)*100.0
                    article_score=((f1_a_article+f1_i_aritcle)/2.0)*100.0
                    score_all=accasation_score+article_score+score_penalty #3ecfDzJbjUvZPUdS
                    print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f Micro_f1_article:%.3f Macro_f1_deathpenalty:%.3f\t"
                                "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t"
                                % (epoch, loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle,f1_a_death, f1_i_death, f1_a_life, f1_i_life))
                    print("1.Accasation Score:", accasation_score, ";2.Article Score:", article_score, ";3.Penalty Score:",score_penalty, ";Score ALL:", score_all)
                    # save model to checkpoint
                    if accasation_score>accasation_score_best:
                        save_path = FLAGS.ckpt_dir + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once.
                        print("going to save check point.")
                        saver.save(sess, save_path, global_step=epoch)
                        accasation_score_best=accasation_score
            #epoch increment
            print("going to increment epoch counter....")
            sess.run(model.epoch_increment)

            # 4.validation
            print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0))
            if epoch % FLAGS.validate_every==0:
                loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life,score_penalty=\
                    do_eval(sess,model,valid,iteration,accusation_num_classes,article_num_classes,accusation_label2index)
                accasation_score = ((f1_macro_accasation + f1_micro_accasation) / 2.0) * 100.0
                article_score = ((f1_a_article + f1_i_aritcle) / 2.0) * 100.0
                score_all = accasation_score + article_score + score_penalty
                print()
                print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f\tMicro_f1_article:%.3f\tMacro_f1_deathpenalty:%.3f\t"
                      "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t"
                      % (epoch,loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life))
                print("===>1.Accasation Score:", accasation_score, ";2.Article Score:", article_score,";3.Penalty Score:",score_penalty,";Score ALL:",score_all)

                #save model to checkpoint
                if accasation_score > accasation_score_best:
                    save_path=FLAGS.ckpt_dir+"model.ckpt"
                    print("going to save check point.")
                    saver.save(sess,save_path,global_step=epoch)
                    accasation_score_best = accasation_score
            #if (epoch == 2 or epoch == 4 or epoch == 7 or epoch==10 or epoch == 13  or epoch==19):
            #if (epoch == 1 or epoch == 3 or epoch == 6 or epoch == 9 or epoch == 12 or epoch == 18):
            if (epoch == 0 or epoch == 2 or epoch == 4 or epoch == 6 or epoch == 9 or epoch == 13):

                for i in range(2):
                    print(i, "Going to decay learning rate by half.")
                    sess.run(model.learning_rate_decay_half_op)

        # 5.最后在测试集上做测试,并报告测试准确率 Testto 0.0
        loss_test, f1_macro_accasation_test, f1_micro_accasation_test, f1_a_article_test, f1_i_aritcle_test, f1_a_death_test, f1_i_death_test, f1_a_life_test, f1_i_life_test, score_penalty_test=\
            do_eval(sess, model, test, iteration, accusation_num_classes, article_num_classes, accusation_label2index)
        print("TEST.FINAL.Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f\tMicro_f1_article:%.3f\tMacro_f1_deathpenalty:%.3f\t"
                    "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t"
                    % (epoch, loss_test, f1_macro_accasation_test, f1_micro_accasation_test, f1_a_article_test, f1_i_aritcle_test, f1_a_death_test,
                       f1_i_death_test, f1_a_life_test, f1_i_life_test))
        accasation_score_test = ((f1_macro_accasation_test + f1_micro_accasation_test) / 2.0) * 100.0
        article_score_test = ((f1_a_article_test + f1_i_aritcle_test) / 2.0) * 100.0
        score_all_test = accasation_score_test + article_score_test + score_penalty_test
        print("TEST.Accasation Score:", accasation_score_test, ";2.Article Score:", article_score_test, ";3.Penalty Score:",score_penalty_test, ";Score ALL:", score_all_test)

        #print("Test Loss:%.3f\tMacro f1:%.3f\tMicro f1:%.3f" % (test_loss,macrof1,microf1))
        print("training completed...")
    pass
Пример #6
0
tf.app.flags.DEFINE_boolean(
    "is_training_flag", True, "is training.true:tranining,false:testing/inference")
tf.app.flags.DEFINE_integer("num_epochs", 15, "number of epochs to run.")
tf.app.flags.DEFINE_integer(
    "validate_every", 1, "Validate every validate_every epochs.")  # 每10轮做一次验证
tf.app.flags.DEFINE_boolean("use_embedding", False,
                            "whether to use embedding or not.")
tf.app.flags.DEFINE_integer(
    "num_filters", 128, "number of filters")  # 256--->512
tf.app.flags.DEFINE_string(
    "word2vec_model_path", "word2vec-title-desc.bin", "word2vec's vocabulary and vectors")
tf.app.flags.DEFINE_string("name_scope", "cnn", "name scope value.")
tf.app.flags.DEFINE_boolean(
    "multi_label_flag", False, "use multi label or single label.")
filter_sizes = [6, 7, 8]

print("Restoring Variables from Checkpoint.")
saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
trainX, trainY, testX, testY = None, None, None, None
vocabulary_word2index, vocabulary_index2word, vocabulary_label2index, _ = create_vocabulary(FLAGS.traning_data_path, FLAGS.vocab_size, name_scope=FLAGS.name_scope)
vocab_size = len(vocabulary_word2index)
print("cnn_model.vocab_size:", vocab_size)
num_classes = len(vocabulary_label2index)
print("num_classes:", num_classes)
#num_examples,FLAGS.sentence_len=trainX.shape
#print("num_examples of training:",num_examples,";sentence_len:",FLAGS.sentence_len)
train, test = load_data_multilabel(
    FLAGS.traning_data_path, vocabulary_word2index, vocabulary_label2index, FLAGS.sentence_len)
trainX, trainY = train
testX, testY = test