예제 #1
0
def main(_):
    """ """
    vocab_word2index, _ = create_or_load_vocabulary(FLAGS.data_path, \
            FLAGS.mask_lm_source_file, FLAGS.vocab_size, \
            test_mode=FLAGS.test_mode, tokenize_style=FLAGS.tokenize_style)
    vocab_size = len(vocab_word2index)
    logging.info("bert pretrain vocab size: %d" % vocab_size)
    index2word = {v: k for k, v in vocab_word2index.items()}
    train, valid, test = mask_language_model(FLAGS.mask_lm_source_file, \
            FLAGS.data_path, index2word, max_allow_sentence_length= \
            FLAGS.max_allow_sentence_length, test_mode=FLAGS.test_mode, \
            process_num=FLAGS.process_num)

    train_X, train_y, train_p = train
    valid_X, valid_y, valid_p = valid
    test_X, test_y, test_p = test
    print("train_X:{}, train_y:{}, train_p:{}".format(train_X.shape, \
            train_y.shape, train_p.shape))

    #1.create session
    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
def main(_):
    vocab_word2index, _ = create_or_load_vocabulary(
        FLAGS.data_path,
        FLAGS.mask_lm_source_file,
        FLAGS.vocab_size,
        test_mode=FLAGS.test_mode,
        tokenize_style=FLAGS.tokenize_style)
    vocab_size = len(vocab_word2index)
    print("bert_pertrain_lm.vocab_size:", vocab_size)
    index2word = {v: k for k, v in vocab_word2index.items()}
    #train,valid,test=mask_language_model(FLAGS.mask_lm_source_file,FLAGS.data_path,index2word,max_allow_sentence_length=FLAGS.max_allow_sentence_length,test_mode=FLAGS.test_mode)
    train, valid, test = mask_language_model(
        FLAGS.mask_lm_source_file,
        FLAGS.data_path,
        index2word,
        max_allow_sentence_length=FLAGS.max_allow_sentence_length,
        test_mode=FLAGS.test_mode,
        process_num=FLAGS.process_num)

    train_X, train_y, train_p = train
    valid_X, valid_y, valid_p = valid
    test_X, test_y, test_p = test

    print("length of training data:", train_X.shape, ";train_Y:",
          train_y.shape, ";train_p:", train_p.shape, ";valid data:",
          valid_X.shape, ";test data:", test_X.shape)
    # 1.create session.
    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    with tf.Session(config=gpu_config) as sess:
        #Instantiate Model
        config = set_config(FLAGS, vocab_size, vocab_size)
        model = BertModel(config)
        #Initialize Save
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint.")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            for i in range(2):  #decay learning rate if necessary.
                print(i, "Going to decay learning rate by half.")
                sess.run(model.learning_rate_decay_half_op)
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_pretrained_embedding:
                vocabulary_index2word = {
                    index: word
                    for word, index in vocab_word2index.items()
                }
                assign_pretrained_word_embedding(
                    sess, vocabulary_index2word, vocab_size,
                    FLAGS.word2vec_model_path, model.embedding,
                    config.d_model)  # assign pretrained word embeddings
        curr_epoch = sess.run(model.epoch_step)

        # 2.feed data & training
        number_of_training_data = len(train_X)
        print("number_of_training_data:", number_of_training_data)
        batch_size = FLAGS.batch_size
        iteration = 0
        score_best = -100
        for epoch in range(curr_epoch, FLAGS.num_epochs):
            loss_total_lm, counter = 0.0, 0
            for start, end in zip(
                    range(0, number_of_training_data, batch_size),
                    range(batch_size, number_of_training_data, batch_size)):
                iteration = iteration + 1
                if epoch == 0 and counter == 0:
                    print("trainX[start:end]:", train_X[start:end],
                          "train_X.shape:", train_X.shape)
                feed_dict = {
                    model.x_mask_lm: train_X[start:end],
                    model.y_mask_lm: train_y[start:end],
                    model.p_mask_lm: train_p[start:end],
                    model.dropout_keep_prob: FLAGS.dropout_keep_prob
                }
                current_loss_lm, lr, l2_loss, _ = sess.run([
                    model.loss_val_lm, model.learning_rate, model.l2_loss_lm,
                    model.train_op_lm
                ], feed_dict)
                loss_total_lm, counter = loss_total_lm + current_loss_lm, counter + 1
                if counter % 30 == 0:
                    print(
                        "%d\t%d\tLearning rate:%.5f\tLoss_lm:%.3f\tCurrent_loss_lm:%.3f\tL2_loss:%.3f\t"
                        % (epoch, counter, lr, float(loss_total_lm) /
                           float(counter), current_loss_lm, l2_loss))
                if start != 0 and start % (800 *
                                           FLAGS.batch_size) == 0:  # epoch!=0
                    loss_valid, acc_valid = do_eval(sess, model, valid,
                                                    batch_size)
                    print(
                        "%d\tValid.Epoch %d ValidLoss:%.3f\tAcc_valid:%.3f\t" %
                        (counter, epoch, loss_valid, acc_valid * 100))
                    # save model to checkpoint
                    if acc_valid > score_best:
                        save_path = FLAGS.ckpt_dir + "model.ckpt"
                        print("going to save check point.")
                        saver.save(sess, save_path, global_step=epoch)
                        score_best = acc_valid
            sess.run(model.epoch_increment)
