示例#1
0
def main():
    print('loading data...')
    tokenizer = FullTokenizer(config.bert_vocab, do_lower_case=config.to_lower)
    pos_2_id, id_2_pos = read_dict(config.pos_dict)
    tag_2_id, id_2_tag = read_dict(config.tag_dict)
    config.num_pos = len(pos_2_id)
    config.num_tag = len(tag_2_id)

    data_reader = DataReader(config, tokenizer, pos_2_id, tag_2_id)
    input_file = args.input
    print('input file: {}'.format(input_file))
    input_data = data_reader.load_data_from_file(input_file)    

    print('building model...')
    model = get_model(config, is_training=False)

    saver = tf.train.Saver(max_to_keep=1)
    with tf.Session(config=sess_config) as sess:
        if tf.train.latest_checkpoint(config.result_dir):
            saver.restore(sess, tf.train.latest_checkpoint(config.result_dir))
            print('loading model from {}'.format(tf.train.latest_checkpoint(config.result_dir)))

            batch_iter = make_batch_iter(list(zip(*input_data)), config.batch_size, shuffle=False)
            outputs = inference(sess, model, batch_iter, verbose=True)

            print('==========  Saving Result  ==========')
            output_file = args.output
            save_result(outputs, output_file, tokenizer, id_2_tag)
        else:
            print('model not found.')

        print('done')
示例#2
0
def test():
    print('loading data...')
    tokenizer = FullTokenizer(config.bert_vocab, do_lower_case=config.to_lower)
    pos_2_id, id_2_pos = read_dict(config.pos_dict)
    tag_2_id, id_2_tag = read_dict(config.tag_dict)
    config.num_pos = len(pos_2_id)
    config.num_tag = len(tag_2_id)

    data_reader = DataReader(config, tokenizer, pos_2_id, tag_2_id)
    test_data = data_reader.read_test_data()

    print('building model...')
    model = get_model(config, is_training=False)

    saver = tf.train.Saver(max_to_keep=1)
    with tf.Session(config=sess_config) as sess:
        if tf.train.latest_checkpoint(config.result_dir):
            saver.restore(sess, tf.train.latest_checkpoint(config.result_dir))
            print('loading model from {}'.format(tf.train.latest_checkpoint(config.result_dir)))

            print('==========  Test  ==========')
            test_batch_iter = make_batch_iter(list(zip(*test_data)), config.batch_size, shuffle=False)
            outputs, test_loss, test_accu = evaluate(sess, model, test_batch_iter, verbose=True)
            print('The average test loss is {:>.4f}, average test accuracy is {:>.4f}'.format(test_loss, test_accu))

            print('==========  Saving Result  ==========')
            save_result(outputs, config.test_result, tokenizer, id_2_tag)
        else:
            print('model not found.')

        print('done')
示例#3
0
def main():
    config = Config('.', 'temp')
    pos_2_id, id_2_pos = read_dict(config.pos_dict)
    tag_2_id, id_2_tag = read_dict(config.tag_dict)
    tokenizer = Tokenizer(config.bert_vocab, do_lower_case=config.to_lower)
    data_reader = DataReader(config, tokenizer, pos_2_id, tag_2_id)

    valid_data = data_reader.read_valid_data()
    check_data(valid_data, tokenizer, id_2_pos, id_2_tag)

    print('done')
示例#4
0
 def __init__(self, dic_path, eng_dic_path=None):
     self.dict = read_dict(dic_path)
     if eng_dic_path:
         self.eng_dict = read_dict(eng_dic_path)
     else:
         self.eng_dict = None
