예제 #1
0
def train(args):
    """
    训练阅读理解模型
    """
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger = logging.getLogger("brc")

    file_handler = logging.FileHandler(args.log_path)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    logger.info(args)

    logger.info('加载数据集和词汇表...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    brc_data = BRCDataset(args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    logger.info('词语转化为id序列...')
    brc_data.convert_to_ids(vocab)
    logger.info('初始化模型...')
    rc_model = RCModel(vocab, args)
    logger.info('训练模型...')
    rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('训练完成!')
예제 #2
0
def evaluate(args):
    """
    对训练好的模型进行验证
    """
    logger = logging.getLogger("brc")
    logger.info('加wudi...')
    logger.info('加载数据集和词汇表...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, '找不到验证文件.'
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          dev_files=args.dev_files)
    logger.info('把文本转化为id序列...')
    brc_data.convert_to_ids(vocab)
    logger.info('重载模型...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('验证模型...')
    dev_batches = brc_data.gen_mini_batches('dev',
                                            args.batch_size,
                                            pad_id=vocab.get_id(
                                                vocab.pad_token),
                                            shuffle=False)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches,
                                                 result_dir=args.result_dir,
                                                 result_prefix='dev.predicted')
    logger.info('验证集上的损失为: {}'.format(dev_loss))
    logger.info('验证集的结果: {}'.format(dev_bleu_rouge))
    logger.info('预测的答案证保存到 {}'.format(os.path.join(args.result_dir)))
예제 #3
0
파일: run.py 프로젝트: lduml/dureader
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    # 加载 vocab对象 ,包括 token2id id2token 以及其它方法
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    # brc_data.save_set_file(brc_data.dev_set, './save_sets', 'dev_set')
    # brc_data.save_set_file(brc_data.test_set, './save_sets', 'test_set')
    # brc_data.save_set_file(brc_data.train_set, './save_sets', 'train_set')
    logger.info('Converting text into ids...')
    # [self.train_set, self.dev_set, self.test_set] 原始数据 转为id形式
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    # 加载上次保存的模型
    # rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    # ****************************************************************
    logger.info('Training the model...')
    rc_model.train(brc_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #4
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    # 加载数据集 和 辞典(prepare保存的)
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin) # pickle python的标准模块 --prepare运行时vocab的对象信息读取
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files) # 最大 文章数,文章长度,问题长度,
                                                            # train时候只有训练文件,验证文件
    # 利用vocab 把brc_data 转换 成 id
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab) # 把原始数据的问题和文章的单词转换成辞典保存的id
    # 初始化神经网络
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    logger.info('Training the model...')
    """
    Train the model with data
    Args:
        data: the BRCDataset class implemented in dataset.py
        epochs: number of training epochs
        batch_size:
        save_dir: the directory to save the model
        save_prefix: the prefix indicating the model type
        dropout_keep_prob: float value indicating dropout keep probability
        evaluate: whether to evaluate the model on test set after each epoch
    """
    rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #5
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    # with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
    with open(args.vocab_path, 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          dev_files=args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = brc_data.gen_mini_batches('dev',
                                            args.batch_size,
                                            pad_id=vocab.get_id(
                                                vocab.pad_token),
                                            shuffle=False)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches,
                                                 result_dir=args.result_dir,
                                                 result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(
        os.path.join(args.result_dir)))
예제 #6
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          test_files=args.test_files,
                          dataset="Dureader")
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)

    test_loss, test_bleu_rouge = rc_model.evaluate(
        test_batches,
        result_dir=args.result_dir,
        result_prefix='test.predicted')
    logger.info('测试集上的损失为: {}'.format(test_loss))