예제 #3
0
def main(_):
    # 1.load vocabulary of token from cache file save from pre-trained stage; load label dict from training file; print some message.
    vocab_word2index, _= create_or_load_vocabulary(FLAGS.data_path,FLAGS.training_data_file,FLAGS.vocab_size,test_mode=FLAGS.test_mode,tokenize_style=FLAGS.tokenize_style,model_name=FLAGS.model_name)
    label2index=get_lable2index(FLAGS.data_path,FLAGS.training_data_file, tokenize_style=FLAGS.tokenize_style)
    vocab_size = len(vocab_word2index);print("cnn_model.vocab_size:",vocab_size);num_classes=len(label2index);print("num_classes:",num_classes)
    iii=0;iii/0 # todo test first two function, then continue
    # load training data.
    train,valid, test= load_data_multilabel(FLAGS.data_path,FLAGS.training_data_file,FLAGS.valid_data_file,FLAGS.test_data_file,vocab_word2index,label2index,FLAGS.sequence_length,
                                            process_num=FLAGS.process_num,test_mode=FLAGS.test_mode,tokenize_style=FLAGS.tokenize_style)
    train_X, train_Y= train
    valid_X, valid_Y= valid
    test_X,test_Y = test
    print("test_model:",FLAGS.test_mode,";length of training data:",train_X.shape,";valid data:",valid_X.shape,";test data:",test_X.shape,";train_Y:",train_Y.shape)
    # 2.create session.
    gpu_config=tf.ConfigProto()
    gpu_config.gpu_options.allow_growth=True
    with tf.Session(config=gpu_config) as sess:
        #Instantiate Model
        config=set_config(FLAGS,num_classes,vocab_size)
        model=BertModel(config)
        #Initialize Save
        saver=tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
            print("Restoring Variables from Checkpoint.")
            sess.run(tf.global_variables_initializer())
            for i in range(6): #decay learning rate if necessary.
                print(i,"Going to decay learning rate by a factor of "+str(FLAGS.decay_rate))
                sess.run(model.learning_rate_decay_half_op)
            # restore those variables that names and shapes exists in your model from checkpoint. for detail check: https://gist.github.com/iganichev/d2d8a0b1abc6b15d4a07de83171163d4
            optimistic_restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir)) #saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            print('Initializing Variables as model instance is not exist.')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_pretrained_embedding:
                vocabulary_index2word={index:word for word,index in vocab_word2index.items()}
                assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size,FLAGS.word2vec_model_path,model.embedding,config.d_model) # assign pretrained word embeddings
        curr_epoch=sess.run(model.epoch_step)
        # 3.feed data & training
        number_of_training_data=len(train_X)
        batch_size=FLAGS.batch_size
        iteration=0
        score_best=-100
        f1_score=0
        epoch=0
        for epoch in range(curr_epoch,FLAGS.num_epochs):
            loss_total, counter =  0.0, 0
            for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)):
                iteration=iteration+1
                if epoch==0 and counter==0:
                    print("trainX[start:end]:",train_X[start:end],"train_X.shape:",train_X.shape)
                feed_dict = {model.input_x: train_X[start:end],model.input_y:train_Y[start:end],model.dropout_keep_prob: FLAGS.dropout_keep_prob}
                current_loss,lr,l2_loss,_=sess.run([model.loss_val,model.learning_rate,model.l2_loss,model.train_op],feed_dict)
                loss_total,counter=loss_total+current_loss,counter+1
                if counter %30==0:
                    print("Learning rate:%.7f\tLoss:%.3f\tCurrent_loss:%.3f\tL2_loss%.3f\t"%(lr,float(loss_total)/float(counter),current_loss,l2_loss))
                if start!=0 and start%(4000*FLAGS.batch_size)==0:
                    loss_valid, f1_macro_valid, f1_micro_valid= do_eval(sess, model, valid,num_classes,label2index)
                    f1_score_valid=((f1_macro_valid+f1_micro_valid)/2.0) #*100.0
                    print("Valid.Epoch %d ValidLoss:%.3f\tF1_score_valid:%.3f\tMacro_f1:%.3f\tMicro_f1:%.3f\t" % (epoch, loss_valid, f1_score_valid, f1_macro_valid, f1_micro_valid))

                    # save model to checkpoint
                    if f1_score_valid>score_best:
                        save_path = FLAGS.ckpt_dir_save + "model.ckpt"
                        print("going to save check point.")
                        saver.save(sess, save_path, global_step=epoch)
                        score_best=f1_score_valid
            #epoch increment
            print("going to increment epoch counter....")
            sess.run(model.epoch_increment)

            # 4.validation
            print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0))
            if epoch % FLAGS.validate_every==0:
                loss_valid,f1_macro_valid2,f1_micro_valid2=do_eval(sess,model,valid,num_classes,label2index)
                f1_score_valid2 = ((f1_macro_valid2 + f1_micro_valid2) / 2.0) #* 100.0
                print("Valid.Epoch %d ValidLoss:%.3f\tF1 score:%.3f\tMacro_f1:%.3f\tMicro_f1:%.3f\t"% (epoch,loss_valid,f1_score_valid2,f1_macro_valid2,f1_micro_valid2))
                #save model to checkpoint
                if f1_score_valid2 > score_best:
                    save_path=FLAGS.ckpt_dir_save+"model.ckpt"
                    print("going to save check point.")
                    saver.save(sess,save_path,global_step=epoch)
                    score_best = f1_score_valid2
            if (epoch == 2 or epoch == 4 or epoch == 6 or epoch == 9 or epoch == 13):
                for i in range(1):
                    print(i, "Going to decay learning rate by half.")
                    sess.run(model.learning_rate_decay_half_op)

        # 5.report on test set
        loss_test, f1_macro_test, f1_micro_test=do_eval(sess, model, test,num_classes, label2index)
        f1_score_test=((f1_macro_test + f1_micro_test) / 2.0) * 100.0
        print("Test.Epoch %d TestLoss:%.3f\tF1_score:%.3f\tMacro_f1:%.3f\tMicro_f1:%.3f\t" % (epoch, loss_test, f1_score_test,f1_macro_test, f1_micro_test))
        print("training completed...")
