Пример #1
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """
    logger = logging.getLogger("BiDAF")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'
    brc_data = Dataset(args.max_p_num,
                       args.max_p_len,
                       args.max_q_len,
                       dev_files=args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = BiDAFModel(vocab, args)

    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = brc_data.get_batches('dev',
                                       args.batch_size,
                                       pad_id=vocab.get_id(vocab.pad_token),
                                       shuffle=False)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches,
                                                 result_dir=args.result_dir,
                                                 result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(
        os.path.join(args.result_dir)))
Пример #2
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("BiDAF")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    data = Dataset(train_files=args.train_files,
                   dev_files=args.dev_files,
                   max_p_length=args.max_p_len,
                   max_q_length=args.max_q_len)
    logger.info('Converting text into ids...')
    data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    model = BiDAFModel(vocab, args)
    model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info("Load dev dataset...")
    model.dev_content_answer(args.dev_files)
    logger.info('Training the model...')
    model.train(data,
                args.epochs,
                args.batch_size,
                save_dir=args.model_dir,
                save_prefix=args.algo,
                dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
Пример #3
0
def interactive(args):
    logger = logging.getLogger("BiDAF")
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    logger.info('Restoring the model...')
    rc_model = BiDAFModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)

    while True:
        content = input('输入原文:\n')
        if content == 'exit':
            exit(0)

        question = input('\n输入问题:\n')
        if question == 'exit':
            exit(0)

        content_segs = ' '.join(jieba.cut(content)).split()
        question_segs = ' '.join(jieba.cut(question)).split()

        content_ids = vocab.convert_to_ids(content_segs)[:args.max_p_len]
        question_ids = vocab.convert_to_ids(question_segs)[:args.max_q_len]

        batch_data = {
            'question_ids': [],
            'question_length': [],
            'content_ids': [],
            'content_length': [],
            'start_id': [],
            'end_id': []
        }
        batch_data['question_ids'].append(question_ids)
        batch_data['question_length'].append(len(question_ids))
        batch_data['content_ids'].append(content_ids)
        batch_data['content_length'].append(len(content_ids))
        batch_data['start_id'].append(0)
        batch_data['end_id'].append(0)

        start_idx, end_idx = rc_model.getAnswer(batch_data)
        print(
            '\n==========================================================================\n'
        )
        print('答案是 :', ''.join(content_segs[start_idx:end_idx + 1]))
        print(
            '\n==========================================================================\n'
        )
Пример #4
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("BiDAF")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = Dataset(args.max_p_len,
                       args.max_q_len,
                       test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = BiDAFModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix='test.predicted')
Пример #5
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """

    logger = logging.getLogger(args.algo)
    logger.info('Load data_set and vocab...')
    # with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
    #     vocab = pickle.load(fin)

    data_dir = '/home/home1/dmyan/codes/bilm-tf/bilm/data/'
    vocab_file = data_dir + 'vocab.txt'
    batcher = TokenBatcher(vocab_file)

    data = Dataset(test_files=args.test_files,
                   max_p_length=args.max_p_len,
                   max_q_length=args.max_q_len)
    logger.info('Converting text into ids...')
    data.convert_to_ids(batcher)
    logger.info('Initialize the model...')
    if args.algo.startswith("BIDAF"):
        model = BiDAFModel(args)
    elif args.algo.startswith("R-net"):
        model = RNETModel(args)
    model.restore(model_dir=args.model_dir + args.algo, model_prefix=args.algo)
    #logger.info("Load dev dataset...")
    #model.dev_content_answer(args.dev_files)
    logger.info('Testing the model...')
    eval_batches = data.get_batches("test", args.batch_size, 0, shuffle=False)
    eval_loss, bleu_rouge = model.evaluate(eval_batches,
                                           result_dir=args.result_dir,
                                           result_prefix="test.predicted")
    logger.info("Test loss {}".format(eval_loss))
    logger.info("Test result: {}".format(bleu_rouge))
    logger.info('Done with model Testing!')
Пример #6
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger(args.algo)
    logger.info('Load data_set and vocab...')
    # with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
    #     vocab = pickle.load(fin)

    data_dir = '/home/home1/dmyan/codes/bilm-tf/bilm/data/'
    vocab_file = data_dir + 'vocab.txt'
    batcher = TokenBatcher(vocab_file)

    data = Dataset(train_files=args.train_files,
                   dev_files=args.dev_files,
                   max_p_length=args.max_p_len,
                   max_q_length=args.max_q_len)
    logger.info('Converting text into ids...')
    data.convert_to_ids(batcher)
    logger.info('Initialize the model...')
    if args.algo.startswith("BIDAF"):
        model = BiDAFModel(args)
    elif args.algo.startswith("R-net"):
        model = RNETModel(args)
    elif args.algo.startswith("QANET"):
        model = QANetModel(args)
    #model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info("Load dev dataset...")
    model.dev_content_answer(args.dev_files)
    logger.info('Training the model...')
    model.train(data,
                args.epochs,
                args.batch_size,
                save_dir=args.model_dir + args.algo,
                save_prefix=args.algo,
                dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')