Exemplo n.º 1
0
def gen_resume_tfrecord(pat, outdir, out_prefix='train'):
    files = glob(pat)
    ds = lm_model.LMDataSet('./model/vocab.txt', max_len)
    for index, file in enumerate(files):
        processor = lm_model.LMResumeDataProcessor()
        examples = processor.get_train_examples(file)
        outfile = os.path.join(outdir,
                               '{}-{}.tfrecord'.format(out_prefix, index))
        print('\n')
        print('====> begin processing {}'.format(file))
        ds.file_based_convert_examples_to_features(examples, ds.tokenizer,
                                                   outfile)
        print('====> out {}'.format(outfile))
        print('====> end processing {}'.format(file))
        print('\n')
Exemplo n.º 2
0
def train(train_file, vocab_file, config, log_dir, pretrained=None):
    graph = tf.Graph()
    with graph.as_default():
        ds = lm_model.LMDataSet(vocab_file, config['max_length'])
        d = ds.get_ds(train_file, config['batch_size'])
        train_iterator = d.make_one_shot_iterator()
        train_inputs = train_iterator.get_next()
        model = lm_model.LMModel(config, config['max_length'])
        loss = model.loss_train(True)
        loss_predict = model.loss_predict()
        train_op = create_optimizer(loss, config['learning_rate'],
                                    config['train_steps'],
                                    config['learning_rate_warmup_steps'],
                                    False)
        partialSaver = None
        if pretrained:
            partialSaver = mtransfer.partial_transfer(pretrained)
        sv = tf.train.Supervisor(graph=graph, logdir=log_dir)
        best_loss = 10000
        with sv.managed_session(master='') as sess:
            train_steps = config['train_steps']
            #             sess.run(tf.global_variables_initializer())
            if partialSaver:
                partialSaver.restore(sess, pretrained)
            for step in range(train_steps):
                if sv.should_stop():
                    break
                try:
                    inputs = sess.run(train_inputs)
                    feed_dicts = make_feed_dict(model, inputs)
                    loss_val, _ = sess.run([loss, train_op],
                                           feed_dict=feed_dicts)
                    if (step + 1) % 100 == 0:
                        print('====> [{}/{}]\tloss:{:.3f}'.format(
                            step, train_steps, loss_val))
                    if best_loss > loss_val:
                        best_loss = loss_val
                        sv.saver.save(sess,
                                      './log/best_model',
                                      global_step=(step + 1))
                        print('====> save model {}'.format((step + 1)))
                except Exception as e:
                    print(e)
                    sv.saver.save(sess,
                                  './log/final_model',
                                  global_step=(step + 1))
            sess.run(tf.global_variables_initializer())
Exemplo n.º 3
0
# !/usr/bin/env python3
import sys
sys.path.append('./bert')
import tensorflow as tf
import lm_model
processor = lm_model.LMWikiDataProcessor()
examples = processor.get_train_examples(
    '/Volumes/beast/data/qa/wiki/chin/extracted/AA/wiki0.txt')
print(len(examples))
# examples = examples[:20000]
max_len = 64
ds = lm_model.LMDataSet('./model/vocab.txt', max_len)
ds.file_based_convert_examples_to_features(examples, ds.tokenizer,
                                           './train.tfrecord')
# print(len(examples))