def pred_input(params, logger, enc=None, path_to_prompt=""): unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ "researchers was the fact that the unicorns spoke perfect English." text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read() tokens = encode(enc, text) if len(tokens) > params["n_ctx"]: logger.info( "The length of your input prompt is longer than the model's context length - truncating input." ) tokens = tokens[len(tokens) - params["n_ctx"]:] if len(tokens) < params["n_ctx"]: tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"]) t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]]) dataset = tf.data.Dataset.from_tensors(t) def _dummy_labels(x): return x, x dataset = dataset.map(_dummy_labels) return dataset
def lambada_create_tokens_data(params, path): with open(path, 'w') as f: req = requests.get(lambada_src_uri) req.raise_for_status() jsons = [json.loads(l) for l in req.iter_lines()] texts = [ftfy.fix_text(j['text'], normalization=normalization) for j in jsons] enc = fetch_encoder(params) arrays = [encode(enc, t) for t in texts] json.dump(arrays, f) return arrays
def wikitext_create_tokens_data(params, path, version="wikitext2"): assert version.lower() in ["wikitext2", "wikitext103"] wikitext2_src = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip" wikitext103_src = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip" version_src = wikitext103_src if version.lower( ) == "wikitext103" else wikitext2_src with open(path, 'w') as f: wikitext_path = f"./{version}-raw-v1.zip" os.system(f"wget {version_src} -O {wikitext_path}") os.makedirs(f"{version}", exist_ok=True) os.system(f"unzip {wikitext_path} -d {version}") n = 103 if version.lower() == "wikitext103" else 2 with open(f"./{version}/wikitext-{n}-raw/wiki.test.raw", 'r') as wt: text = ftfy.fix_text(wikitext_detokenizer(wt.read())) enc = fetch_encoder(params) encoded_text = encode(enc, text) arrays = [] for i in range(0, len(encoded_text), params["n_ctx"] - 1): arrays.append(encoded_text[i:i + params["n_ctx"] - 1]) json.dump(arrays, f) return arrays