Пример #1
0
def main(_):
    """ """
    vocab_word2index, _ = create_or_load_vocabulary(FLAGS.data_path, \
            FLAGS.mask_lm_source_file, FLAGS.vocab_size, \
            test_mode=FLAGS.test_mode, tokenize_style=FLAGS.tokenize_style)
    vocab_size = len(vocab_word2index)
    logging.info("bert pretrain vocab size: %d" % vocab_size)
    index2word = {v: k for k, v in vocab_word2index.items()}
    train, valid, test = mask_language_model(FLAGS.mask_lm_source_file, \
            FLAGS.data_path, index2word, max_allow_sentence_length= \
            FLAGS.max_allow_sentence_length, test_mode=FLAGS.test_mode, \
            process_num=FLAGS.process_num)

    train_X, train_y, train_p = train
    valid_X, valid_y, valid_p = valid
    test_X, test_y, test_p = test
    print("train_X:{}, train_y:{}, train_p:{}".format(train_X.shape, \
            train_y.shape, train_p.shape))

    #1.create session
    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
def main(_):
    vocab_word2index, _ = create_or_load_vocabulary(
        FLAGS.data_path,
        FLAGS.mask_lm_source_file,
        FLAGS.vocab_size,
        test_mode=FLAGS.test_mode,
        tokenize_style=FLAGS.tokenize_style)
    vocab_size = len(vocab_word2index)
    print("bert_pertrain_lm.vocab_size:", vocab_size)
    index2word = {v: k for k, v in vocab_word2index.items()}
    #train,valid,test=mask_language_model(FLAGS.mask_lm_source_file,FLAGS.data_path,index2word,max_allow_sentence_length=FLAGS.max_allow_sentence_length,test_mode=FLAGS.test_mode)
    train, valid, test = mask_language_model(
        FLAGS.mask_lm_source_file,
        FLAGS.data_path,
        index2word,
        max_allow_sentence_length=FLAGS.max_allow_sentence_length,
        test_mode=FLAGS.test_mode,
        process_num=FLAGS.process_num)

    train_X, train_y, train_p = train
    valid_X, valid_y, valid_p = valid
    test_X, test_y, test_p = test

    print("length of training data:", train_X.shape, ";train_Y:",
          train_y.shape, ";train_p:", train_p.shape, ";valid data:",
          valid_X.shape, ";test data:", test_X.shape)
    # 1.create session.
    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    with tf.Session(config=gpu_config) as sess:
        #Instantiate Model
        config = set_config(FLAGS, vocab_size, vocab_size)
        model = BertModel(config)
        #Initialize Save
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint.")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            for i in range(2):  #decay learning rate if necessary.
                print(i, "Going to decay learning rate by half.")
                sess.run(model.learning_rate_decay_half_op)
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_pretrained_embedding:
                vocabulary_index2word = {
                    index: word
                    for word, index in vocab_word2index.items()
                }
                assign_pretrained_word_embedding(
                    sess, vocabulary_index2word, vocab_size,
                    FLAGS.word2vec_model_path, model.embedding,
                    config.d_model)  # assign pretrained word embeddings
        curr_epoch = sess.run(model.epoch_step)

        # 2.feed data & training
        number_of_training_data = len(train_X)
        print("number_of_training_data:", number_of_training_data)
        batch_size = FLAGS.batch_size
        iteration = 0
        score_best = -100
        for epoch in range(curr_epoch, FLAGS.num_epochs):
            loss_total_lm, counter = 0.0, 0
            for start, end in zip(
                    range(0, number_of_training_data, batch_size),
                    range(batch_size, number_of_training_data, batch_size)):
                iteration = iteration + 1
                if epoch == 0 and counter == 0:
                    print("trainX[start:end]:", train_X[start:end],
                          "train_X.shape:", train_X.shape)
                feed_dict = {
                    model.x_mask_lm: train_X[start:end],
                    model.y_mask_lm: train_y[start:end],
                    model.p_mask_lm: train_p[start:end],
                    model.dropout_keep_prob: FLAGS.dropout_keep_prob
                }
                current_loss_lm, lr, l2_loss, _ = sess.run([
                    model.loss_val_lm, model.learning_rate, model.l2_loss_lm,
                    model.train_op_lm
                ], feed_dict)
                loss_total_lm, counter = loss_total_lm + current_loss_lm, counter + 1
                if counter % 30 == 0:
                    print(
                        "%d\t%d\tLearning rate:%.5f\tLoss_lm:%.3f\tCurrent_loss_lm:%.3f\tL2_loss:%.3f\t"
                        % (epoch, counter, lr, float(loss_total_lm) /
                           float(counter), current_loss_lm, l2_loss))
                if start != 0 and start % (800 *
                                           FLAGS.batch_size) == 0:  # epoch!=0
                    loss_valid, acc_valid = do_eval(sess, model, valid,
                                                    batch_size)
                    print(
                        "%d\tValid.Epoch %d ValidLoss:%.3f\tAcc_valid:%.3f\t" %
                        (counter, epoch, loss_valid, acc_valid * 100))
                    # save model to checkpoint
                    if acc_valid > score_best:
                        save_path = FLAGS.ckpt_dir + "model.ckpt"
                        print("going to save check point.")
                        saver.save(sess, save_path, global_step=epoch)
                        score_best = acc_valid
            sess.run(model.epoch_increment)