예제 #4
0
def main(_):
    vocab_word2index, label2index = create_or_load_vocabulary(
        FLAGS.data_path,
        FLAGS.training_data_file,
        FLAGS.vocab_size,
        test_mode=FLAGS.test_mode,
        tokenize_style=FLAGS.tokenize_style,
        model_name='transfomer')
    vocab_size = len(vocab_word2index)
    print("cnn_model.vocab_size:", vocab_size)
    num_classes = len(label2index)
    print("num_classes:", num_classes)
    train, valid, test = load_data_multilabel(
        FLAGS.data_path,
        FLAGS.training_data_file,
        FLAGS.valid_data_file,
        FLAGS.test_data_file,
        vocab_word2index,
        label2index,
        FLAGS.sequence_length,
        process_num=FLAGS.process_num,
        test_mode=FLAGS.test_mode,
        tokenize_style=FLAGS.tokenize_style,
        model_name='transfomer')
    train_X, train_Y = train
    valid_X, valid_Y = valid
    test_X, test_Y = test
    print("Test_mode:", FLAGS.test_mode, ";length of training data:",
          train_X.shape, ";valid data:", valid_X.shape, ";test data:",
          test_X.shape, ";train_Y:", train_Y.shape)
    # 1.create session.
    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    with tf.Session(config=gpu_config) as sess:
        #Instantiate Model
        config = set_config(FLAGS, num_classes, vocab_size)
        model = TransformerModel(config)
        #Initialize Save
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint.")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            #for i in range(2): #decay learning rate if necessary.
            #    print(i,"Going to decay learning rate by half.")
            #    sess.run(model.learning_rate_decay_half_op)
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_pretrained_embedding:
                vocabulary_index2word = {
                    index: word
                    for word, index in vocab_word2index.items()
                }
                assign_pretrained_word_embedding(
                    sess, vocabulary_index2word, vocab_size,
                    FLAGS.word2vec_model_path, model.embedding,
                    config.d_model)  # assign pretrained word embeddings
        curr_epoch = sess.run(model.epoch_step)
        # 2.feed data & training
        number_of_training_data = len(train_X)
        batch_size = FLAGS.batch_size
        iteration = 0
        score_best = -100
        f1_score = 0
        for epoch in range(curr_epoch, FLAGS.num_epochs):
            loss_total, counter = 0.0, 0
            for start, end in zip(
                    range(0, number_of_training_data, batch_size),
                    range(batch_size, number_of_training_data, batch_size)):
                iteration = iteration + 1
                if epoch == 0 and counter == 0:
                    print("trainX[start:end]:", train_X[start:end],
                          "train_X.shape:", train_X.shape)
                feed_dict = {
                    model.input_x: train_X[start:end],
                    model.input_y: train_Y[start:end],
                    model.dropout_keep_prob: FLAGS.dropout_keep_prob
                }
                current_loss, lr, l2_loss, _ = sess.run([
                    model.loss_val, model.learning_rate, model.l2_loss,
                    model.train_op
                ], feed_dict)
                loss_total, counter = loss_total + current_loss, counter + 1
                if counter % 30 == 0:
                    print(
                        "Learning rate:%.5f\tLoss:%.3f\tCurrent_loss:%.3f\tL2_loss%.3f\t"
                        % (lr, float(loss_total) / float(counter),
                           current_loss, l2_loss))
                if start != 0 and start % (3000 * FLAGS.batch_size) == 0:
                    loss_valid, f1_macro_valid, f1_micro_valid = do_eval(
                        sess, model, valid, num_classes, label2index)
                    f1_score_valid = (
                        (f1_macro_valid + f1_micro_valid) / 2.0) * 100.0
                    print(
                        "Valid.Epoch %d ValidLoss:%.3f\tF1_score_valid:%.3f\tMacro_f1:%.3f\tMicro_f1:%.3f\t"
                        % (epoch, loss_valid, f1_score_valid, f1_macro_valid,
                           f1_micro_valid))

                    # save model to checkpoint
                    if f1_score_valid > score_best:
                        save_path = FLAGS.ckpt_dir + "model.ckpt"
                        print("going to save check point.")
                        saver.save(sess, save_path, global_step=epoch)
                        score_best = f1_score_valid
            #epoch increment
            print("going to increment epoch counter....")
            sess.run(model.epoch_increment)

            # 4.validation
            print(epoch, FLAGS.validate_every,
                  (epoch % FLAGS.validate_every == 0))
            if epoch % FLAGS.validate_every == 0:
                loss_valid, f1_macro_valid2, f1_micro_valid2 = do_eval(
                    sess, model, valid, num_classes, label2index)
                f1_score_valid2 = (
                    (f1_macro_valid2 + f1_micro_valid2) / 2.0)  #* 100.0
                print(
                    "Valid.Epoch %d ValidLoss:%.3f\tF1 score:%.3f\tMacro_f1:%.3f\tMicro_f1:%.3f\t"
                    % (epoch, loss_valid, f1_score_valid2, f1_macro_valid2,
                       f1_micro_valid2))
                #save model to checkpoint
                if f1_score_valid2 > score_best:
                    save_path = FLAGS.ckpt_dir + "model.ckpt"
                    print("going to save check point.")
                    saver.save(sess, save_path, global_step=epoch)
                    score_best = f1_score_valid2
            if (epoch == 2 or epoch == 4 or epoch == 6 or epoch == 9
                    or epoch == 13):
                for i in range(1):
                    print(i, "Going to decay learning rate by half.")
                    sess.run(model.learning_rate_decay_half_op)

        # 5.最后在测试集上做测试,并报告测试准确率 Testto 0.0
        loss_test, f1_macro_test, f1_micro_test = do_eval(
            sess, model, test, num_classes, label2index)
        f1_score_test = ((f1_macro_test + f1_micro_test) / 2.0)  #* 100.0
        print(
            "Test.Epoch %d TestLoss:%.3f\tF1_score:%.3f\tMacro_f1:%.3f\tMicro_f1:%.3f\t"
            % (epoch, loss_test, f1_score_test, f1_macro_test, f1_micro_test))
        print("training completed...")
