def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.dev_files) > 0, 'No dev files are provided.' brc_data = Dataset(args.max_p_num, args.max_p_len, args.max_q_len, args.max_w_len, dev_files=args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Evaluating the model on dev set...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format( os.path.join(args.result_dir)))
def evaluate_test(args): """ evaluate the trained model on test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No dev files are provided.' pro_data = Dataset(test_files=args.test_files) logger.info('Converting text into ids...') pro_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.restore_epoch) logger.info('Evaluating the model on test set...') test_batches = pro_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) test_loss, test_acc = rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted') logger.info('Accuracy on test set: {}'.format(test_acc))
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = Dataset(args.max_p_num, args.max_p_len, args.max_q_len, args.max_w_len, args.train_files, args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) if args.restore: logger.info('Restoring the model...') rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Evaluating the model on dev set...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) _, dev_bleu_rouge = rc_model.evaluate(dev_batches) rc_model.max_rouge_l = dev_bleu_rouge['ROUGE-L'] logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def debug(args): """ small batch used to debug """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) rc_data = Dataset(train_files=args.test_files, dev_files=args.dev_files) logger.info('Converting text into ids...') rc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) logger.info('Training the model...') rc_model.train(rc_data, args.epochs, args.batch_size, save_dir=args.model_dir, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = Dataset(args.max_p_num, args.max_p_len, args.max_q_len, args.max_w_len, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted', hack=True)
def getSoftmax(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") #加载字典数据 logger.info('Load vocab and embedding from text...') with open(args.pretrained_embedding_path, 'rb') as f: embedding = pickle.load(f) logger.info('Load vocab and embedding from text...') with open(args.pretrained_char_embedding_path, 'rb') as f: char_embedding = pickle.load(f) #loading data logger.info('loading the data...') with open(args.train_data_path, 'rb') as f: train_data = pickle.load(f) logger.info('train data size:' + str(len(train_data))) with open(args.dev_data_path, 'rb') as f: dev_data = pickle.load(f) with open(args.testa_data_path, 'rb') as f: test_data = pickle.load(f) logger.info('dev data size:' + str(len(dev_data))) logger.info('testa data size:' + str(len(test_data))) #shuffle数据 logger.info('shuffle the data...') #trainShuffleData = shuffle_data(train_data,'pWordId' ) #trainShuffleData = train_data logger.info('Initialize the model...') rc_model = RCModel(args, embedding, char_embedding) logger.info("loading model from {}{}".format(args.model_dir, args.algo + args.model_prefix)) rc_model.restore(args.model_dir, args.algo + args.model_prefix) logger.info('Training the model...') if args.softmax_mode == 'valid': rc_model.get_softmax_result(train_data=dev_data, batch_size=args.batch_size, dropout_keep_prob=args.dropout_keep_prob, outputPath=args.softmax_log_output_path) else: rc_model.get_softmax_result(train_data=test_data, batch_size=args.batch_size, dropout_keep_prob=args.dropout_keep_prob, outputPath=args.softmax_log_output_path) logger.info('Done with model training!')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") #加载字典数据 logger.info('Load vocab and embedding from text...') with open(args.pretrained_embedding_path, 'rb') as f: embedding = pickle.load(f) logger.info('Load vocab and embedding from text...') with open(args.pretrained_char_embedding_path, 'rb') as f: char_embedding = pickle.load(f) #loading data logger.info('loading the data...') with open(args.train_data_path, 'rb') as f: train_data = pickle.load(f) logger.info('train data size:' + str(len(train_data))) with open(args.dev_data_path, 'rb') as f: dev_data = pickle.load(f) logger.info('dev data size:' + str(len(dev_data))) #shuffle数据 logger.info('shuffle the data...') #trainShuffleData = shuffle_data(train_data,'pWordId' ) trainShuffleData = train_data logger.info('Initialize the model...') rc_model = RCModel(args, embedding, char_embedding) logger.info('Training the model...') rc_model.train(trainShuffleData, dev_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo + args.model_prefix, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')