def main(_): logger = logging.getLogger('ai_law') logger.setLevel(logging.INFO) fh = logging.FileHandler(FLAGS.log_path, mode='a') fh.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) # 制定formatter fmt = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" datefmt = "%a %Y-%m-%d %H:%M:%S" # TODO month formatter = logging.Formatter(fmt, datefmt) # 为文件和控制台设置输出格式 fh.setFormatter(formatter) ch.setFormatter(formatter) # 添加两种句柄到logger对象 logger.addHandler(fh) logger.addHandler(ch) logger.info("model:{}".format(FLAGS.model)) vocab_word2index, accusation_label2index,articles_label2index= create_or_load_vocabulary(FLAGS.data_path,FLAGS.predict_path,FLAGS.traning_data_file,FLAGS.vocab_size,name_scope=FLAGS.name_scope,test_mode=FLAGS.test_mode) deathpenalty_label2index={True:1,False:0} lifeimprisonment_label2index={True:1,False:0} vocab_size = len(vocab_word2index);print("cnn_model.vocab_size:",vocab_size); accusation_num_classes=len(accusation_label2index);article_num_classes=len(articles_label2index) deathpenalty_num_classes=len(deathpenalty_label2index);lifeimprisonment_num_classes=len(lifeimprisonment_label2index) logger.info("accusation_num_classes:{} article_num_clasess:{} ".format(accusation_num_classes, article_num_classes)) train,valid, test= load_data_multilabel(FLAGS.traning_data_file,FLAGS.valid_data_file,FLAGS.test_data_path,FLAGS.stopwords_file ,vocab_word2index, accusation_label2index,articles_label2index,deathpenalty_label2index,lifeimprisonment_label2index, FLAGS.sentence_len,name_scope=FLAGS.name_scope,test_mode=FLAGS.test_mode) train_X, train_Y_accusation, train_Y_article, train_Y_deathpenalty, train_Y_lifeimprisonment, train_Y_imprisonment,train_weights_accusation,train_weights_article = train valid_X, valid_Y_accusation, valid_Y_article, valid_Y_deathpenalty, valid_Y_lifeimprisonment, valid_Y_imprisonment,valid_weights_accusation,valid_weights_article = valid test_X, test_Y_accusation, test_Y_article, test_Y_deathpenalty, test_Y_lifeimprisonment, test_Y_imprisonment,test_weights_accusation,test_weights_article = test #print some message for debug purpose # print("length of training data:",len(train_X),";valid data:",len(valid_X),";test data:",len(test_X)) logger.info("length of training data:{} ;valid data:{} ;test data:{}".format(len(train_X),len(valid_X),len(test_X))) # print("trainX_[0]:", train_X[0]); train_Y_accusation_short = get_target_label_short(train_Y_accusation[0]) train_Y_article_short = get_target_label_short(train_Y_article[0]) # print("train_Y_accusation_short:", train_Y_accusation_short,";train_Y_article_short:",train_Y_article_short) # print("train_Y_deathpenalty:",train_Y_deathpenalty[0],";train_Y_lifeimprisonment:",train_Y_lifeimprisonment[0],";train_Y_imprisonment:",train_Y_imprisonment[0]) #2.create session. config=tf.ConfigProto() config.gpu_options.allow_growth=True with tf.Session(config=config) as sess: #Instantiate Model model=HierarchicalAttention( accusation_num_classes,article_num_classes, deathpenalty_num_classes,lifeimprisonment_num_classes,FLAGS.learning_rate,FLAGS.batch_size, FLAGS.decay_steps, FLAGS.decay_rate, FLAGS.sentence_len, FLAGS.num_sentences,vocab_size, FLAGS.embed_size,FLAGS.hidden_size, num_filters=FLAGS.num_filters,model=FLAGS.model,filter_sizes=filter_sizes,stride_length=stride_length,pooling_strategy=FLAGS.pooling_strategy) #Initialize Save saver=tf.train.Saver() if os.path.exists(FLAGS.ckpt_dir_accu+"checkpoint"): logger.info("Restoring Variables from Checkpoint.") saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir_accu)) for i in range(2): #decay learning rate if necessary. logger.info("{} Going to decay learning rate by half.".format(i)) sess.run(model.learning_rate_decay_half_op) #sess.run(model.learning_rate_decay_half_op) else: logger.info('Initializing Variables') sess.run(tf.global_variables_initializer()) if FLAGS.use_pretrained_embedding: #load pre-trained word embedding vocabulary_index2word={index:word for word,index in vocab_word2index.items()} assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path,model.Embedding) #assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path2,model.Embedding2) #TODO curr_epoch=sess.run(model.epoch_step) #3.feed data & training number_of_training_data=len(train_X) batch_size=FLAGS.batch_size iteration=0 accasation_score_best=-100 law_score_best=-100 imprisonment_score_best=-100 for epoch in range(curr_epoch,FLAGS.num_epochs): loss_total, counter = 0.0, 0 for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)): iteration=iteration+1 if epoch==0 and counter==0: logger.info("trainX[start:end]: {} train_X.shape: {}".format(train_X[start:end], train_X.shape)) feed_dict = {model.input_x: train_X[start:end],model.input_y_accusation:train_Y_accusation[start:end],model.input_y_article:train_Y_article[start:end], model.input_y_deathpenalty:train_Y_deathpenalty[start:end],model.input_y_lifeimprisonment:train_Y_lifeimprisonment[start:end], model.input_y_imprisonment:train_Y_imprisonment[start:end],model.input_weight_accusation:train_weights_accusation[start:end], model.input_weight_article:train_weights_article[start:end],model.dropout_keep_prob: FLAGS.keep_dropout_rate, model.is_training_flag:FLAGS.is_training_flag} #model.iter: iteration,model.tst: not FLAGS.is_training current_loss,lr,loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,_=\ sess.run([model.loss_val,model.learning_rate,model.loss_accusation,model.loss_article,model.loss_deathpenalty, model.loss_lifeimprisonment,model.loss_imprisonment,model.l2_loss,model.train_op],feed_dict) #model.update_ema loss_total,counter=loss_total+current_loss,counter+1 if counter %200==0: print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f" %(epoch,counter,float(loss_total)/float(counter),lr)) if counter %600==0: # print("Loss_accusation:%.3f\tLoss_article:%.3f\tLoss_deathpenalty:%.3f\tLoss_lifeimprisonment:%.3f\tLoss_imprisonment:%.3f\tL2_loss:%.3f\tCurrent_loss:%.3f\t" # %(loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,current_loss)) logger.info("Loss_accusation:{} \tLoss_article:{} \tLoss_deathpenalty:{} \tLoss_lifeimprisonment:{} \tLoss_imprisonment:{} \tL2_loss:{} \tCurrent_loss:{} \t".format(loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,current_loss)) ######################################################################################################## if start!=0 and start%(2000*FLAGS.batch_size)==0: # eval every 400 steps. loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle, f1_a_death, f1_i_death, f1_a_life, f1_i_life, score_penalty = \ do_eval(sess, model, valid,iteration,accusation_num_classes,article_num_classes) accasation_score=((f1_macro_accasation+f1_micro_accasation)/2.0)*100.0 article_score=((f1_a_article+f1_i_aritcle)/2.0)*100.0 score_all=accasation_score+article_score+score_penalty #3ecfDzJbjUvZPUdS # print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f Micro_f1_article:%.3f Macro_f1_deathpenalty:%.3f\t" # "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t" # % (epoch, loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle,f1_a_death, f1_i_death, f1_a_life, f1_i_life)) logger.info("Epoch {} ValidLoss:{} \n Macro_f1_accasation:{} \tMicro_f1_accsastion:{}\tMacro_f1_article:{} \t Micro_f1_article:{} \t Macro_f1_deathpenalty:{} \t" "Micro_f1_deathpenalty:{} \tMacro_f1_lifeimprisonment:{} \tMicro_f1_lifeimprisonment:{}\t".format(epoch, loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle,f1_a_death, f1_i_death, f1_a_life, f1_i_life)) # print("1.Accasation Score:", accasation_score, ";2.Article Score:", article_score, ";3.Penalty Score:",score_penalty, ";Score ALL:", score_all) logger.info("Epoch:{} 1.Accasation Score:{} ;2.Article Score:{} ;3.Penalty Score:{} ;Score ALL:{}\n accasation_score_best{}".format(epoch,accasation_score, article_score, score_penalty, score_all, accasation_score_best)) # save model to checkpoint if accasation_score>accasation_score_best: save_path = FLAGS.ckpt_dir_accu + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once. logger.info("going to save check point for accusation.") saver.save(sess, save_path, global_step=epoch) accasation_score_best=accasation_score if article_score > law_score_best: save_path = FLAGS.ckpt_dir_law + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once. logger.info("going to save check point for article.") saver.save(sess, save_path, global_step=epoch) law_score_best = article_score if score_penalty > imprisonment_score_best: save_path = FLAGS.ckpt_dir_imprision + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once. logger.info("going to save check point for imprisonment.") saver.save(sess, save_path, global_step=epoch) imprisonment_score_best = score_penalty logger.info("Epoch:{} Bestscore:1 Accasation:{} ;2. Article:{} ;3.penalty:{}".format(epoch, accasation_score_best, law_score_best, imprisonment_score_best)) #epoch increment # print("going to increment epoch counter....") logger.info("going to increment epoch counter....") sess.run(model.epoch_increment) # 4.validation print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0)) if epoch % FLAGS.validate_every==0: loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life,score_penalty=\ do_eval(sess,model,valid,iteration,accusation_num_classes,article_num_classes) accasation_score = ((f1_macro_accasation + f1_micro_accasation) / 2.0) * 100.0 article_score = ((f1_a_article + f1_i_aritcle) / 2.0) * 100.0 score_all = accasation_score + article_score + score_penalty print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f\tMicro_f1_article:%.3f\tMacro_f1_deathpenalty:%.3f\t" "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t" % (epoch,loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life)) # print("===>1.Accasation Score:", accasation_score, ";2.Article Score:", article_score,";3.Penalty Score:",score_penalty,";Score ALL:",score_all) logger.info("===>1.Accasation Score: {} ;2.Article Score: {} ;3.Penalty Score:{} ;Score ALL:{}".format(accasation_score, article_score, score_penalty, score_all)) #save model to checkpoint if accasation_score > accasation_score_best: save_path=FLAGS.ckpt_dir_accu+"model.ckpt" print("going to save check point.") saver.save(sess,save_path,global_step=epoch) accasation_score_best = accasation_score if article_score > law_score_best: save_path = FLAGS.ckpt_dir_law + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once. logger.info("going to save check point for article.") saver.save(sess, save_path, global_step=epoch) law_score_best = article_score if score_penalty > imprisonment_score_best: save_path = FLAGS.ckpt_dir_imprision + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once. logger.info("going to save check point for imprisonment.") saver.save(sess, save_path, global_step=epoch) imprisonment_score_best = score_penalty #if (epoch == 2 or epoch == 4 or epoch == 7 or epoch==10 or epoch == 13 or epoch==19): if (epoch == 1 or epoch == 3 or epoch == 6 or epoch == 9 or epoch == 12 or epoch == 18): for i in range(2): print(i, "Going to decay learning rate by half.") sess.run(model.learning_rate_decay_half_op) # 5.最后在测试集上做测试,并报告测试准确率 Testto 0.0 # test_loss,macrof1,microf1 = do_eval(sess, flags.model, testX, testY,iteration) # print("Test Loss:%.3f\tMacro f1:%.3f\tMicro f1:%.3f" % (test_loss,macrof1,microf1)) # print("training completed...") pass
def main(_): print("model:",FLAGS.model) name_scope=FLAGS.model vocab_word2index, accusation_label2index,articles_label2index= create_or_load_vocabulary(FLAGS.data_path,FLAGS.predict_path,FLAGS.traning_data_file,FLAGS.vocab_size,name_scope=name_scope,test_mode=FLAGS.test_mode) deathpenalty_label2index={True:1,False:0} lifeimprisonment_label2index={True:1,False:0} vocab_size = len(vocab_word2index);print("cnn_model.vocab_size:",vocab_size); accusation_num_classes=len(accusation_label2index);article_num_classes=len(articles_label2index) deathpenalty_num_classes=len(deathpenalty_label2index);lifeimprisonment_num_classes=len(lifeimprisonment_label2index) print("accusation_num_classes:",accusation_num_classes);print("article_num_clasess:",article_num_classes) train,valid, test= load_data_multilabel(FLAGS.traning_data_file,FLAGS.valid_data_file,FLAGS.test_data_path,vocab_word2index, accusation_label2index,articles_label2index,deathpenalty_label2index,lifeimprisonment_label2index, FLAGS.sentence_len,name_scope=name_scope,test_mode=FLAGS.test_mode) train_X, train_Y_accusation, train_Y_article, train_Y_deathpenalty, train_Y_lifeimprisonment, train_Y_imprisonment,train_weights_accusation,train_weights_article = train valid_X, valid_Y_accusation, valid_Y_article, valid_Y_deathpenalty, valid_Y_lifeimprisonment, valid_Y_imprisonment,valid_weights_accusation,valid_weights_article = valid test_X, test_Y_accusation, test_Y_article, test_Y_deathpenalty, test_Y_lifeimprisonment, test_Y_imprisonment,test_weights_accusation,test_weights_article = test #print some message for debug purpose print("length of training data:",len(train_X),";valid data:",len(valid_X),";test data:",len(test_X)) print("trainX_[0]:", train_X[0]); train_Y_accusation_short1 = get_target_label_short(train_Y_accusation[0]);train_Y_accusation_short2 = get_target_label_short(train_Y_accusation[1]);train_Y_accusation_short3 = get_target_label_short(train_Y_accusation[2]);train_Y_accusation_short4 = get_target_label_short(train_Y_accusation[20]);train_Y_accusation_short5 = get_target_label_short(train_Y_accusation[200]) train_Y_article_short = get_target_label_short(train_Y_article[0]) print("train_Y_accusation_short:", train_Y_accusation_short1,train_Y_accusation_short2,train_Y_accusation_short3,train_Y_accusation_short4,train_Y_accusation_short4,";train_Y_article_short:",train_Y_article_short) print("train_Y_deathpenalty:",train_Y_deathpenalty[0],";train_Y_lifeimprisonment:",train_Y_lifeimprisonment[0],";train_Y_imprisonment:",train_Y_imprisonment[0]) #2.create session. config=tf.ConfigProto() config.gpu_options.allow_growth=True with tf.Session(config=config) as sess: #Instantiate Model model=HierarchicalAttention( accusation_num_classes,article_num_classes, deathpenalty_num_classes,lifeimprisonment_num_classes,FLAGS.learning_rate,FLAGS.batch_size, FLAGS.decay_steps, FLAGS.decay_rate, FLAGS.sentence_len, FLAGS.num_sentences,vocab_size, FLAGS.embed_size,FLAGS.hidden_size, num_filters=FLAGS.num_filters,model=FLAGS.model,filter_sizes=filter_sizes,stride_length=stride_length,pooling_strategy=FLAGS.pooling_strategy) #Initialize Save saver=tf.train.Saver() if os.path.exists(FLAGS.ckpt_dir+"checkpoint"): print("Restoring Variables from Checkpoint.") saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir)) for i in range(2): #decay learning rate if necessary. print(i,"Going to decay learning rate by half.") sess.run(model.learning_rate_decay_half_op) #sess.run(model.learning_rate_decay_half_op) else: print('Initializing Variables') sess.run(tf.global_variables_initializer()) if FLAGS.use_pretrained_embedding: #load pre-trained word embedding vocabulary_index2word={index:word for word,index in vocab_word2index.items()} assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path,model.Embedding) #assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path2,model.Embedding2) #TODO curr_epoch=sess.run(model.epoch_step) #3.feed data & training number_of_training_data=len(train_X) batch_size=FLAGS.batch_size iteration=0 accasation_score_best=-100 for epoch in range(curr_epoch,FLAGS.num_epochs): loss_total, counter = 0.0, 0 for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)): iteration=iteration+1 if epoch==0 and counter==0: print("trainX[start:end]:",train_X[start:end],"train_X.shape:",train_X.shape) feed_dict = {model.input_x: train_X[start:end],model.input_y_accusation:train_Y_accusation[start:end],model.input_y_article:train_Y_article[start:end], model.input_y_deathpenalty:train_Y_deathpenalty[start:end],model.input_y_lifeimprisonment:train_Y_lifeimprisonment[start:end], model.input_y_imprisonment:train_Y_imprisonment[start:end],model.input_weight_accusation:train_weights_accusation[start:end], model.input_weight_article:train_weights_article[start:end],model.dropout_keep_prob: FLAGS.keep_dropout_rate, model.is_training_flag:FLAGS.is_training_flag} #model.iter: iteration,model.tst: not FLAGS.is_training current_loss,lr,loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,_=\ sess.run([model.loss_val,model.learning_rate,model.loss_accusation,model.loss_article,model.loss_deathpenalty, model.loss_lifeimprisonment,model.loss_imprisonment,model.l2_loss,model.train_op],feed_dict) #model.update_ema loss_total,counter=loss_total+current_loss,counter+1 if counter %20==0: print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f" %(epoch,counter,float(loss_total)/float(counter),lr)) if counter %60==0: print("Loss_accusation:%.3f\tLoss_article:%.3f\tLoss_deathpenalty:%.3f\tLoss_lifeimprisonment:%.3f\tLoss_imprisonment:%.3f\tL2_loss:%.3f\tCurrent_loss:%.3f\t" %(loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,current_loss)) ######################################################################################################## if start!=0 and start%(2000*FLAGS.batch_size)==0: # eval every 400 steps. loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle, f1_a_death, f1_i_death, f1_a_life, f1_i_life, score_penalty = \ do_eval(sess, model, valid,iteration,accusation_num_classes,article_num_classes,accusation_label2index) accasation_score=((f1_macro_accasation+f1_micro_accasation)/2.0)*100.0 article_score=((f1_a_article+f1_i_aritcle)/2.0)*100.0 score_all=accasation_score+article_score+score_penalty #3ecfDzJbjUvZPUdS print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f Micro_f1_article:%.3f Macro_f1_deathpenalty:%.3f\t" "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t" % (epoch, loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle,f1_a_death, f1_i_death, f1_a_life, f1_i_life)) print("1.Accasation Score:", accasation_score, ";2.Article Score:", article_score, ";3.Penalty Score:",score_penalty, ";Score ALL:", score_all) # save model to checkpoint if accasation_score>accasation_score_best: save_path = FLAGS.ckpt_dir + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once. print("going to save check point.") saver.save(sess, save_path, global_step=epoch) accasation_score_best=accasation_score #epoch increment print("going to increment epoch counter....") sess.run(model.epoch_increment) # 4.validation print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0)) if epoch % FLAGS.validate_every==0: loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life,score_penalty=\ do_eval(sess,model,valid,iteration,accusation_num_classes,article_num_classes,accusation_label2index) accasation_score = ((f1_macro_accasation + f1_micro_accasation) / 2.0) * 100.0 article_score = ((f1_a_article + f1_i_aritcle) / 2.0) * 100.0 score_all = accasation_score + article_score + score_penalty print() print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f\tMicro_f1_article:%.3f\tMacro_f1_deathpenalty:%.3f\t" "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t" % (epoch,loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life)) print("===>1.Accasation Score:", accasation_score, ";2.Article Score:", article_score,";3.Penalty Score:",score_penalty,";Score ALL:",score_all) #save model to checkpoint if accasation_score > accasation_score_best: save_path=FLAGS.ckpt_dir+"model.ckpt" print("going to save check point.") saver.save(sess,save_path,global_step=epoch) accasation_score_best = accasation_score #if (epoch == 2 or epoch == 4 or epoch == 7 or epoch==10 or epoch == 13 or epoch==19): #if (epoch == 1 or epoch == 3 or epoch == 6 or epoch == 9 or epoch == 12 or epoch == 18): if (epoch == 0 or epoch == 2 or epoch == 4 or epoch == 6 or epoch == 9 or epoch == 13): for i in range(2): print(i, "Going to decay learning rate by half.") sess.run(model.learning_rate_decay_half_op) # 5.最后在测试集上做测试,并报告测试准确率 Testto 0.0 loss_test, f1_macro_accasation_test, f1_micro_accasation_test, f1_a_article_test, f1_i_aritcle_test, f1_a_death_test, f1_i_death_test, f1_a_life_test, f1_i_life_test, score_penalty_test=\ do_eval(sess, model, test, iteration, accusation_num_classes, article_num_classes, accusation_label2index) print("TEST.FINAL.Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f\tMicro_f1_article:%.3f\tMacro_f1_deathpenalty:%.3f\t" "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t" % (epoch, loss_test, f1_macro_accasation_test, f1_micro_accasation_test, f1_a_article_test, f1_i_aritcle_test, f1_a_death_test, f1_i_death_test, f1_a_life_test, f1_i_life_test)) accasation_score_test = ((f1_macro_accasation_test + f1_micro_accasation_test) / 2.0) * 100.0 article_score_test = ((f1_a_article_test + f1_i_aritcle_test) / 2.0) * 100.0 score_all_test = accasation_score_test + article_score_test + score_penalty_test print("TEST.Accasation Score:", accasation_score_test, ";2.Article Score:", article_score_test, ";3.Penalty Score:",score_penalty_test, ";Score ALL:", score_all_test) #print("Test Loss:%.3f\tMacro f1:%.3f\tMicro f1:%.3f" % (test_loss,macrof1,microf1)) print("training completed...") pass