def evaluate(args): """ evaluate the trained model on dev files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'rb') as fin: word_vocab = pickle.load(fin) with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin: char_vocab = pickle.load(fin) assert len(args.dev_files) > 0, 'No dev files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(word_vocab,char_vocab) logger.info('Restoring the model...') rc_model = RCModel(word_vocab, char_vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Evaluating the model on dev set...') dev_batches = brc_data.gen_mini_batches('dev', args.batch_size, pad_id=word_vocab.get_id(word_vocab.pad_token), shuffle=False) dev_loss, dev_bleu_rouge = rc_model.evaluate( dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted') logger.info('Loss on dev set: {}'.format(dev_loss)) logger.info('Result on dev set: {}'.format(dev_bleu_rouge)) logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir)))
def demo(args): logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'rb') as fin: word_vocab = pickle.load(fin) with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin: char_vocab = pickle.load(fin) logger.info('Restoring the model...') rc_model = RCModel(word_vocab, char_vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') Demo(rc_model, args)
def train(args): """ trains the reading comprehension model """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'rb') as fin: word_vocab = pickle.load(fin) print("load word dict finished") with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin: char_vocab = pickle.load(fin) print("load char dict finished") brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, args.train_files, args.dev_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(word_vocab, char_vocab) logger.info('Initialize the model...') rc_model = RCModel(word_vocab, char_vocab, args) logger.info('Training the model...') rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout_keep_prob=args.dropout_keep_prob) logger.info('Done with model training!')
def predict(args): """ predicts answers for test files """ logger = logging.getLogger("brc") logger.info('Load data_set and vocab...') with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'rb') as fin: word_vocab = pickle.load(fin) with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin: char_vocab = pickle.load(fin) assert len(args.test_files) > 0, 'No test files are provided.' brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, test_files=args.test_files) logger.info('Converting text into ids...') brc_data.convert_to_ids(word_vocab,char_vocab) logger.info('Restoring the model...') rc_model = RCModel(word_vocab, char_vocab, args) rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo) logger.info('Predicting answers for test set...') test_batches = brc_data.gen_mini_batches('test', args.batch_size, pad_id=word_vocab.get_id(word_vocab.pad_token), shuffle=False) rc_model.evaluate(test_batches, result_dir=args.result_dir, result_prefix='test.predicted')
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=num_workers, collate_fn=transform.batchify, shuffle=True) dev_loader = DataLoader(dataset=dev_dataset, batch_size=BATCH_SIZE, num_workers=num_workers, collate_fn=transform.batchify) model_params = cur_cfg.model_params model_params['c_max_len'] = 64 model = RCModel(model_params, embed_lists, mode=mode, emb_trainable=True) model = model.cuda() if mode == MODE_MRT: criterion_main = RougeLoss().cuda() elif mode == MODE_OBJ: criterion_main = ObjDetectionLoss(B=transform.B, S=transform.S, dynamic_score=False).cuda() else: criterion_main = PointerLoss().cuda() criterion_extra = nn.MultiLabelSoftMarginLoss().cuda() param_to_update = list( filter(lambda p: p.requires_grad, model.parameters())) model_param_num = 0
transform = MaiIndexTransform(jieba_base_v, jieba_sgns_v, jieba_flag_v) test_data_source = MaiDirDataSource(testset_roots) test_loader = DataLoader( dataset=MaiDirDataset(test_data_source.data, transform), batch_sampler=MethodBasedBatchSampler(test_data_source.data, batch_size=32, shuffle=False), num_workers=mp.cpu_count(), collate_fn=transform.batchify) model_params = cur_cfg.model_params model = RCModel(model_params, embed_lists_train) #, normalize=(not use_mrt)) print('loading model, ', model_path) state = torch.load(model_path) best_score = state['best_score'] best_epoch = state['best_epoch'] best_step = state['best_step'] print('best_epoch:%2d, best_step:%5d, best_score:%.4f' % (best_epoch, best_step, best_score)) model.load_state_dict(state['best_model_state']) model.reset_embeddings(embed_lists_test) model = model.cuda() model.eval() answer_prob_dict = {} answer_dict = {}