def train(args): """ 训练阅读理解模型 """ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("brc") file_handler = logging.FileHandler(args.log_path) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(formatter) logger.addHandler(console_handler) logger.info(args) logger.info('加载数据集和词汇表...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_len, args.max_q_len, args.train_files, args.dev_files) logger.info('词语转化为id序列...') brc_data.convert_to_ids(vocab) logger.info('初始化模型...') rc_model = RCModel(vocab, args) logger.info('训练模型...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('训练完成!')
def evaluate(args): """ 对训练好的模型进行验证 """ logger = logging.getLogger("brc") logger.info('加wudi...') logger.info('加载数据集和词汇表...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.dev_files) > 0, '找不到验证文件.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files) logger.info('把文本转化为id序列...') brc_data.convert_to_ids(vocab) logger.info('重载模型...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('验证模型...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('验证集上的损失为: {}'.format(dev_loss)) logger.info('验证集的结果: {}'.format(dev_bleu_rouge)) logger.info('预测的答案证保存到 {}'.format(os.path.join(args.result_dir)))
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') # 加载 vocab对象 ,包括 token2id id2token 以及其它方法 with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files) # brc_data.save_set_file(brc_data.dev_set, './save_sets', 'dev_set') # brc_data.save_set_file(brc_data.test_set, './save_sets', 'test_set') # brc_data.save_set_file(brc_data.train_set, './save_sets', 'train_set') logger.info('Converting text into ids...') # [self.train_set, self.dev_set, self.test_set] 原始数据 转为id形式 brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) # 加载上次保存的模型 # rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) # **************************************************************** logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") # 加载数据集 和 辞典(prepare保存的) logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) # pickle python的标准模块 --prepare运行时vocab的对象信息读取 brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files) # 最大 文章数,文章长度,问题长度, # train时候只有训练文件,验证文件 # 利用vocab 把brc_data 转换 成 id logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) # 把原始数据的问题和文章的单词转换成辞典保存的id # 初始化神经网络 logger.info('Initialize the model...') rc_model = RCModel(vocab, args) logger.info('Training the model...') """ Train the model with data Args: data: the BRCDataset class implemented in dataset.py epochs: number of training epochs batch_size: save_dir: the directory to save the model save_prefix: the prefix indicating the model type dropout_keep_prob: float value indicating dropout keep probability evaluate: whether to evaluate the model on test set after each epoch """ rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') # with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: with open(args.vocab_path, 'rb') as fin: vocab = pickle.load(fin) assert len(args.dev_files) > 0, 'No dev files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Evaluating the model on dev set...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format( os.path.join(args.result_dir)))
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, test_files=args.test_files, dataset="Dureader") logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) test_loss, test_bleu_rouge = rc_model.evaluate( test_batches, result_dir=args.result_dir, result_prefix='test.predicted') logger.info('测试集上的损失为: {}'.format(test_loss))
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) # logger.info('Assigning embeddings...') # vocab.embed_dim = args.embed_size # vocab.load_pretrained_embeddings(args.embedding_path) logger.info('Vocabulary %s' % vocab.size()) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, vocab, args.train_files, args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) # rc_model = MTRCModel(vocab, args) logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files, data_num=args.data_num) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.experiment, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) print("vocab.size() = ", vocab.size()) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.max_word_len, args.train_files, args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab, args.use_char_level == 'true') logger.info('Initialize the model...') rc_model = RCModel(vocab, args) if args.retrain == 'true': rc_model.restore(model_dir=args.model_restore_dir, model_prefix=args.algo_restore) rc_model.finalize() logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def train(args): """ 训练阅读理解模型 """ logger = logging.getLogger("brc") logger.info('加载数据集和词汇表...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files) logger.info('词语转化为id序列...') brc_data.convert_to_ids(vocab) logger.info('初始化模型...') rc_model = RCModel(vocab, args) logger.info('训练模型...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('训练完成!')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id(vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) logger.info('Loading Pretrained Word Embedding') # vocab.embed_dim = None # vocab.load_pretrained_embeddings(args.embedding_path) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, vocab, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) # rc_model = MTRCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) if args.word2vec_path: logger.info('learn_word_embedding:{}'.format(args.learn_word_embedding)) logger.info('loadding %s \n' % args.word2vec_path) word2vec = gensim.models.Word2Vec.load(args.word2vec_path) vocab.load_pretrained_embeddings_from_w2v(word2vec.wv) logger.info('load pretrained embedding from %s done\n' % args.word2vec_path) if args.use_char_embed: with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin: char_vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files) steps_per_epoch = brc_data.size('train') // args.batch_size args.decay_steps = args.decay_epochs * steps_per_epoch logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) if args.use_char_embed: logger.info('Converting text into char ids...') brc_data.convert_to_char_ids(char_vocab) logger.info('Binding char_vocab to args to pass to RCModel') args.char_vocab = char_vocab RCModel = choose_model_by_gpu_setting(args) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) logger.info('Training the model...{}'.format(RCModel.__name__)) rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(args.vocab_dir + '/vocab.data', 'rb') as fin: vocab = pickle.load(fin) with open(args.vocab_dir + '/vocab.data', 'rb') as fin: vocab1 = pickle.load(fin) with open(args.vocab_dir + '/vocab.data', 'rb') as fin: vocab2 = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.max_a_len, args.train_files, args.dev_files, args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab, vocab1, vocab2) logger.info('Initialize the model...') rc_model = RCModel(vocab, vocab1, args) # rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Training the model...') if args.train_as: rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def predict(args): """ 预测测试文件的答案 """ logger = logging.getLogger("brc") logger.info('加载数据集和词汇表...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, '找不到测试文件.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, test_files=args.test_files) logger.info('把文本转化为id序列...') brc_data.convert_to_ids(vocab) logger.info('重载模型...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('预测测试集的答案...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def predict(args): """ predicts answers for test files """ if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) with open(os.path.join(args.vocab_dir, 'tar_vocab.data'), 'rb') as fin: vocab1 = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,args.max_a_len, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab, vocab1,vocab1) logger.info('Restoring the model...') rc_model = RCModel(vocab,vocab1, args) rc_model.restore(model_dir=args.model_dir , model_prefix=args.algo ) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id(vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix=args.result_prefix)
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("Military AI") logger.info('Load data_set and vocab...') mai_data = MilitaryAiDataset(args.train_files, args.train_raw_files, args.test_files, args.test_raw_files, args.char_embed_file, args.token_embed_file, args.elmo_dict_file, args.elmo_embed_file, char_min_cnt=1, token_min_cnt=3) logger.info('Assigning embeddings...') if not args.use_embe: mai_data.token_vocab.randomly_init_embeddings(args.embed_size) mai_data.char_vocab.randomly_init_embeddings(args.embed_size) logger.info('Restoring the model...') rc_model = RCModel(mai_data.char_vocab, mai_data.token_vocab, mai_data.flag_vocab, mai_data.elmo_vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo + args.suffix) logger.info('Predicting answers for test set...') test_batches = mai_data.gen_mini_batches('test', args.batch_size, shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Loading vocab...') with open(os.path.join(args.vocab_dir, 'vocab.pkl'), 'rb') as fin: vocab = pickle.load(fin) fin.close() pad_id = vocab.get_id(vocab.pad_token) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.prepared_dir, args.train_files, args.dev_files, args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) g = tf.Graph() with g.as_default(): rc_model = RCModel(vocab.embeddings, pad_id, args) del vocab # Train with tf.name_scope("Train"): logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.result_dir, save_prefix='test.predicted', dropout_keep_prob=args.dropout_keep_prob) tf.summary.FileWriter(args.summary_dir, g).close() with tf.name_scope('Valid'): assert len(args.dev_files) > 0, 'No dev files are provided.' logger.info('Evaluating the model on dev set...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=pad_id, shuffle=False) dev_loss, dev_bleu_rouge = rc_model.evaluate( dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format( os.path.join(args.result_dir))) with tf.name_scope('Test'): assert len(args.test_files) > 0, 'No test files are provided.' logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=pad_id, shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files) # 此函数中没有test_files logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) logger.info('Training the model...') t = str(time.time()) save_dir = os.path.join(args.model_dir, t) logger.info('model save dir:{}'.format(save_dir)) rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=save_dir, save_prefix='BERT-'+str(args.algo)+'-'+str(args.epochs)+'-'+str(args.batch_size), dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.dev_files) > 0, 'No dev files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Evaluating the model on dev set...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=vocab.get_id(vocab.pad_token), shuffle=False) dev_loss, dev_bleu_rouge = rc_model.evaluate( dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir)))
def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger("Military AI") logger.info('Load data_set and vocab...') mai_data = MilitaryAiDataset(args.train_files, args.train_raw_files, args.test_files, args.test_raw_files, args.char_embed_file, args.token_embed_file, args.elmo_dict_file, args.elmo_embed_file, char_min_cnt=1, token_min_cnt=3) logger.info('Assigning embeddings...') if not args.use_embe: mai_data.token_vocab.randomly_init_embeddings(args.embed_size) mai_data.char_vocab.randomly_init_embeddings(args.embed_size) logger.info('Restoring the model...') rc_model = RCModel(mai_data.char_vocab, mai_data.token_vocab, mai_data.flag_vocab, mai_data.elmo_vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo + args.suffix) logger.info('Evaluating the model on dev set...') dev_batches = mai_data.gen_mini_batches('dev', args.batch_size, shuffle=False) dev_loss, dev_main_loss, dev_bleu_rouge = rc_model.evaluate( dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_main_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format( os.path.join(args.result_dir)))
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("mrc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) brc_data = MRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.max_s_len, args.train_files, args.dev_files, vocab=vocab) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') rc_model = RCModel(vocab, args) logger.info('Training the model...') if args.algo == 'MCST': logger.info('Use MCST Model to train...') rc_model = MCSTmodel(vocab, args) logger.info('Training MCST model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) else: rc_model = RCModel(vocab, args) logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob)
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.max_word_len, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) rc_model.finalize() # 增加完所有操作后采用sess.graph.finalize() # 来使得整个graph变为只读的 # 注意:tf.train.Saver() # 也算是往graph中添加node, 所以也必须放在finilize前 # 但是,,tf.train.Saver() # 只会存储 # 在该Saver声明时已经存在的变量!!! logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("Military AI") logger.info('Load data_set and vocab...') mai_data = MilitaryAiDataset(args.train_files, args.train_raw_files, args.test_files, args.test_raw_files, args.char_embed_file, args.token_embed_file, args.elmo_dict_file, args.elmo_embed_file, char_min_cnt=1, token_min_cnt=3) logger.info('Assigning embeddings...') if not args.use_embe: mai_data.token_vocab.randomly_init_embeddings(args.embed_size) mai_data.char_vocab.randomly_init_embeddings(args.embed_size) logger.info('Initialize the model...') rc_model = RCModel(mai_data.char_vocab, mai_data.token_vocab, mai_data.flag_vocab, mai_data.elmo_vocab, args) if args.is_restore or args.restore_suffix: restore_prefix = args.algo + args.suffix if args.restore_suffix: restore_prefix = args.algo + args.restore_suffix rc_model.restore(model_dir=args.model_dir, model_prefix=restore_prefix) logger.info('Training the model...') rc_model.train(mai_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo + args.suffix, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') steps_per_epoch = brc_data.size('train') // args.batch_size args.decay_steps = args.decay_epochs * steps_per_epoch RCModel = choose_model_by_gpu_setting(args) rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id(vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(args.vocab_path, 'rb') as fin: vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.algo, args.max_p_num, args.max_p_len, args.max_q_len, args.max_a_len, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab) logger.info('Restoring the model...') rc_model = RCModel(vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) if args.algo == 'YESNO': qa_resultPath = args.test_files[0] #只会有一个文件! (filepath, tempfilename) = os.path.split(qa_resultPath) (qarst_filename, extension) = os.path.splitext(tempfilename) result_prefix = qarst_filename else: result_prefix = 'test.predicted.qa' rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix=result_prefix) if args.algo == 'YESNO': #将YESNO结果合并入QA结果 qa_resultPath = args.test_files[0] #只会有一个文件! yesno_resultPath = args.result_dir + '/' + result_prefix + '.YESNO.json' out_file_path = args.result_dir + '/' + result_prefix + '.134.class.' + str( args.run_id) + '.json' #首先载入YESNO部分的预测结果 yesno_records = {} with open(yesno_resultPath, 'r') as f_in: for line in f_in: sample = json.loads(line) yesno_records[sample['question_id']] = line total_rst_num = 0 with open(qa_resultPath, 'r') as f_in: with open(out_file_path, 'w') as f_out: for line in f_in: total_rst_num += 1 sample = json.loads(line) if sample['question_id'] in yesno_records: line = yesno_records[sample['question_id']] f_out.write(line) print('total rst num : ', total_rst_num) print('yes no label combining done!')