Пример #1
0
def predict(args):
    """Predict answers"""
    logger = logging.getLogger("QANet")
    logger.info('Load data_set and vocab...')
    print('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, dataName+'OurVocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    assert len(args.test_files) > 0, 'No test files are provided.'
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len, 
                          test_files=args.test_files)

    logger.info('Converting text into ids...')
    print('Converting text into ids...')
    dataloader.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    print('Restoring the model...')

    model = Model(vocab, args)
    model.restore(args.model_dir, args.algo)
    logger.info('Predicting answers for test set...')
    print('Predicting answers for test set...')
    test_batches = dataloader.next_batch('test', args.batch_size, vocab.get_word_id(vocab.pad_token), vocab.get_char_id(vocab.pad_token), shuffle=False)

    model.evaluate(test_batches,result_dir=args.result_dir, result_prefix='test.predicted')
Пример #2
0
def train(args):
    """Train"""
    logger = logging.getLogger("QANet")
    logger.info("====== training ======")

    logger.info('Load data_set and vocab...')
    print('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, dataName+'OurVocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len,
                          args.train_files, args.dev_files)

    logger.info('Converting text into ids...')
    dataloader.convert_to_ids(vocab)

    logger.info('Initialize the model...')
    model = Model(vocab, args)

    logger.info('Training the model...')
    print('Training the model...')
    model.train(dataloader, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout=args.dropout)

    logger.info('====== Done with model training! ======')
    print('====== Done with model training! ======')
Пример #3
0
def evaluate(args):
    """Evaluate test data"""
    logger = logging.getLogger("QANet")
    logger.info("====== evaluating ======")
    logger.info('Load data_set and vocab...')
    print('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, dataName+'OurVocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    assert len(args.dev_files) > 0, 'No dev files are provided.'
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len,
                            args.max_ch_len, args.train_files, args.dev_files)

    logger.info('Converting text into ids...')
    print('Converting text into ids...')
    dataloader.convert_to_ids(vocab)

    logger.info('Restoring the model...')
    print('Restoring the model...')
    model = Model(vocab, args)
    model.restore(args.model_dir, args.algo)
    logger.info('Evaluating the model on dev set...')
    print('Evaluating the model on dev set...')
    dev_batches = dataloader.next_batch('dev', args.batch_size, vocab.get_word_id(vocab.pad_token), vocab.get_char_id(vocab.pad_token), shuffle=False)

    dev_loss, dev_bleu_rouge = model.evaluate(
        dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted')

    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir)))
Пример #4
0
def prepare(args):
    """prepare to process data including building vocab"""
    logger = logging.getLogger("QANet")
    logger.info("====== preprocessing ======")
    logger.info('Checking the data files...')
    print('Checking the data files...')
    for data_path in args.train_files + args.dev_files + args.test_files:
        assert os.path.exists(data_path), '{} file does not exist.'.format(data_path)

    logger.info('Preparing the directories...')
    print('Preparing the directories...')
    for dir_path in [args.vocab_dir, args.model_dir, args.result_dir, args.summary_dir]:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

    logger.info('Building vocabulary...')
    print('Building vocabulary...')
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len,
                          args.train_files, args.dev_files, args.test_files)

    vocab = Vocab(lower=True)
    for word in dataloader.word_iter('train'):
        vocab.add_word(word)
        [vocab.add_char(ch) for ch in word]

    unfiltered_vocab_size = vocab.word_size()
    vocab.filter_words_by_cnt(min_cnt=2)
    filtered_num = unfiltered_vocab_size - vocab.word_size()
    logger.info('After filter {} tokens, the final vocab size is {}, char size is {}'.format(filtered_num,
                                                                            vocab.word_size(), vocab.char_size()))

    unfiltered_vocab_char_size = vocab.char_size()
    vocab.filter_chars_by_cnt(min_cnt=2)
    filtered_char_num = unfiltered_vocab_char_size - vocab.char_size()
    logger.info('After filter {} chars, the final char vocab size is {}'.format(filtered_char_num,
                                                                            vocab.char_size()))

    logger.info('Assigning embeddings...')
    if args.pretrained_word_path is not None:
        vocab.load_pretrained_word_embeddings(args.pretrained_word_path)
    else:
        vocab.randomly_init_word_embeddings(args.word_embed_size)
    
    if args.pretrained_char_path is not None:
        vocab.load_pretrained_char_embeddings(args.pretrained_char_path)
    else:
        vocab.randomly_init_char_embeddings(args.char_embed_size)

    logger.info('Saving vocab...')
    print('Saving vocab...')
    with open(os.path.join(args.vocab_dir, dataName+'OurVocab.data'), 'wb') as fout:
        pickle.dump(vocab, fout)

    logger.info('====== Done with preparing! ======')