def do_eval(sess,model,valid,iteration,accusation_num_classes):
    valid_X, valid_Y_accusation = get_part_validation_data(valid)
    number_examples = len(valid_X)
    print("number_examples:", number_examples)
    batch_size = FLAGS.batch_size
    label_dict_accusation = init_label_dict(accusation_num_classes)

    score = []
    loss_total = 0
    counter = 0

    for start, end in zip(range(0, number_examples, batch_size), range(batch_size, number_examples, batch_size)):
        feed_dict = {model.input_x: valid_X[start:end],
                     model.input_y_accusation: valid_Y_accusation[start:end],
                     model.dropout_keep_prob: 1.0, model.is_training_flag: False}

        loss, similarity_scores = sess.run([model.loss_val, model.logits_accusation], feed_dict)
        if start == 0:
            print(similarity_scores)
        score.extend(similarity_scores)
        counter = counter + 1
        loss_total= loss_total+ loss

    label_dict_accusation = compute_confuse_matrix_batch(valid_Y_accusation, score, label_dict_accusation, name='accusation')
    #compute f1_micro & f1_macro for accusation,article,deathpenalty,lifeimprisonment
    #f1_micro_accusation,f1_macro_accusation=compute_micro_macro(label_dict_accusation)

    f1_micro, pre_micro, recal_micro, f1_macro, pre_macro, recal_macro = compute_micro_macro(label_dict_accusation)
    print("f1_micro_accusation:", f1_micro, "  pre_micro: ", pre_micro, "  recal_micro: ", recal_micro)
    print("f1_macro_accusation:", f1_macro, "  pre_macro: ", pre_macro, "  recal_macro: ", recal_macro)

    loss2 = loss_total / float(counter)
    print('******* MAP and MRR *************')
    compute_MAP_MRR(score, valid_Y_accusation)
    return loss2, f1_macro, f1_micro

    return eval_loss/float(eval_counter+small_value), f1_macro_accusation, f1_micro_accusation
コード例 #2
0
def do_eval(sess,model,valid,iteration,accusation_num_classes,article_num_classes):
    valid_X, valid_Y_accusation, valid_Y_article, valid_Y_deathpenalty, valid_Y_lifeimprisonment, valid_Y_imprisonment,_,_=get_part_validation_data(valid)
    number_examples=len(valid_X)
    print("number_examples:",number_examples)
    eval_loss,eval_counter=0.0,0
    batch_size=FLAGS.batch_size
    label_dict_accusation=init_label_dict(accusation_num_classes)
    label_dict_article=init_label_dict(article_num_classes)
    label_dict_deathpenalty = init_label_dict(2)
    label_dict_lifeimprisonment = init_label_dict(2)

    eval_macro_f1_accusation, eval_micro_f1_accusation,eval_r2_score_imprisonment,eval_macro_f1_article,eval_micro_f1_article,eval_r2_score_imprisonment = 0.0,0.0,0.0,0.0,0.0,0.0
    eval_penalty_score=0.0
    for start,end in zip(range(0,number_examples,batch_size),range(batch_size,number_examples,batch_size)):
        feed_dict = {model.input_x: valid_X[start:end],
                     model.input_y_accusation:valid_Y_accusation[start:end],model.input_y_article:valid_Y_article[start:end],
                     model.input_y_deathpenalty:valid_Y_deathpenalty[start:end],model.input_y_lifeimprisonment:valid_Y_lifeimprisonment[start:end],
                     model.input_y_imprisonment:valid_Y_imprisonment[start:end],model.input_weight_accusation:
                         [1.0 for i in range(batch_size)],model.input_weight_article:[1.0 for i in range(batch_size)],
                     model.dropout_keep_prob: 1.0,model.is_training_flag:False}#,model.iter: iteration,model.tst: True}
        curr_eval_loss, logits_accusation,logits_article,logits_deathpenalty,logits_lifeimprisonment,logits_imprisonment= sess.run(
                        [model.loss_val,model.logits_accusation,model.logits_article,model.logits_deathpenalty,model.logits_lifeimprisonment,model.logits_imprisonment],feed_dict)#logits:[batch_size,label_size]
        #compute confuse matrix for accusation,relevant article,death penalty,life imprisonment
        label_dict_accusation=compute_confuse_matrix_batch(valid_Y_accusation[start:end],logits_accusation,label_dict_accusation,name='accusation')
        label_dict_article = compute_confuse_matrix_batch(valid_Y_article[start:end],logits_article,label_dict_article,name='article')
        label_dict_deathpenalty = compute_confuse_matrix_batch(valid_Y_deathpenalty[start:end],logits_deathpenalty,label_dict_deathpenalty,name='deathpenalty')
        label_dict_lifeimprisonment = compute_confuse_matrix_batch(valid_Y_lifeimprisonment[start:end],logits_lifeimprisonment,label_dict_lifeimprisonment,name='lifeimprisionment')
        penalty_score=compute_penalty_score_batch(valid_Y_deathpenalty[start:end], logits_deathpenalty,
                                                    valid_Y_lifeimprisonment[start:end], logits_lifeimprisonment,valid_Y_imprisonment, logits_imprisonment)
        eval_penalty_score=eval_penalty_score+penalty_score
        eval_loss=eval_loss+curr_eval_loss
        eval_counter=eval_counter+1

    #compute f1_micro & f1_macro for accusation,article,deathpenalty,lifeimprisonment
    f1_micro_accusation,f1_macro_accusation=compute_micro_macro(label_dict_accusation)
    f1_micro_article, f1_macro_article = compute_micro_macro(label_dict_article)
    f1_micro_deathpenalty, f1_macro_deathpenalty = compute_micro_macro(label_dict_deathpenalty)
    f1_micro_lifeimprisonment, f1_macro_lifeimprisonment = compute_micro_macro(label_dict_lifeimprisonment)
    print("f1_micro_accusation:",f1_micro_accusation,";f1_macro_accusation:",f1_macro_accusation)
    return eval_loss/float(eval_counter+small_value),f1_macro_accusation,f1_micro_accusation, f1_macro_article, f1_micro_article, \
           f1_macro_deathpenalty, f1_micro_deathpenalty,f1_macro_lifeimprisonment, f1_micro_lifeimprisonment,eval_penalty_score/float(eval_counter+small_value)