Ejemplo n.º 1
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    print('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'),
              'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    print('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    print('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    print('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix='test.predicted')
Ejemplo n.º 2
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    print('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'),
              'rb') as fin:
        vocab = pickle.load(fin)
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    logger.info('Converting text into ids...')
    print('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    logger.info('Training the model...')
    print('Training the model...')
    rc_model.train(brc_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
Ejemplo n.º 3
0
def predict_one(args, test_json_data=None):
    """
        predicts answers for test one data
        test_json_data: 示例
        {
            "documents":
            [
                {
                    "title": "揭秘宋庆龄“第二段婚姻”传言不为人知的真相 - 红色秘史 - 红潮网 ",
                    "segmented_title": "",
                    "segmented_paragraphs": [[], []],
                    "paragraphs": ["宋庆龄一生没有生养自己的孩子,鲜为人知的是,花甲之年时,她却有两个养女:隋永清和隋永洁。", ""],
                    "bs_rank_pos": 0
                    },
                    {}
            ],
            "question": "宋庆龄第二任丈夫是谁",
            "segmented_question": ["宋庆龄", "第", "二", "任", "丈夫", "是", "谁"],
            "question_type": "ENTITY",    默认为 "ENTITY"
            "fact_or_opinion": "FACT",    默认为 "FACT"
            "question_id": 221574
        }
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    print('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'),
              'rb') as fin:
        vocab = pickle.load(fin)

    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          test_one=test_json_data)
    logger.info('Converting text into ids...')
    print('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    print('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    print('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix='test.predicted')
Ejemplo n.º 4
0
def prepare(args):
    """
    checks data, creates the directories, prepare the vocabulary and embeddings
    """
    logger = logging.getLogger("brc")
    logger.info('Checking the data files...')
    print('Checking the data files...')
    for data_path in args.train_files + args.dev_files + args.test_files:
        assert os.path.exists(data_path), '{} file does not exist.'.format(
            data_path)
    logger.info('Preparing the directories...')
    for dir_path in [
            args.vocab_dir, args.model_dir, args.result_dir, args.summary_dir
    ]:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

    logger.info('Building vocabulary...')
    print('Building vocabulary...')
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files, args.test_files)
    vocab = Vocab(lower=True)
    for word in brc_data.word_iter('train'):
        vocab.add(word)

    unfiltered_vocab_size = vocab.size()
    vocab.filter_tokens_by_cnt(min_cnt=2)
    filtered_num = unfiltered_vocab_size - vocab.size()
    logger.info('After filter {} tokens, the final vocab size is {}'.format(
        filtered_num, vocab.size()))

    logger.info('Assigning embeddings...')
    vocab.randomly_init_embeddings(args.embed_size)

    logger.info('Saving vocab...')
    print('Saving vocab...')
    with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'),
              'wb') as fout:
        pickle.dump(vocab, fout)

    logger.info('Done with preparing!')
Ejemplo n.º 5
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    print('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'),
              'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          dev_files=args.dev_files)
    logger.info('Converting text into ids...')
    print('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    print('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Evaluating the model on dev set...')
    print('Evaluating the model on dev set...')
    dev_batches = brc_data.gen_mini_batches('dev',
                                            args.batch_size,
                                            pad_id=vocab.get_id(
                                                vocab.pad_token),
                                            shuffle=False)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches,
                                                 result_dir=args.result_dir,
                                                 result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(
        os.path.join(args.result_dir)))