示例#5
0
def main():
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)
    if not os.path.exists(config.train_log_dir):
        os.makedirs(config.train_log_dir)
    if not os.path.exists(config.valid_log_dir):
        os.makedirs(config.valid_log_dir)

    print('preparing data...')
    config.word_2_id, config.id_2_word = read_dict(config.word_dict)
    config.attr_2_id, config.id_2_attr = read_dict(config.attr_dict)
    config.vocab_size = min(config.vocab_size, len(config.word_2_id))
    config.oov_vocab_size = len(config.word_2_id) - config.vocab_size
    config.attr_size = len(config.attr_2_id)

    embedding_matrix = None
    if args.do_train:
        if os.path.exists(config.glove_file):
            print('loading embedding matrix from file: {}'.format(config.glove_file))
            embedding_matrix, config.word_em_size = load_glove_embedding(config.glove_file, list(config.word_2_id.keys()))
            print('shape of embedding matrix: {}'.format(embedding_matrix.shape))
    else:
        if os.path.exists(config.glove_file):
            with open(config.glove_file, 'r', encoding='utf-8') as fin:
                line = fin.readline()
                config.word_em_size = len(line.strip().split()) - 1

    data_reader = DataReader(config)
    evaluator = Evaluator('description')

    print('building model...')
    model = get_model(config, embedding_matrix)
    saver = tf.train.Saver(max_to_keep=10)

    if args.do_train:
        print('loading data...')
        train_data = data_reader.read_train_data()
        valid_data = data_reader.read_valid_data_small()

        print_title('Trainable Variables')
        for v in tf.trainable_variables():
            print(v)

        print_title('Gradients')
        for g in model.gradients:
            print(g)

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(config.result_dir)
            if model_file is not None:
                print('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)
            else:
                print('initializing from scratch...')
                tf.global_variables_initializer().run()

            train_writer = tf.summary.FileWriter(config.train_log_dir, sess.graph)
            valid_writer = tf.summary.FileWriter(config.valid_log_dir, sess.graph)

            run_train(sess, model, train_data, valid_data, saver, evaluator, train_writer, valid_writer, verbose=True)

    if args.do_eval:
        print('loading data...')
        valid_data = data_reader.read_valid_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(config.result_dir)
            if model_file is not None:
                print('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids, alignment_history, valid_loss, valid_accu = run_evaluate(sess, model, valid_data, verbose=True)
                print('average valid loss: {:>.4f}, average valid accuracy: {:>.4f}'.format(valid_loss, valid_accu))

                print_title('Saving Result')
                save_result(predicted_ids, alignment_history, config.id_2_word, config.valid_data, config.valid_result)
                evaluator.evaluate(config.valid_data, config.valid_result, config.to_lower)
            else:
                print('model not found!')

    if args.do_test:
        print('loading data...')
        test_data = data_reader.read_test_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(config.result_dir)
            if model_file is not None:
                print('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids, alignment_history = run_test(sess, model, test_data, verbose=True)

                print_title('Saving Result')
                save_result(predicted_ids, alignment_history, config.id_2_word, config.test_data, config.test_result)
                evaluator.evaluate(config.test_data, config.test_result, config.to_lower)
            else:
                print('model not found!')
示例#6
0
def train():
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)
    if not os.path.exists(config.train_log_dir):
        os.mkdir(config.train_log_dir)
    if not os.path.exists(config.valid_log_dir):
        os.mkdir(config.valid_log_dir)

    print('loading data...')
    tokenizer = FullTokenizer(config.bert_vocab, do_lower_case=config.to_lower)
    pos_2_id, id_2_pos = read_dict(config.pos_dict)
    tag_2_id, id_2_tag = read_dict(config.tag_dict)
    config.num_pos = len(pos_2_id)
    config.num_tag = len(tag_2_id)

    data_reader = DataReader(config, tokenizer, pos_2_id, tag_2_id)
    train_data = data_reader.read_train_data()
    valid_data = data_reader.read_valid_data()

    print('building model...')
    model = get_model(config, is_training=True)

    tvars = tf.trainable_variables()
    assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, config.bert_ckpt)
    tf.train.init_from_checkpoint(config.bert_ckpt, assignment_map)

    print('==========  Trainable Variables  ==========')
    for v in tvars:
        init_string = ''
        if v.name in initialized_variable_names:
            init_string = '<INIT_FROM_CKPT>'
        print(v.name, v.shape, init_string)

    print('==========  Gradients  ==========')
    for g in model.gradients:
        print(g)

    best_score = 0.0
    saver = tf.train.Saver(max_to_keep=1)
    with tf.Session(config=sess_config) as sess:
        if tf.train.latest_checkpoint(config.result_dir):
            saver.restore(sess, tf.train.latest_checkpoint(config.result_dir))
            print('loading model from {}'.format(tf.train.latest_checkpoint(config.result_dir)))
        else:
            tf.global_variables_initializer().run()
            print('initializing from scratch.')

        train_writer = tf.summary.FileWriter(config.train_log_dir, sess.graph)

        for i in range(config.num_epoch):
            print('==========  Epoch {} Train  =========='.format(i + 1))
            train_batch_iter = make_batch_iter(list(zip(*train_data)), config.batch_size, shuffle=True)
            train_loss, train_accu = run_epoch(sess, model, train_batch_iter, train_writer, verbose=True)
            print('The average train loss is {:>.4f}, average train accuracy is {:>.4f}'.format(train_loss, train_accu))

            print('==========  Epoch {} Valid  =========='.format(i + 1))
            valid_batch_iter = make_batch_iter(list(zip(*valid_data)), config.batch_size, shuffle=False)
            outputs, valid_loss, valid_accu = evaluate(sess, model, valid_batch_iter, verbose=True)
            print('The average valid loss is {:>.4f}, average valid accuracy is {:>.4f}'.format(valid_loss, valid_accu))

            print('==========  Saving Result  ==========')
            save_result(outputs, config.valid_result, tokenizer, id_2_tag)

            if valid_accu > best_score:
                best_score = valid_accu
                saver.save(sess, config.model_file)