def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') print('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, test_files=args.test_files) logger.info('Converting text into ids...') print('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') print('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') print('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') print('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files) logger.info('Converting text into ids...') print('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) logger.info('Training the model...') print('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def predict_one(args, test_json_data=None): """ predicts answers for test one data test_json_data: 示例 { "documents": [ { "title": "揭秘宋庆龄“第二段婚姻”传言不为人知的真相 - 红色秘史 - 红潮网 ", "segmented_title": "", "segmented_paragraphs": [[], []], "paragraphs": ["宋庆龄一生没有生养自己的孩子,鲜为人知的是,花甲之年时,她却有两个养女:隋永清和隋永洁。", ""], "bs_rank_pos": 0 }, {} ], "question": "宋庆龄第二任丈夫是谁", "segmented_question": ["宋庆龄", "第", "二", "任", "丈夫", "是", "谁"], "question_type": "ENTITY", 默认为 "ENTITY" "fact_or_opinion": "FACT", 默认为 "FACT" "question_id": 221574 } """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') print('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, test_one=test_json_data) logger.info('Converting text into ids...') print('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') print('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') print('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def prepare(args): """ checks data, creates the directories, prepare the vocabulary and embeddings """ logger = logging.getLogger("brc") logger.info('Checking the data files...') print('Checking the data files...') for data_path in args.train_files + args.dev_files + args.test_files: assert os.path.exists(data_path), '{} file does not exist.'.format( data_path) logger.info('Preparing the directories...') for dir_path in [ args.vocab_dir, args.model_dir, args.result_dir, args.summary_dir ]: if not os.path.exists(dir_path): os.makedirs(dir_path) logger.info('Building vocabulary...') print('Building vocabulary...') brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files, args.test_files) vocab = Vocab(lower=True) for word in brc_data.word_iter('train'): vocab.add(word) unfiltered_vocab_size = vocab.size() vocab.filter_tokens_by_cnt(min_cnt=2) filtered_num = unfiltered_vocab_size - vocab.size() logger.info('After filter {} tokens, the final vocab size is {}'.format( filtered_num, vocab.size())) logger.info('Assigning embeddings...') vocab.randomly_init_embeddings(args.embed_size) logger.info('Saving vocab...') print('Saving vocab...') with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'), 'wb') as fout: pickle.dump(vocab, fout) logger.info('Done with preparing!')
def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') print('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, dataName + 'BaiduVocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.dev_files) > 0, 'No dev files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files) logger.info('Converting text into ids...') print('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') print('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Evaluating the model on dev set...') print('Evaluating the model on dev set...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format( os.path.join(args.result_dir)))