示例#1
0
def setup(data_folder):
    np.random.seed(0)
    torch.cuda.manual_seed_all(0)
    torch.manual_seed(0)

    codec = get_encoder()

    dataset = NewsDataset(path=data_folder,
                          ctx_length=128,
                          codec=codec,
                          start_from_zero=True)

    config = GPT2Config()
    model = GPT2LMHeadModel(config)
    if not os.path.exists('gpt2-pytorch_model.bin'):
        print("Downloading GPT-2 checkpoint...")
        url = 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin'
        r = requests.get(url, allow_redirects=True)
        open('gpt2-pytorch_model.bin', 'wb').write(r.content)

    model = load_weight(
        model, torch.load('gpt2-pytorch_model.bin', map_location=device))
    model = model.to(device)
    model.eval()
    return codec, model, dataset, config
示例#2
0
def setup(n_enc_layer=1):
    np.random.seed(0)
    torch.cuda.manual_seed_all(0)
    torch.manual_seed(0)

    codec = get_encoder()
    config = GPT2Config(n_enc_layer=n_enc_layer)
    model = GPT2LMHeadModel(config)
    if not os.path.exists('../gpt2-pytorch_model.bin'):
        print("Downloading GPT-2 checkpoint...")
        url = 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin'
        r = requests.get(url, allow_redirects=True)
        open('../gpt2-pytorch_model.bin', 'wb').write(r.content)

    model = load_weight(
        model, torch.load('../gpt2-pytorch_model.bin', map_location=device))
    model = model.to(device)
    return codec, model, config
示例#3
0
def text_generator(state_dict):
    parser = argparse.ArgumentParser()
    parser.add_argument("--text", type=str, required=True)
    parser.add_argument("--quiet", type=bool, default=False)
    parser.add_argument("--nsamples", type=int, default=1)
    parser.add_argument('--unconditional',
                        action='store_true',
                        help='If true, unconditional generation.')
    parser.add_argument("--batch_size", type=int, default=-1)
    parser.add_argument("--length", type=int, default=-1)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--top_k", type=int, default=40)
    args = parser.parse_args()

    if args.quiet is False:
        print(args)

    if args.batch_size == -1:
        args.batch_size = 1
    assert args.nsamples % args.batch_size == 0

    seed = random.randint(0, 2147483647)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load Model
    enc = get_encoder()
    config = GPT2Config()
    model = GPT2LMHeadModel(config)
    model = load_weight(model, state_dict)
    model.to(device)
    model.eval()

    if args.length == -1:
        args.length = config.n_ctx // 2
    elif args.length > config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         config.n_ctx)

    print(args.text)
    context_tokens = enc.encode(args.text)

    generated = 0
    for _ in range(args.nsamples // args.batch_size):
        out = sample_sequence(
            model=model,
            length=args.length,
            context=context_tokens if not args.unconditional else None,
            start_token=enc.encoder['<|endoftext|>']
            if args.unconditional else None,
            batch_size=args.batch_size,
            temperature=args.temperature,
            top_k=args.top_k,
            device=device)
        out = out[:, len(context_tokens):].tolist()
        for i in range(args.batch_size):
            generated += 1
            text = enc.decode(out[i])
            if args.quiet is False:
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
            print(text)