예제 #1
0
def pred_input(params, logger, enc=None, path_to_prompt=""):

    unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
               "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
               "researchers was the fact that the unicorns spoke perfect English."

    text = unicorns if path_to_prompt == "" else open(path_to_prompt,
                                                      "r").read()
    tokens = encode(enc, text)

    if len(tokens) > params["n_ctx"]:
        logger.info(
            "The length of your input prompt is longer than the model's context length - truncating input."
        )
        tokens = tokens[len(tokens) - params["n_ctx"]:]
    if len(tokens) < params["n_ctx"]:
        tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]],
                        constant_values=params["padding_id"])
    t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]])
    dataset = tf.data.Dataset.from_tensors(t)

    def _dummy_labels(x):
        return x, x

    dataset = dataset.map(_dummy_labels)
    return dataset
예제 #2
0
파일: tasks.py 프로젝트: zxhjiutian/gpt-neo
def lambada_create_tokens_data(params, path):
    with open(path, 'w') as f:
        req = requests.get(lambada_src_uri)
        req.raise_for_status()
        jsons = [json.loads(l) for l in req.iter_lines()]
        texts = [ftfy.fix_text(j['text'], normalization=normalization) for j in jsons]
        enc = fetch_encoder(params)
        arrays = [encode(enc, t) for t in texts]
        json.dump(arrays, f)
        return arrays
예제 #3
0
파일: tasks.py 프로젝트: SiZ-oLab/GPTNeo
def wikitext_create_tokens_data(params, path, version="wikitext2"):
    assert version.lower() in ["wikitext2", "wikitext103"]
    wikitext2_src = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip"
    wikitext103_src = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip"
    version_src = wikitext103_src if version.lower(
    ) == "wikitext103" else wikitext2_src
    with open(path, 'w') as f:
        wikitext_path = f"./{version}-raw-v1.zip"
        os.system(f"wget {version_src} -O {wikitext_path}")
        os.makedirs(f"{version}", exist_ok=True)
        os.system(f"unzip {wikitext_path} -d {version}")
        n = 103 if version.lower() == "wikitext103" else 2
        with open(f"./{version}/wikitext-{n}-raw/wiki.test.raw", 'r') as wt:
            text = ftfy.fix_text(wikitext_detokenizer(wt.read()))
        enc = fetch_encoder(params)
        encoded_text = encode(enc, text)
        arrays = []
        for i in range(0, len(encoded_text), params["n_ctx"] - 1):
            arrays.append(encoded_text[i:i + params["n_ctx"] - 1])
        json.dump(arrays, f)
        return arrays