예제 #7
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    # logger.info('Assigning embeddings...')
    # vocab.embed_dim = args.embed_size
    # vocab.load_pretrained_embeddings(args.embedding_path)

    logger.info('Vocabulary %s' % vocab.size())

    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          vocab, args.train_files, args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    # rc_model = MTRCModel(vocab, args)
    logger.info('Training the model...')
    rc_model.train(brc_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #8
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          args.train_files,
                          args.dev_files,
                          data_num=args.data_num)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    logger.info('Training the model...')
    rc_model.train(brc_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.experiment,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #9
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    print("vocab.size() = ", vocab.size())
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.max_word_len, args.train_files, args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab, args.use_char_level == 'true')
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    if args.retrain == 'true':
        rc_model.restore(model_dir=args.model_restore_dir,
                         model_prefix=args.algo_restore)
    rc_model.finalize()
    logger.info('Training the model...')
    rc_model.train(brc_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #10
0
파일: run.py 프로젝트: zsweet/G-Reader
def train(args):
    """
    训练阅读理解模型
    """
    logger = logging.getLogger("brc")
    logger.info('加载数据集和词汇表...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    logger.info('词语转化为id序列...')
    brc_data.convert_to_ids(vocab)
    logger.info('初始化模型...')
    rc_model = RCModel(vocab, args)
    logger.info('训练模型...')
    rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('训练完成!')
예제 #11
0
파일: run.py 프로젝트: jiangyuenju/DuReader
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    logger.info('Training the model...')
    rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #12
0
파일: run.py 프로젝트: jiangyuenju/DuReader
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test', args.batch_size,
                                             pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir, result_prefix='test.predicted')
예제 #13
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    logger.info('Loading Pretrained Word Embedding')
    # vocab.embed_dim = None
    # vocab.load_pretrained_embeddings(args.embedding_path)

    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          vocab,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    # rc_model = MTRCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix='test.predicted')
예제 #14
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    if args.word2vec_path:
        logger.info('learn_word_embedding:{}'.format(args.learn_word_embedding))
        logger.info('loadding %s \n' % args.word2vec_path)
        word2vec = gensim.models.Word2Vec.load(args.word2vec_path)
        vocab.load_pretrained_embeddings_from_w2v(word2vec.wv)
        logger.info('load pretrained embedding from %s done\n' % args.word2vec_path)

    if args.use_char_embed:
        with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin:
            char_vocab = pickle.load(fin)

    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    steps_per_epoch = brc_data.size('train') // args.batch_size
    args.decay_steps = args.decay_epochs * steps_per_epoch 
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    
    if args.use_char_embed:
        logger.info('Converting text into char ids...')
        brc_data.convert_to_char_ids(char_vocab)
        logger.info('Binding char_vocab to args to pass to RCModel')
        args.char_vocab = char_vocab

    RCModel = choose_model_by_gpu_setting(args)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    logger.info('Training the model...{}'.format(RCModel.__name__))
    rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #15
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(args.vocab_dir + '/vocab.data', 'rb') as fin:
        vocab = pickle.load(fin)

    with open(args.vocab_dir + '/vocab.data', 'rb') as fin:
        vocab1 = pickle.load(fin)

    with open(args.vocab_dir + '/vocab.data', 'rb') as fin:
        vocab2 = pickle.load(fin)

    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.max_a_len, args.train_files, args.dev_files,
                          args.test_files)

    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab, vocab1, vocab2)
    logger.info('Initialize the model...')

    rc_model = RCModel(vocab, vocab1, args)
    # rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Training the model...')
    if args.train_as:
        rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    rc_model.train(brc_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #16
0
def predict(args):
    """
    预测测试文件的答案
    """
    logger = logging.getLogger("brc")
    logger.info('加载数据集和词汇表...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, '找不到测试文件.'
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          test_files=args.test_files)
    logger.info('把文本转化为id序列...')
    brc_data.convert_to_ids(vocab)
    logger.info('重载模型...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('预测测试集的答案...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix='test.predicted')
예제 #17
0
def predict(args):
    """
    predicts answers for test files
    """
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    with open(os.path.join(args.vocab_dir, 'tar_vocab.data'), 'rb') as fin:
        vocab1 = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,args.max_a_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab, vocab1,vocab1)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab,vocab1, args)
    rc_model.restore(model_dir=args.model_dir , model_prefix=args.algo )
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test', args.batch_size,
                                             pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir, result_prefix=args.result_prefix)
예제 #18
0
def predict(args):
    """
	predicts answers for test files
	"""
    logger = logging.getLogger("Military AI")
    logger.info('Load data_set and vocab...')
    mai_data = MilitaryAiDataset(args.train_files,
                                 args.train_raw_files,
                                 args.test_files,
                                 args.test_raw_files,
                                 args.char_embed_file,
                                 args.token_embed_file,
                                 args.elmo_dict_file,
                                 args.elmo_embed_file,
                                 char_min_cnt=1,
                                 token_min_cnt=3)
    logger.info('Assigning embeddings...')
    if not args.use_embe:
        mai_data.token_vocab.randomly_init_embeddings(args.embed_size)
        mai_data.char_vocab.randomly_init_embeddings(args.embed_size)
    logger.info('Restoring the model...')
    rc_model = RCModel(mai_data.char_vocab, mai_data.token_vocab,
                       mai_data.flag_vocab, mai_data.elmo_vocab, args)
    rc_model.restore(model_dir=args.model_dir,
                     model_prefix=args.algo + args.suffix)
    logger.info('Predicting answers for test set...')
    test_batches = mai_data.gen_mini_batches('test',
                                             args.batch_size,
                                             shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix='test.predicted')
