Пример #1
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files)
    steps_per_epoch = brc_data.size('dev') // args.batch_size
    args.decay_steps = args.decay_epochs * steps_per_epoch 
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    RCModel = choose_model_by_gpu_setting(args)
    rc_model = RCModel(vocab, args)
    logger.info('Restoring the model...{}'.format(RCModel.__name__))
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = brc_data.gen_mini_batches('dev', args.batch_size,
                                            pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(
        dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted', save_full_info=False)
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir)))
Пример #2
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    if args.word2vec_path:
        logger.info('learn_word_embedding:{}'.format(args.learn_word_embedding))
        logger.info('loadding %s \n' % args.word2vec_path)
        word2vec = gensim.models.Word2Vec.load(args.word2vec_path)
        vocab.load_pretrained_embeddings_from_w2v(word2vec.wv)
        logger.info('load pretrained embedding from %s done\n' % args.word2vec_path)

    if args.use_char_embed:
        with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'rb') as fin:
            char_vocab = pickle.load(fin)

    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    steps_per_epoch = brc_data.size('train') // args.batch_size
    args.decay_steps = args.decay_epochs * steps_per_epoch 
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    
    if args.use_char_embed:
        logger.info('Converting text into char ids...')
        brc_data.convert_to_char_ids(char_vocab)
        logger.info('Binding char_vocab to args to pass to RCModel')
        args.char_vocab = char_vocab

    RCModel = choose_model_by_gpu_setting(args)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    logger.info('Training the model...{}'.format(RCModel.__name__))
    rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
Пример #3
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    steps_per_epoch = brc_data.size('train') // args.batch_size
    args.decay_steps = args.decay_epochs * steps_per_epoch 
    RCModel = choose_model_by_gpu_setting(args)
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test', args.batch_size,
                                             pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir, result_prefix='test.predicted')