Ejemplo n.º 1
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'
    brc_data = Dataset(args.max_p_num,
                       args.max_p_len,
                       args.max_q_len,
                       args.max_w_len,
                       dev_files=args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = brc_data.gen_mini_batches('dev',
                                            args.batch_size,
                                            pad_id=vocab.get_id(
                                                vocab.pad_token),
                                            shuffle=False)
    dev_loss, dev_bleu_rouge = rc_model.evaluate(dev_batches,
                                                 result_dir=args.result_dir,
                                                 result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(
        os.path.join(args.result_dir)))
Ejemplo n.º 2
0
def evaluate_test(args):
    """
    evaluate the trained model on test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No dev files are provided.'
    pro_data = Dataset(test_files=args.test_files)
    logger.info('Converting text into ids...')
    pro_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.restore_epoch)
    logger.info('Evaluating the model on test set...')
    test_batches = pro_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    test_loss, test_acc = rc_model.evaluate(test_batches,
                                            result_dir=args.result_dir,
                                            result_prefix='test.predicted')
    logger.info('Accuracy on test set: {}'.format(test_acc))
Ejemplo n.º 3
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    brc_data = Dataset(args.max_p_num, args.max_p_len, args.max_q_len,
                       args.max_w_len, args.train_files, args.dev_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    if args.restore:
        logger.info('Restoring the model...')
        rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
        logger.info('Evaluating the model on dev set...')
        dev_batches = brc_data.gen_mini_batches('dev',
                                                args.batch_size,
                                                pad_id=vocab.get_id(
                                                    vocab.pad_token),
                                                shuffle=False)
        _, dev_bleu_rouge = rc_model.evaluate(dev_batches)
        rc_model.max_rouge_l = dev_bleu_rouge['ROUGE-L']
    logger.info('Training the model...')
    rc_model.train(brc_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
Ejemplo n.º 4
0
def debug(args):
    """
    small batch used to debug
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    rc_data = Dataset(train_files=args.test_files, dev_files=args.dev_files)
    logger.info('Converting text into ids...')
    rc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    rc_model = RCModel(vocab, args)
    logger.info('Training the model...')
    rc_model.train(rc_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
Ejemplo n.º 5
0
def predict(args):
    """
    predicts answers for test files
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    brc_data = Dataset(args.max_p_num,
                       args.max_p_len,
                       args.max_q_len,
                       args.max_w_len,
                       test_files=args.test_files)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = brc_data.gen_mini_batches('test',
                                             args.batch_size,
                                             pad_id=vocab.get_id(
                                                 vocab.pad_token),
                                             shuffle=False)
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir,
                      result_prefix='test.predicted',
                      hack=True)
Ejemplo n.º 6
0
def getSoftmax(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")

    #加载字典数据
    logger.info('Load vocab and embedding from text...')
    with open(args.pretrained_embedding_path, 'rb') as f:
        embedding = pickle.load(f)

    logger.info('Load vocab and embedding from text...')
    with open(args.pretrained_char_embedding_path, 'rb') as f:
        char_embedding = pickle.load(f)

    #loading data
    logger.info('loading the data...')
    with open(args.train_data_path, 'rb') as f:
        train_data = pickle.load(f)
    logger.info('train data size:' + str(len(train_data)))
    with open(args.dev_data_path, 'rb') as f:
        dev_data = pickle.load(f)

    with open(args.testa_data_path, 'rb') as f:
        test_data = pickle.load(f)
    logger.info('dev data size:' + str(len(dev_data)))
    logger.info('testa data size:' + str(len(test_data)))

    #shuffle数据
    logger.info('shuffle the data...')
    #trainShuffleData = shuffle_data(train_data,'pWordId' )
    #trainShuffleData = train_data

    logger.info('Initialize the model...')
    rc_model = RCModel(args, embedding, char_embedding)
    logger.info("loading model from {}{}".format(args.model_dir, args.algo +
                                                 args.model_prefix))
    rc_model.restore(args.model_dir, args.algo + args.model_prefix)
    logger.info('Training the model...')
    if args.softmax_mode == 'valid':
        rc_model.get_softmax_result(train_data=dev_data,
                                    batch_size=args.batch_size,
                                    dropout_keep_prob=args.dropout_keep_prob,
                                    outputPath=args.softmax_log_output_path)
    else:
        rc_model.get_softmax_result(train_data=test_data,
                                    batch_size=args.batch_size,
                                    dropout_keep_prob=args.dropout_keep_prob,
                                    outputPath=args.softmax_log_output_path)
    logger.info('Done with model training!')
Ejemplo n.º 7
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger("brc")

    #加载字典数据
    logger.info('Load vocab and embedding from text...')
    with open(args.pretrained_embedding_path, 'rb') as f:
        embedding = pickle.load(f)

    logger.info('Load vocab and embedding from text...')
    with open(args.pretrained_char_embedding_path, 'rb') as f:
        char_embedding = pickle.load(f)

    #loading data
    logger.info('loading the data...')
    with open(args.train_data_path, 'rb') as f:
        train_data = pickle.load(f)
    logger.info('train data size:' + str(len(train_data)))
    with open(args.dev_data_path, 'rb') as f:
        dev_data = pickle.load(f)
    logger.info('dev data size:' + str(len(dev_data)))

    #shuffle数据
    logger.info('shuffle the data...')
    #trainShuffleData = shuffle_data(train_data,'pWordId' )
    trainShuffleData = train_data

    logger.info('Initialize the model...')
    rc_model = RCModel(args, embedding, char_embedding)
    logger.info('Training the model...')
    rc_model.train(trainShuffleData,
                   dev_data,
                   args.epochs,
                   args.batch_size,
                   save_dir=args.model_dir,
                   save_prefix=args.algo + args.model_prefix,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')