예제 #5
0
def main(_):
    print("model:",FLAGS.model)
    name_scope=FLAGS.model
    vocab_word2index, accusation_label2index,articles_label2index= create_or_load_vocabulary(FLAGS.data_path,FLAGS.predict_path,FLAGS.traning_data_file,FLAGS.vocab_size,name_scope=name_scope,test_mode=FLAGS.test_mode,tokenize_style=FLAGS.tokenize_style) #tokenize_style=FLAGS.tokenize_style
    deathpenalty_label2index={True:1,False:0}
    lifeimprisonment_label2index={True:1,False:0}
    vocab_size = len(vocab_word2index);print("cnn_model.vocab_size:",vocab_size);
    accusation_num_classes=len(accusation_label2index);article_num_classes=len(articles_label2index)
    deathpenalty_num_classes=len(deathpenalty_label2index);lifeimprisonment_num_classes=len(lifeimprisonment_label2index)
    print("accusation_num_classes:",accusation_num_classes);print("article_num_clasess:",article_num_classes)
    train,valid, test= load_data_multilabel(FLAGS.traning_data_file,FLAGS.valid_data_file,FLAGS.test_data_path,vocab_word2index, accusation_label2index,articles_label2index,deathpenalty_label2index,lifeimprisonment_label2index,
                                      FLAGS.sentence_len,name_scope=name_scope,test_mode=FLAGS.test_mode,tokenize_style=FLAGS.tokenize_style) #,tokenize_style=FLAGS.tokenize_style
    train_X, train_feature_X, train_Y_accusation, train_Y_article, train_Y_deathpenalty, train_Y_lifeimprisonment, train_Y_imprisonment,train_weights_accusation,train_weights_article = train
    valid_X, valid_feature_X, valid_Y_accusation, valid_Y_article, valid_Y_deathpenalty, valid_Y_lifeimprisonment, valid_Y_imprisonment,valid_weights_accusation,valid_weights_article = valid
    test_X, test_feature_X, test_Y_accusation, test_Y_article, test_Y_deathpenalty, test_Y_lifeimprisonment, test_Y_imprisonment,test_weights_accusation,test_weights_article = test
    #print some message for debug purpose
    feature_length=len(train_feature_X[0])
    print("length of training data:",len(train_X),";valid data:",len(valid_X),";test data:",len(test_X),";feature_length:",feature_length)

    print("trainX_[0]:", train_X[0]); print("train_feature_X[0]:",train_feature_X[0])

    train_Y_accusation_short1 = get_target_label_short(train_Y_accusation[0]);train_Y_accusation_short2 = get_target_label_short(train_Y_accusation[1]);train_Y_accusation_short3 = get_target_label_short(train_Y_accusation[2]);train_Y_accusation_short4 = get_target_label_short(train_Y_accusation[20]);train_Y_accusation_short5 = get_target_label_short(train_Y_accusation[200])
    train_Y_article_short = get_target_label_short(train_Y_article[0])
    print("train_Y_accusation_short:", train_Y_accusation_short1,train_Y_accusation_short2,train_Y_accusation_short3,train_Y_accusation_short4,train_Y_accusation_short4,";train_Y_article_short:",train_Y_article_short)
    print("train_Y_deathpenalty:",train_Y_deathpenalty[0],";train_Y_lifeimprisonment:",train_Y_lifeimprisonment[0],";train_Y_imprisonment:",train_Y_imprisonment[0])
    #2.create session.
    config=tf.ConfigProto()
    config.gpu_options.allow_growth=True
    with tf.Session(config=config) as sess:
        #Instantiate Model
        model=HierarchicalAttention( accusation_num_classes,article_num_classes, deathpenalty_num_classes,lifeimprisonment_num_classes,FLAGS.learning_rate,FLAGS.batch_size,
                            FLAGS.decay_steps, FLAGS.decay_rate, FLAGS.sentence_len, FLAGS.num_sentences,vocab_size, FLAGS.embed_size,FLAGS.hidden_size,
                                     num_filters=FLAGS.num_filters,model=FLAGS.model,filter_sizes=filter_sizes,stride_length=stride_length,pooling_strategy=FLAGS.pooling_strategy,feature_length=feature_length)
        #Initialize Save
        saver=tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
            print("Restoring Variables from Checkpoint.")
            saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            for i in range(2): #decay learning rate if necessary.
                print(i,"Going to decay learning rate by half.")
                sess.run(model.learning_rate_decay_half_op)
                #sess.run(model.learning_rate_decay_half_op)

        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_pretrained_embedding: #load pre-trained word embedding
                vocabulary_index2word={index:word for word,index in vocab_word2index.items()}
                assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path,model.Embedding)
                #assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, model,FLAGS.word2vec_model_path2,model.Embedding2) #TODO

        curr_epoch=sess.run(model.epoch_step)
        #3.feed data & training
        number_of_training_data=len(train_X)
        batch_size=FLAGS.batch_size
        iteration=0
        accasation_score_best=-100


        for epoch in range(curr_epoch,FLAGS.num_epochs):
            loss_total, counter =  0.0, 0
            for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)):
                iteration=iteration+1
                if epoch==0 and counter==0:
                    print("trainX[start:end]:",train_X[start:end],"train_X.shape:",train_X.shape)
                feed_dict = {model.input_x: train_X[start:end],model.input_feature: train_feature_X[start:end],model.input_y_accusation:train_Y_accusation[start:end],model.input_y_article:train_Y_article[start:end],
                             model.input_y_deathpenalty:train_Y_deathpenalty[start:end],model.input_y_lifeimprisonment:train_Y_lifeimprisonment[start:end],
                             model.input_y_imprisonment:train_Y_imprisonment[start:end],model.input_weight_accusation:train_weights_accusation[start:end],
                             model.input_weight_article:train_weights_article[start:end],model.dropout_keep_prob: FLAGS.keep_dropout_rate,
                             model.is_training_flag:FLAGS.is_training_flag}
                             #model.iter: iteration,model.tst: not FLAGS.is_training
                current_loss,lr,loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,_=\
                    sess.run([model.loss_val,model.learning_rate,model.loss_accusation,model.loss_article,model.loss_deathpenalty,
                                         model.loss_lifeimprisonment,model.loss_imprisonment,model.l2_loss,model.train_op],feed_dict) #model.update_ema
                loss_total,counter=loss_total+current_loss,counter+1
                if counter %20==0:
                    print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f" %(epoch,counter,float(loss_total)/float(counter),lr))
                if counter %60==0:
                    print("Loss_accusation:%.3f\tLoss_article:%.3f\tLoss_deathpenalty:%.3f\tLoss_lifeimprisonment:%.3f\tLoss_imprisonment:%.3f\tL2_loss:%.3f\tCurrent_loss:%.3f\t"
                          %(loss_accusation,loss_article,loss_deathpenalty,loss_lifeimprisonment,loss_imprisonment,l2_loss,current_loss))
                ########################################################################################################
                if start!=0 and start%(3900*FLAGS.batch_size)==0: # eval every 400 steps.
                    loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle, f1_a_death, f1_i_death, f1_a_life, f1_i_life, score_penalty = \
                        do_eval(sess, model, valid,iteration,accusation_num_classes,article_num_classes,accusation_label2index)
                    accasation_score=((f1_macro_accasation+f1_micro_accasation)/2.0)*100.0
                    article_score=((f1_a_article+f1_i_aritcle)/2.0)*100.0
                    score_all=accasation_score+article_score+score_penalty #3ecfDzJbjUvZPUdS
                    print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f Micro_f1_article:%.3f Macro_f1_deathpenalty:%.3f\t"
                                "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t"
                                % (epoch, loss, f1_macro_accasation, f1_micro_accasation, f1_a_article, f1_i_aritcle,f1_a_death, f1_i_death, f1_a_life, f1_i_life))
                    print("1.Accasation Score:", accasation_score, ";2.Article Score:", article_score, ";3.Penalty Score:",score_penalty, ";Score ALL:", score_all)
                    # save model to checkpoint
                    if accasation_score>accasation_score_best:
                        save_path = FLAGS.ckpt_dir + "model.ckpt" #TODO temp remove==>only save checkpoint for each epoch once.
                        print("going to save check point.")
                        saver.save(sess, save_path, global_step=epoch)
                        accasation_score_best=accasation_score
            #epoch increment
            print("going to increment epoch counter....")
            sess.run(model.epoch_increment)

            # 4.validation
            print(epoch,FLAGS.validate_every,(epoch % FLAGS.validate_every==0))
            if epoch % FLAGS.validate_every==0:
                loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life,score_penalty=\
                    do_eval(sess,model,valid,iteration,accusation_num_classes,article_num_classes,accusation_label2index)
                accasation_score = ((f1_macro_accasation + f1_micro_accasation) / 2.0) * 100.0
                article_score = ((f1_a_article + f1_i_aritcle) / 2.0) * 100.0
                score_all = accasation_score + article_score + score_penalty
                print()
                print("Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f\tMicro_f1_article:%.3f\tMacro_f1_deathpenalty:%.3f\t"
                      "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t"
                      % (epoch,loss,f1_macro_accasation,f1_micro_accasation,f1_a_article,f1_i_aritcle,f1_a_death,f1_i_death,f1_a_life,f1_i_life))
                print("===>1.Accasation Score:", accasation_score, ";2.Article Score:", article_score,";3.Penalty Score:",score_penalty,";Score ALL:",score_all)

                #save model to checkpoint
                if accasation_score > accasation_score_best:
                    save_path=FLAGS.ckpt_dir+"model.ckpt"
                    print("going to save check point.")
                    saver.save(sess,save_path,global_step=epoch)
                    accasation_score_best = accasation_score
            #if (epoch == 2 or epoch == 4 or epoch == 7 or epoch==10 or epoch == 13  or epoch==19):
            #if (epoch == 1 or epoch == 3 or epoch == 6 or epoch == 9 or epoch == 12 or epoch == 18):
            if (epoch == 0 or epoch == 2 or epoch == 4 or epoch == 6 or epoch == 9 or epoch == 13):
                for i in range(2):
                    print(i, "Going to decay learning rate by half.")
                    sess.run(model.learning_rate_decay_half_op)

        # 5.最后在测试集上做测试,并报告测试准确率 Testto 0.0
        loss_test, f1_macro_accasation_test, f1_micro_accasation_test, f1_a_article_test, f1_i_aritcle_test, f1_a_death_test, f1_i_death_test, f1_a_life_test, f1_i_life_test, score_penalty_test=\
            do_eval(sess, model, test, iteration, accusation_num_classes, article_num_classes, accusation_label2index)
        print("TEST.FINAL.Epoch %d ValidLoss:%.3f\tMacro_f1_accasation:%.3f\tMicro_f1_accsastion:%.3f\tMacro_f1_article:%.3f\tMicro_f1_article:%.3f\tMacro_f1_deathpenalty:%.3f\t"
                    "Micro_f1_deathpenalty:%.3f\tMacro_f1_lifeimprisonment:%.3f\tMicro_f1_lifeimprisonment:%.3f\t"
                    % (epoch, loss_test, f1_macro_accasation_test, f1_micro_accasation_test, f1_a_article_test, f1_i_aritcle_test, f1_a_death_test,
                       f1_i_death_test, f1_a_life_test, f1_i_life_test))
        accasation_score_test = ((f1_macro_accasation_test + f1_micro_accasation_test) / 2.0) * 100.0
        article_score_test = ((f1_a_article_test + f1_i_aritcle_test) / 2.0) * 100.0
        score_all_test = accasation_score_test + article_score_test + score_penalty_test
        print("TEST.Accasation Score:", accasation_score_test, ";2.Article Score:", article_score_test, ";3.Penalty Score:",score_penalty_test, ";Score ALL:", score_all_test)

        #print("Test Loss:%.3f\tMacro f1:%.3f\tMicro f1:%.3f" % (test_loss,macrof1,microf1))
        print("training completed...")
    pass
예제 #6
0
    :param y_mask_lm_train:
    :param p_mask_lm_train:
    :return:
    """
    return np.array(X_mask_lm_train), np.array(y_mask_lm_train), np.array(
        p_mask_lm_train)


if __name__ == "__main__":

    source_file = '/data/xuht/ChineseSTSListCorpus/corpus.txt'
    data_path = '/data/xuht/ChineseSTSListCorpus/bert/'
    traning_data_path = source_file
    valid_data_path = source_file
    test_data_path = valid_data_path
    vocab_size = 500000
    process_num = 5
    test_mode = True
    sentence_len = 200
    vocab_word2index, label2index = create_or_load_vocabulary(
        data_path, traning_data_path, vocab_size, test_mode=False)
    index2word = {v: k for k, v in vocab_word2index.items()}
    train, valid, test = mask_language_model(
        source_file, data_path, index2word,
        max_allow_sentence_length=10)  #print("X_mask_lm:",X_mask_lm)

    train_X, train_y, train_p = train
    valid_X, valid_y, valid_p = valid
    test_X, test_y, test_p = test

    print(train_X.shape, train_y.shape, train_p.shape)