예제 #19
0
파일: run.py 프로젝트: baiyigali/mrc
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Loading vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.pkl'), 'rb') as fin:
        vocab = pickle.load(fin)
    fin.close()
    pad_id = vocab.get_id(vocab.pad_token)
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.prepared_dir, args.train_files, args.dev_files,
                          args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    g = tf.Graph()
    with g.as_default():
        rc_model = RCModel(vocab.embeddings, pad_id, args)
        del vocab
        # Train
        with tf.name_scope("Train"):
            logger.info('Training the model...')
            rc_model.train(brc_data,
                           args.epochs,
                           args.batch_size,
                           save_dir=args.result_dir,
                           save_prefix='test.predicted',
                           dropout_keep_prob=args.dropout_keep_prob)
        tf.summary.FileWriter(args.summary_dir, g).close()
        with tf.name_scope('Valid'):
            assert len(args.dev_files) > 0, 'No dev files are provided.'
            logger.info('Evaluating the model on dev set...')
            dev_batches = brc_data.gen_mini_batches('dev',
                                                    args.batch_size,
                                                    pad_id=pad_id,
                                                    shuffle=False)
            dev_loss, dev_bleu_rouge = rc_model.evaluate(
                dev_batches,
                result_dir=args.result_dir,
                result_prefix='dev.predicted')
            logger.info('Loss on dev set: {}'.format(dev_loss))
            logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
            logger.info('Predicted answers are saved to {}'.format(
                os.path.join(args.result_dir)))
        with tf.name_scope('Test'):
            assert len(args.test_files) > 0, 'No test files are provided.'
            logger.info('Predicting answers for test set...')
            test_batches = brc_data.gen_mini_batches('test',
                                                     args.batch_size,
                                                     pad_id=pad_id,
                                                     shuffle=False)
            rc_model.evaluate(test_batches,
                              result_dir=args.result_dir,
                              result_prefix='test.predicted')
예제 #20
0
def train(args):
	"""
	trains the reading comprehension model
	"""
	logger = logging.getLogger("brc")
	logger.info('Load data_set and vocab...')
	with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
		vocab = pickle.load(fin)
	brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
						  args.train_files, args.dev_files)  # 此函数中没有test_files
	logger.info('Converting text into ids...')
	brc_data.convert_to_ids(vocab)
	logger.info('Initialize the model...')
	rc_model = RCModel(vocab, args)
	logger.info('Training the model...')
	t = str(time.time())
	save_dir = os.path.join(args.model_dir, t)
	logger.info('model save dir:{}'.format(save_dir))
	rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=save_dir,
				   save_prefix='BERT-'+str(args.algo)+'-'+str(args.epochs)+'-'+str(args.batch_size),
				   dropout_keep_prob=args.dropout_keep_prob)
	logger.info('Done with model training!')
