예제 #1
0
def main():
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)
    if not os.path.exists(config.train_log_dir):
        os.makedirs(config.train_log_dir)
    if not os.path.exists(config.valid_log_dir):
        os.makedirs(config.valid_log_dir)

    print('preparing data...')
    config.word_2_id, config.id_2_word = read_dict(config.word_dict)
    config.attr_2_id, config.id_2_attr = read_dict(config.attr_dict)
    config.vocab_size = min(config.vocab_size, len(config.word_2_id))
    config.oov_vocab_size = len(config.word_2_id) - config.vocab_size
    config.attr_size = len(config.attr_2_id)

    embedding_matrix = None
    if args.do_train:
        if os.path.exists(config.glove_file):
            print('loading embedding matrix from file: {}'.format(config.glove_file))
            embedding_matrix, config.word_em_size = load_glove_embedding(config.glove_file, list(config.word_2_id.keys()))
            print('shape of embedding matrix: {}'.format(embedding_matrix.shape))
    else:
        if os.path.exists(config.glove_file):
            with open(config.glove_file, 'r', encoding='utf-8') as fin:
                line = fin.readline()
                config.word_em_size = len(line.strip().split()) - 1

    data_reader = DataReader(config)
    evaluator = Evaluator('description')

    print('building model...')
    model = get_model(config, embedding_matrix)
    saver = tf.train.Saver(max_to_keep=10)

    if args.do_train:
        print('loading data...')
        train_data = data_reader.read_train_data()
        valid_data = data_reader.read_valid_data_small()

        print_title('Trainable Variables')
        for v in tf.trainable_variables():
            print(v)

        print_title('Gradients')
        for g in model.gradients:
            print(g)

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(config.result_dir)
            if model_file is not None:
                print('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)
            else:
                print('initializing from scratch...')
                tf.global_variables_initializer().run()

            train_writer = tf.summary.FileWriter(config.train_log_dir, sess.graph)
            valid_writer = tf.summary.FileWriter(config.valid_log_dir, sess.graph)

            run_train(sess, model, train_data, valid_data, saver, evaluator, train_writer, valid_writer, verbose=True)

    if args.do_eval:
        print('loading data...')
        valid_data = data_reader.read_valid_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(config.result_dir)
            if model_file is not None:
                print('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids, alignment_history, valid_loss, valid_accu = run_evaluate(sess, model, valid_data, verbose=True)
                print('average valid loss: {:>.4f}, average valid accuracy: {:>.4f}'.format(valid_loss, valid_accu))

                print_title('Saving Result')
                save_result(predicted_ids, alignment_history, config.id_2_word, config.valid_data, config.valid_result)
                evaluator.evaluate(config.valid_data, config.valid_result, config.to_lower)
            else:
                print('model not found!')

    if args.do_test:
        print('loading data...')
        test_data = data_reader.read_test_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(config.result_dir)
            if model_file is not None:
                print('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids, alignment_history = run_test(sess, model, test_data, verbose=True)

                print_title('Saving Result')
                save_result(predicted_ids, alignment_history, config.id_2_word, config.test_data, config.test_result)
                evaluator.evaluate(config.test_data, config.test_result, config.to_lower)
            else:
                print('model not found!')
예제 #2
0
def train():
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)
    if not os.path.exists(config.train_log_dir):
        os.mkdir(config.train_log_dir)
    if not os.path.exists(config.valid_log_dir):
        os.mkdir(config.valid_log_dir)

    print('loading data...')
    tokenizer = FullTokenizer(config.bert_vocab, do_lower_case=config.to_lower)
    pos_2_id, id_2_pos = read_dict(config.pos_dict)
    tag_2_id, id_2_tag = read_dict(config.tag_dict)
    config.num_pos = len(pos_2_id)
    config.num_tag = len(tag_2_id)

    data_reader = DataReader(config, tokenizer, pos_2_id, tag_2_id)
    train_data = data_reader.read_train_data()
    valid_data = data_reader.read_valid_data()

    print('building model...')
    model = get_model(config, is_training=True)

    tvars = tf.trainable_variables()
    assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, config.bert_ckpt)
    tf.train.init_from_checkpoint(config.bert_ckpt, assignment_map)

    print('==========  Trainable Variables  ==========')
    for v in tvars:
        init_string = ''
        if v.name in initialized_variable_names:
            init_string = '<INIT_FROM_CKPT>'
        print(v.name, v.shape, init_string)

    print('==========  Gradients  ==========')
    for g in model.gradients:
        print(g)

    best_score = 0.0
    saver = tf.train.Saver(max_to_keep=1)
    with tf.Session(config=sess_config) as sess:
        if tf.train.latest_checkpoint(config.result_dir):
            saver.restore(sess, tf.train.latest_checkpoint(config.result_dir))
            print('loading model from {}'.format(tf.train.latest_checkpoint(config.result_dir)))
        else:
            tf.global_variables_initializer().run()
            print('initializing from scratch.')

        train_writer = tf.summary.FileWriter(config.train_log_dir, sess.graph)

        for i in range(config.num_epoch):
            print('==========  Epoch {} Train  =========='.format(i + 1))
            train_batch_iter = make_batch_iter(list(zip(*train_data)), config.batch_size, shuffle=True)
            train_loss, train_accu = run_epoch(sess, model, train_batch_iter, train_writer, verbose=True)
            print('The average train loss is {:>.4f}, average train accuracy is {:>.4f}'.format(train_loss, train_accu))

            print('==========  Epoch {} Valid  =========='.format(i + 1))
            valid_batch_iter = make_batch_iter(list(zip(*valid_data)), config.batch_size, shuffle=False)
            outputs, valid_loss, valid_accu = evaluate(sess, model, valid_batch_iter, verbose=True)
            print('The average valid loss is {:>.4f}, average valid accuracy is {:>.4f}'.format(valid_loss, valid_accu))

            print('==========  Saving Result  ==========')
            save_result(outputs, config.valid_result, tokenizer, id_2_tag)

            if valid_accu > best_score:
                best_score = valid_accu
                saver.save(sess, config.model_file)
예제 #3
0
파일: train.py 프로젝트: snowlixue/LegalAtt
def train():
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)

    print('load data...')
    word_2_id, id_2_word = read_dict(config.word_dict)
    accu_2_id, id_2_accu = read_dict(config.accu_dict)
    art_2_id, id_2_art = read_dict(config.art_dict)

    if os.path.exists(config.word2vec_model):
        embedding_matrix = load_embedding(config.word2vec_model,
                                          word_2_id.keys())
    else:
        embedding_matrix = np.random.uniform(
            -0.5, 0.5, [len(word_2_id), config.embedding_size])

    data_reader = DataReader(config)
    train_data = data_reader.read_train_data(word_2_id, accu_2_id, art_2_id)
    valid_data = data_reader.read_valid_data(word_2_id, accu_2_id, art_2_id)
    art_data = data_reader.read_article(art_2_id.keys(), word_2_id)

    print('build model...')
    with tf.variable_scope('model'):
        model = get_model(config, embedding_matrix, is_training=True)

    print('==========  Trainable Variables  ==========')
    for v in tf.trainable_variables():
        print(v)

    saver = tf.train.Saver(max_to_keep=1)
    with tf.Session(config=config_proto) as sess:
        tf.global_variables_initializer().run()
        saver.save(sess, config.model_file)

        for i in range(config.num_epoch):
            print('==========  Epoch %2d Train  ==========' % (i + 1))
            train_batch_iter = make_batch_iter(list(zip(*train_data)),
                                               config.batch_size,
                                               shuffle=True)
            train_loss, _ = run_epoch(sess,
                                      model,
                                      train_batch_iter,
                                      art_data,
                                      verbose=True)
            print('The average train loss of epoch %2d is %.4f' %
                  ((i + 1), train_loss))

            print('==========  Epoch %2d Valid  ==========' % (i + 1))
            valid_batch_iter = make_batch_iter(list(zip(*valid_data)),
                                               config.batch_size,
                                               shuffle=False)
            outputs = inference(sess,
                                model,
                                valid_batch_iter,
                                art_data,
                                verbose=True)

            print('==========  Saving model  ==========')
            saver.save(sess, config.model_file)

            save_result(outputs, config.valid_result, id_2_accu, id_2_art)
            result = judger.get_result(config.valid_data, config.valid_result)
            accu_micro_f1, accu_macro_f1 = judger.calc_f1(result[0])
            article_micro_f1, article_macro_f1 = judger.calc_f1(result[1])
            score = [(accu_micro_f1 + accu_macro_f1) / 2,
                     (article_micro_f1 + article_macro_f1) / 2]
            print('Micro-F1 of accusation: %.4f' % accu_micro_f1)
            print('Macro-F1 of accusation: %.4f' % accu_macro_f1)
            print('Micro-F1 of relevant articles: %.4f' % article_micro_f1)
            print('Macro-F1 of relevant articles: %.4f' % article_macro_f1)
            print('Score: ', score)
