예제 #1
0
파일: generate.py 프로젝트: databill86/duet
def main(model: GPT2LMHeadModel, enc: GPT2Tokenizer, phrase: str = ''):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    nsamples = 1
    length = 40
    temperature = 1.2
    top_k = 0
    top_p = 0.9
    batch_size = 1
    stop_token = [enc.encoder[x] for x in ('<|endoftext|>', '.', '?', '!')]
    assert nsamples % batch_size == 0

    if length == -1:
        length = model.config.n_ctx // 2
    elif length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         model.config.n_ctx)

    context_tokens = enc.encode(phrase) if phrase else [
        enc.encoder['<|endoftext|>']
    ]
    generated = 0
    out = sample_sequence(model=model,
                          length=length,
                          context=context_tokens,
                          start_token=None,
                          batch_size=batch_size,
                          temperature=temperature,
                          top_k=top_k,
                          device=device,
                          top_p=top_p,
                          stop_token=stop_token)
    out = out[:, len(context_tokens):].tolist()
    return enc.decode(out[0])
예제 #2
0
def encode_many_texts(tokenizer: GPT2Tokenizer, texts: Iterable[str]) \
-> torch.Tensor:
    """Uses -1 as padding."""
    encoded_texts = [tokenizer.encode(text) for text in texts]
    max_len = max(len(text) for text in encoded_texts)
    padded_encoded_texts = [
        text + [-1] * (max_len - len(text)) for text in encoded_texts
    ]
    return torch.tensor(padded_encoded_texts)