def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') vocab_path = os.path.join(args.vocab_dir, args.data_type, args.vocab_file) with open(vocab_path, 'rb') as fin: logger.info('load vocab from {}'.format(vocab_path)) vocab = pickle.load(fin) assert len(args.dev_files) > 0, 'No dev files are provided.' # data_type 容易和 data files 不一致,此处判断下 for f in args.train_files + args.dev_files + args.test_files: if args.data_type not in f: raise ValueError('Inconsistency between data_type and files') brc_data = Dataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files, badcase_sample_log_file=args.badcase_sample_log_file) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab, args.use_oov2unk) logger.info('Build the model...') rc_model = MultiAnsModel(vocab, args) logger.info('restore model from {}, with prefix {}'.format( os.path.join(args.model_dir, args.data_type), args.desc + args.algo)) rc_model.restore(model_dir=os.path.join(args.model_dir, args.data_type), model_prefix=args.desc + args.algo) logger.info('Evaluating the model on dev set...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=vocab.get_id( vocab.pad_token), shuffle=False) total_batch_count = brc_data.get_data_length( 'dev') // args.batch_size + int( brc_data.get_data_length('dev') % args.batch_size != 0) dev_loss, dev_bleu_rouge = rc_model.evaluate(total_batch_count, dev_batches, result_dir=os.path.join( args.result_dir, args.data_type), result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format( os.path.join(args.result_dir)))
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('check the directories...') for dir_path in [ os.path.join(args.model_dir, args.data_type), os.path.join(args.result_dir, args.data_type), os.path.join(args.summary_dir, args.data_type) ]: if not os.path.exists(dir_path): logger.warning( "don't exist {} directory, so we create it!".format(dir_path)) os.makedirs(dir_path) # data_type 容易和 data files 不一致,此处判断下 for f in args.train_files + args.dev_files + args.test_files: if args.data_type not in f: raise ValueError('Inconsistency between data_type and files') logger.info('Load data_set and vocab...') vocab_path = os.path.join(args.vocab_dir, args.data_type, args.vocab_file) with open(vocab_path, 'rb') as fin: logger.info('load vocab from {}'.format(vocab_path)) vocab = pickle.load(fin) brc_data = Dataset( args.max_p_num, args.max_p_len, args.max_q_len, args.max_a_len, train_answer_len_cut_bins=args.train_answer_len_cut_bins, train_files=args.train_files, dev_files=args.dev_files, badcase_sample_log_file=args.badcase_sample_log_file) logger.info('Converting text into ids...') brc_data.convert_to_ids(vocab, args.use_oov2unk) logger.info('Initialize the model...') rc_model = MultiAnsModel(vocab, args) logger.info('Training the model...') rc_model.train_and_evaluate_several_batchly( data=brc_data, epochs=args.epochs, batch_size=args.batch_size, evaluate_cnt_in_one_epoch=args.evaluate_cnt_in_one_epoch, save_dir=os.path.join(args.model_dir, args.data_type), save_prefix=args.desc + args.algo) logger.info('Done with model training!')