def run_test(model_dir, data_dir, mode, config_path='345M/', beam_width=10):
    config_path = config_path + 'config.json'
    vocab_path = config_path + 'vocab.json'
    merge_path = config_path + 'merges.txt'
    checkpoint_path = model_dir + '/GPT_model.pkl'
    log_filename = model_dir + '/test_data.log'

    config = GPT2Config.from_json_file(os.path.join('./configs/', config_path))

    create_log(log_filename)
    print("Building model")
    model = load_model(GPT2LMHeadModel(config), checkpoint_path,
                       test=True).cuda()
    model.eval()
    tokenizer = GPT2Tokenizer(vocab_path, merge_path)
    if mode == 'test':
        print('Loading test dataset...')
        test_data_loader = GPT2DataLoader(data_path=data_dir,
                                          vocab_file=vocab_path,
                                          bpe_merges=merge_path,
                                          bucket=2,
                                          batch_size=1,
                                          max_seq_len=512)
Ejemplo n.º 2
0
from pytorch_pretrained_bert.modeling_gpt2 import GPT2Config
from data_loader import GPT2DataLoader
from train import run
import os
import torch

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_size = 'small'
    if model_size == 'small':
        config_path = '117M/config.json'
    elif model_size == 'middle':
        config_path = '345M/config.json'
    elif model_size == 'big':
        config_path = '762M/config.json'
    config = GPT2Config.from_json_file(os.path.join('./configs/', config_path))
    model = load_model(GPT2LMHeadModel(config), "checkpoints/small_fs.pkl")
    model = model.to(device)

    train_data_loader = GPT2DataLoader(data_path='DailyDialog/train_text.txt',
                                       vocab_file='./vocab_file/encoder.json',
                                       bpe_merges='vocab_file/merges.txt',
                                       bucket=2,
                                       batch_size=5,
                                       max_seq_len=512)

    valid_data_loader = GPT2DataLoader(data_path='DailyDialog/test_text.txt',
                                       vocab_file='./vocab_file/encoder.json',
                                       bpe_merges='vocab_file/merges.txt',
                                       bucket=2,
                                       batch_size=5,