Exemplo n.º 1
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'rb') as fin:
        word_vocab = pickle.load(fin)
    with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin:
        char_vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(word_vocab,char_vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(word_vocab, char_vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = brc_data.gen_mini_batches('dev', args.batch_size,
                                            pad_id=word_vocab.get_id(word_vocab.pad_token), shuffle=False)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(
        dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir)))
Exemplo n.º 2
0
def demo(args):
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'rb') as fin:
        word_vocab = pickle.load(fin)
    with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin:
        char_vocab = pickle.load(fin)

    logger.info('Restoring the model...')
    rc_model = RCModel(word_vocab, char_vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    Demo(rc_model, args)
Exemplo n.º 3
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'rb') as fin:
        word_vocab = pickle.load(fin)
    print("load word dict finished")
    with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin:
        char_vocab = pickle.load(fin)
    print("load char dict finished")
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(word_vocab, char_vocab)
    logger.info('Initialize the model...')
    rc_model = RCModel(word_vocab, char_vocab, args)
    logger.info('Training the model...')
    rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
Exemplo n.º 4
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'rb') as fin:
        word_vocab = pickle.load(fin)
    with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin:
        char_vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(word_vocab,char_vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(word_vocab, char_vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test', args.batch_size,
                                             pad_id=word_vocab.get_id(word_vocab.pad_token), shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir, result_prefix='test.predicted')
Exemplo n.º 5
0
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              num_workers=num_workers,
                              collate_fn=transform.batchify,
                              shuffle=True)

    dev_loader = DataLoader(dataset=dev_dataset,
                            batch_size=BATCH_SIZE,
                            num_workers=num_workers,
                            collate_fn=transform.batchify)

    model_params = cur_cfg.model_params
    model_params['c_max_len'] = 64

    model = RCModel(model_params, embed_lists, mode=mode, emb_trainable=True)
    model = model.cuda()
    if mode == MODE_MRT:
        criterion_main = RougeLoss().cuda()
    elif mode == MODE_OBJ:
        criterion_main = ObjDetectionLoss(B=transform.B,
                                          S=transform.S,
                                          dynamic_score=False).cuda()
    else:
        criterion_main = PointerLoss().cuda()

    criterion_extra = nn.MultiLabelSoftMarginLoss().cuda()

    param_to_update = list(
        filter(lambda p: p.requires_grad, model.parameters()))
    model_param_num = 0
Exemplo n.º 6
0
Arquivo: test.py Projeto: eguilg/mrc
    transform = MaiIndexTransform(jieba_base_v, jieba_sgns_v, jieba_flag_v)

    test_data_source = MaiDirDataSource(testset_roots)

    test_loader = DataLoader(
        dataset=MaiDirDataset(test_data_source.data, transform),
        batch_sampler=MethodBasedBatchSampler(test_data_source.data,
                                              batch_size=32,
                                              shuffle=False),
        num_workers=mp.cpu_count(),
        collate_fn=transform.batchify)

    model_params = cur_cfg.model_params

    model = RCModel(model_params,
                    embed_lists_train)  #, normalize=(not use_mrt))
    print('loading model, ', model_path)

    state = torch.load(model_path)
    best_score = state['best_score']
    best_epoch = state['best_epoch']
    best_step = state['best_step']
    print('best_epoch:%2d, best_step:%5d, best_score:%.4f' %
          (best_epoch, best_step, best_score))
    model.load_state_dict(state['best_model_state'])
    model.reset_embeddings(embed_lists_test)
    model = model.cuda()
    model.eval()

    answer_prob_dict = {}
    answer_dict = {}