예제 #4
0
def main():
    os.makedirs(config.temp_dir, exist_ok=True)
    os.makedirs(config.result_dir, exist_ok=True)
    os.makedirs(config.train_log_dir, exist_ok=True)

    logger.setLevel(logging.INFO)
    init_logger(logging.INFO, 'temp.log.txt', 'w')

    logger.info('preparing data...')
    config.word_2_id, config.id_2_word = read_json_dict(config.vocab_dict)
    config.vocab_size = min(config.vocab_size, len(config.word_2_id))
    config.oov_vocab_size = min(config.oov_vocab_size,
                                len(config.word_2_id) - config.vocab_size)

    embedding_matrix = None
    if args.do_train:
        if os.path.exists(config.glove_file):
            logger.info('loading embedding matrix from file: {}'.format(
                config.glove_file))
            embedding_matrix, config.word_em_size = load_glove_embedding(
                config.glove_file, list(config.word_2_id.keys()))
            logger.info('shape of embedding matrix: {}'.format(
                embedding_matrix.shape))
    else:
        if os.path.exists(config.glove_file):
            with open(config.glove_file, 'r', encoding='utf-8') as fin:
                line = fin.readline()
                config.word_em_size = len(line.strip().split()) - 1

    data_reader = DataReader(config)
    evaluator = Evaluator('tgt')

    logger.info('building model...')
    model = get_model(config, embedding_matrix)
    saver = tf.train.Saver(max_to_keep=10)

    if args.do_train:
        logger.info('loading data...')
        train_data = data_reader.read_train_data()
        valid_data = data_reader.read_valid_data()

        logger.info(log_title('Trainable Variables'))
        for v in tf.trainable_variables():
            logger.info(v)

        logger.info(log_title('Gradients'))
        for g in model.gradients:
            logger.info(g)

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)
            else:
                logger.info('initializing from scratch...')
                tf.global_variables_initializer().run()

            train_writer = tf.summary.FileWriter(config.train_log_dir,
                                                 sess.graph)

            valid_log_history = run_train(sess, model, train_data, valid_data,
                                          saver, evaluator, train_writer)
            save_json(
                valid_log_history,
                os.path.join(config.result_dir, config.current_model,
                             'valid_log_history.json'))

    if args.do_eval:
        logger.info('loading data...')
        valid_data = data_reader.read_valid_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids, valid_loss, valid_accu = run_evaluate(
                    sess, model, valid_data)
                logger.info(
                    'average valid loss: {:>.4f}, average valid accuracy: {:>.4f}'
                    .format(valid_loss, valid_accu))

                logger.info(log_title('Saving Result'))
                save_outputs(predicted_ids, config.id_2_word,
                             config.valid_data, config.valid_outputs)
                results = evaluator.evaluate(config.valid_data,
                                             config.valid_outputs,
                                             config.to_lower)
                save_json(results, config.valid_results)
            else:
                logger.info('model not found!')

    if args.do_test:
        logger.info('loading data...')
        test_data = data_reader.read_test_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids = run_test(sess, model, test_data)

                logger.info(log_title('Saving Result'))
                save_outputs(predicted_ids, config.id_2_word, config.test_data,
                             config.test_outputs)
                results = evaluator.evaluate(config.test_data,
                                             config.test_outputs,
                                             config.to_lower)
                save_json(results, config.test_results)
            else:
                logger.info('model not found!')