def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger("BiDAF") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.dev_files) > 0, 'No dev files are provided.' brc_data = Dataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = BiDAFModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Evaluating the model on dev set...') dev_batches = brc_data.get_batches('dev', args.batch_size, pad_id=vocab.get_id(vocab.pad_token), shuffle=False) dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format( os.path.join(args.result_dir)))
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("BiDAF") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) data = Dataset(train_files=args.train_files, dev_files=args.dev_files, max_p_length=args.max_p_len, max_q_length=args.max_q_len) logger.info('Converting text into ids...') data.convert_to_ids(vocab) logger.info('Initialize the model...') model = BiDAFModel(vocab, args) model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info("Load dev dataset...") model.dev_content_answer(args.dev_files) logger.info('Training the model...') model.train(data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def interactive(args): logger = logging.getLogger("BiDAF") with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) logger.info('Restoring the model...') rc_model = BiDAFModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) while True: content = input('输入原文:\n') if content == 'exit': exit(0) question = input('\n输入问题:\n') if question == 'exit': exit(0) content_segs = ' '.join(jieba.cut(content)).split() question_segs = ' '.join(jieba.cut(question)).split() content_ids = vocab.convert_to_ids(content_segs)[:args.max_p_len] question_ids = vocab.convert_to_ids(question_segs)[:args.max_q_len] batch_data = { 'question_ids': [], 'question_length': [], 'content_ids': [], 'content_length': [], 'start_id': [], 'end_id': [] } batch_data['question_ids'].append(question_ids) batch_data['question_length'].append(len(question_ids)) batch_data['content_ids'].append(content_ids) batch_data['content_length'].append(len(content_ids)) batch_data['start_id'].append(0) batch_data['end_id'].append(0) start_idx, end_idx = rc_model.getAnswer(batch_data) print( '\n==========================================================================\n' ) print('答案是 :', ''.join(content_segs[start_idx:end_idx + 1])) print( '\n==========================================================================\n' )
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("BiDAF") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = Dataset(args.max_p_len, args.max_q_len, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = BiDAFModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger(args.algo) logger.info('Load data_set and vocab...') # with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: # vocab = pickle.load(fin) data_dir = '/home/home1/dmyan/codes/bilm-tf/bilm/data/' vocab_file = data_dir + 'vocab.txt' batcher = TokenBatcher(vocab_file) data = Dataset(test_files=args.test_files, max_p_length=args.max_p_len, max_q_length=args.max_q_len) logger.info('Converting text into ids...') data.convert_to_ids(batcher) logger.info('Initialize the model...') if args.algo.startswith("BIDAF"): model = BiDAFModel(args) elif args.algo.startswith("R-net"): model = RNETModel(args) model.restore(model_dir=args.model_dir + args.algo, model_prefix=args.algo) #logger.info("Load dev dataset...") #model.dev_content_answer(args.dev_files) logger.info('Testing the model...') eval_batches = data.get_batches("test", args.batch_size, 0, shuffle=False) eval_loss, bleu_rouge = model.evaluate(eval_batches, result_dir=args.result_dir, result_prefix="test.predicted") logger.info("Test loss {}".format(eval_loss)) logger.info("Test result: {}".format(bleu_rouge)) logger.info('Done with model Testing!')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger(args.algo) logger.info('Load data_set and vocab...') # with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: # vocab = pickle.load(fin) data_dir = '/home/home1/dmyan/codes/bilm-tf/bilm/data/' vocab_file = data_dir + 'vocab.txt' batcher = TokenBatcher(vocab_file) data = Dataset(train_files=args.train_files, dev_files=args.dev_files, max_p_length=args.max_p_len, max_q_length=args.max_q_len) logger.info('Converting text into ids...') data.convert_to_ids(batcher) logger.info('Initialize the model...') if args.algo.startswith("BIDAF"): model = BiDAFModel(args) elif args.algo.startswith("R-net"): model = RNETModel(args) elif args.algo.startswith("QANET"): model = QANetModel(args) #model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info("Load dev dataset...") model.dev_content_answer(args.dev_files) logger.info('Training the model...') model.train(data, args.epochs, args.batch_size, save_dir=args.model_dir + args.algo, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')