def gen_resume_tfrecord(pat, outdir, out_prefix='train'): files = glob(pat) ds = lm_model.LMDataSet('./model/vocab.txt', max_len) for index, file in enumerate(files): processor = lm_model.LMResumeDataProcessor() examples = processor.get_train_examples(file) outfile = os.path.join(outdir, '{}-{}.tfrecord'.format(out_prefix, index)) print('\n') print('====> begin processing {}'.format(file)) ds.file_based_convert_examples_to_features(examples, ds.tokenizer, outfile) print('====> out {}'.format(outfile)) print('====> end processing {}'.format(file)) print('\n')
def train(train_file, vocab_file, config, log_dir, pretrained=None): graph = tf.Graph() with graph.as_default(): ds = lm_model.LMDataSet(vocab_file, config['max_length']) d = ds.get_ds(train_file, config['batch_size']) train_iterator = d.make_one_shot_iterator() train_inputs = train_iterator.get_next() model = lm_model.LMModel(config, config['max_length']) loss = model.loss_train(True) loss_predict = model.loss_predict() train_op = create_optimizer(loss, config['learning_rate'], config['train_steps'], config['learning_rate_warmup_steps'], False) partialSaver = None if pretrained: partialSaver = mtransfer.partial_transfer(pretrained) sv = tf.train.Supervisor(graph=graph, logdir=log_dir) best_loss = 10000 with sv.managed_session(master='') as sess: train_steps = config['train_steps'] # sess.run(tf.global_variables_initializer()) if partialSaver: partialSaver.restore(sess, pretrained) for step in range(train_steps): if sv.should_stop(): break try: inputs = sess.run(train_inputs) feed_dicts = make_feed_dict(model, inputs) loss_val, _ = sess.run([loss, train_op], feed_dict=feed_dicts) if (step + 1) % 100 == 0: print('====> [{}/{}]\tloss:{:.3f}'.format( step, train_steps, loss_val)) if best_loss > loss_val: best_loss = loss_val sv.saver.save(sess, './log/best_model', global_step=(step + 1)) print('====> save model {}'.format((step + 1))) except Exception as e: print(e) sv.saver.save(sess, './log/final_model', global_step=(step + 1)) sess.run(tf.global_variables_initializer())
# !/usr/bin/env python3 import sys sys.path.append('./bert') import tensorflow as tf import lm_model processor = lm_model.LMWikiDataProcessor() examples = processor.get_train_examples( '/Volumes/beast/data/qa/wiki/chin/extracted/AA/wiki0.txt') print(len(examples)) # examples = examples[:20000] max_len = 64 ds = lm_model.LMDataSet('./model/vocab.txt', max_len) ds.file_based_convert_examples_to_features(examples, ds.tokenizer, './train.tfrecord') # print(len(examples))