Exemplo n.º 1
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    vocab_path = os.path.join(args.vocab_dir, args.data_type, args.vocab_file)
    with open(vocab_path, 'rb') as fin:
        logger.info('load vocab from {}'.format(vocab_path))
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'

    # data_type 容易和 data files 不一致,此处判断下
    for f in args.train_files + args.dev_files + args.test_files:
        if args.data_type not in f:
            raise ValueError('Inconsistency between data_type and files')

    brc_data = Dataset(args.max_p_num,
                       args.max_p_len,
                       args.max_q_len,
                       dev_files=args.dev_files,
                       badcase_sample_log_file=args.badcase_sample_log_file)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab, args.use_oov2unk)
    logger.info('Build the model...')
    rc_model = MultiAnsModel(vocab, args)
    logger.info('restore model from {}, with prefix {}'.format(
        os.path.join(args.model_dir, args.data_type), args.desc + args.algo))
    rc_model.restore(model_dir=os.path.join(args.model_dir, args.data_type),
                     model_prefix=args.desc + args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = brc_data.gen_mini_batches('dev',
                                            args.batch_size,
                                            pad_id=vocab.get_id(
                                                vocab.pad_token),
                                            shuffle=False)
    total_batch_count = brc_data.get_data_length(
        'dev') // args.batch_size + int(
            brc_data.get_data_length('dev') % args.batch_size != 0)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(total_batch_count,
                                                 dev_batches,
                                                 result_dir=os.path.join(
                                                     args.result_dir,
                                                     args.data_type),
                                                 result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(
        os.path.join(args.result_dir)))
Exemplo n.º 2
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")

    logger.info('check the directories...')
    for dir_path in [
            os.path.join(args.model_dir, args.data_type),
            os.path.join(args.result_dir, args.data_type),
            os.path.join(args.summary_dir, args.data_type)
    ]:
        if not os.path.exists(dir_path):
            logger.warning(
                "don't exist {} directory, so we create it!".format(dir_path))
            os.makedirs(dir_path)

    # data_type 容易和 data files 不一致,此处判断下
    for f in args.train_files + args.dev_files + args.test_files:
        if args.data_type not in f:
            raise ValueError('Inconsistency between data_type and files')

    logger.info('Load data_set and vocab...')
    vocab_path = os.path.join(args.vocab_dir, args.data_type, args.vocab_file)
    with open(vocab_path, 'rb') as fin:
        logger.info('load vocab from {}'.format(vocab_path))
        vocab = pickle.load(fin)
    brc_data = Dataset(
        args.max_p_num,
        args.max_p_len,
        args.max_q_len,
        args.max_a_len,
        train_answer_len_cut_bins=args.train_answer_len_cut_bins,
        train_files=args.train_files,
        dev_files=args.dev_files,
        badcase_sample_log_file=args.badcase_sample_log_file)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab, args.use_oov2unk)
    logger.info('Initialize the model...')
    rc_model = MultiAnsModel(vocab, args)
    logger.info('Training the model...')
    rc_model.train_and_evaluate_several_batchly(
        data=brc_data,
        epochs=args.epochs,
        batch_size=args.batch_size,
        evaluate_cnt_in_one_epoch=args.evaluate_cnt_in_one_epoch,
        save_dir=os.path.join(args.model_dir, args.data_type),
        save_prefix=args.desc + args.algo)
    logger.info('Done with model training!')