예제 #21
0
파일: run.py 프로젝트: jiangyuenju/DuReader
def evaluate(args):
    """
    evaluate the trained model on dev files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = brc_data.gen_mini_batches('dev', args.batch_size,
                                            pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(
        dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir)))
예제 #22
0
def evaluate(args):
    """
	evaluate the trained model on dev files
	"""
    logger = logging.getLogger("Military AI")
    logger.info('Load data_set and vocab...')
    mai_data = MilitaryAiDataset(args.train_files,
                                 args.train_raw_files,
                                 args.test_files,
                                 args.test_raw_files,
                                 args.char_embed_file,
                                 args.token_embed_file,
                                 args.elmo_dict_file,
                                 args.elmo_embed_file,
                                 char_min_cnt=1,
                                 token_min_cnt=3)

    logger.info('Assigning embeddings...')
    if not args.use_embe:
        mai_data.token_vocab.randomly_init_embeddings(args.embed_size)
        mai_data.char_vocab.randomly_init_embeddings(args.embed_size)
    logger.info('Restoring the model...')
    rc_model = RCModel(mai_data.char_vocab, mai_data.token_vocab,
                       mai_data.flag_vocab, mai_data.elmo_vocab, args)
    rc_model.restore(model_dir=args.model_dir,
                     model_prefix=args.algo + args.suffix)
    logger.info('Evaluating the model on dev set...')
    dev_batches = mai_data.gen_mini_batches('dev',
                                            args.batch_size,
                                            shuffle=False)
    dev_loss, dev_main_loss, dev_bleu_rouge = rc_model.evaluate(
        dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_main_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(
        os.path.join(args.result_dir)))
예제 #23
0
파일: run.py 프로젝트: colinsongf/MC-MDP
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("mrc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    brc_data = MRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          args.max_s_len,
                          args.train_files,
                          args.dev_files,
                          vocab=vocab)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')

    rc_model = RCModel(vocab, args)
    logger.info('Training the model...')
    if args.algo == 'MCST':
        logger.info('Use MCST Model to train...')
        rc_model = MCSTmodel(vocab, args)
        logger.info('Training MCST model...')
        rc_model.train(brc_data,
                       args.epochs,
                       args.batch_size,
                       save_dir=args.model_dir,
                       save_prefix=args.algo,
                       dropout_keep_prob=args.dropout_keep_prob)
    else:
        rc_model = RCModel(vocab, args)
        logger.info('Training the model...')
        rc_model.train(brc_data,
                       args.epochs,
                       args.batch_size,
                       save_dir=args.model_dir,
                       save_prefix=args.algo,
                       dropout_keep_prob=args.dropout_keep_prob)
예제 #24
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          args.max_word_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    rc_model.finalize()
    # 增加完所有操作后采用sess.graph.finalize()
    # 来使得整个graph变为只读的
    # 注意:tf.train.Saver()
    # 也算是往graph中添加node, 所以也必须放在finilize前
    # 但是,,tf.train.Saver()
    # 只会存储
    # 在该Saver声明时已经存在的变量!!!
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix='test.predicted')
예제 #25
0
def train(args):
    """
	trains the reading comprehension model
	"""
    logger = logging.getLogger("Military AI")
    logger.info('Load data_set and vocab...')

    mai_data = MilitaryAiDataset(args.train_files,
                                 args.train_raw_files,
                                 args.test_files,
                                 args.test_raw_files,
                                 args.char_embed_file,
                                 args.token_embed_file,
                                 args.elmo_dict_file,
                                 args.elmo_embed_file,
                                 char_min_cnt=1,
                                 token_min_cnt=3)

    logger.info('Assigning embeddings...')
    if not args.use_embe:
        mai_data.token_vocab.randomly_init_embeddings(args.embed_size)
        mai_data.char_vocab.randomly_init_embeddings(args.embed_size)
    logger.info('Initialize the model...')
    rc_model = RCModel(mai_data.char_vocab, mai_data.token_vocab,
                       mai_data.flag_vocab, mai_data.elmo_vocab, args)
    if args.is_restore or args.restore_suffix:
        restore_prefix = args.algo + args.suffix
        if args.restore_suffix:
            restore_prefix = args.algo + args.restore_suffix
        rc_model.restore(model_dir=args.model_dir, model_prefix=restore_prefix)
    logger.info('Training the model...')
    rc_model.train(mai_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.algo + args.suffix,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
예제 #26
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    steps_per_epoch = brc_data.size('train') // args.batch_size
    args.decay_steps = args.decay_epochs * steps_per_epoch 
    RCModel = choose_model_by_gpu_setting(args)
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test', args.batch_size,
                                             pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir, result_prefix='test.predicted')
예제 #27
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(args.vocab_path, 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.algo,
                          args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          args.max_a_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')

    brc_data.convert_to_ids(vocab)

    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    if args.algo == 'YESNO':
        qa_resultPath = args.test_files[0]  #只会有一个文件!
        (filepath, tempfilename) = os.path.split(qa_resultPath)
        (qarst_filename, extension) = os.path.splitext(tempfilename)
        result_prefix = qarst_filename
    else:
        result_prefix = 'test.predicted.qa'

    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix=result_prefix)
    if args.algo == 'YESNO':  #将YESNO结果合并入QA结果
        qa_resultPath = args.test_files[0]  #只会有一个文件!
        yesno_resultPath = args.result_dir + '/' + result_prefix + '.YESNO.json'
        out_file_path = args.result_dir + '/' + result_prefix + '.134.class.' + str(
            args.run_id) + '.json'

        #首先载入YESNO部分的预测结果
        yesno_records = {}
        with open(yesno_resultPath, 'r') as f_in:
            for line in f_in:
                sample = json.loads(line)
                yesno_records[sample['question_id']] = line

        total_rst_num = 0
        with open(qa_resultPath, 'r') as f_in:
            with open(out_file_path, 'w') as f_out:
                for line in f_in:
                    total_rst_num += 1
                    sample = json.loads(line)
                    if sample['question_id'] in yesno_records:
                        line = yesno_records[sample['question_id']]
                    f_out.write(line)

        print('total rst num : ', total_rst_num)
        print('yes no label combining done!')