def create_vm(): if not _FLAG_INSTANCE_TYPE.value: raise ValueError("Using the full gcloud launcher is useless " "without an instance type.") validate_instance_type_flag() positional = [ "gcloud", "compute", "instances", "create", _FLAG_INSTANCE_NAME.value, ] if _FLAG_PREEMPTIBLE_VM.value: positional.append("--preemptible") named_flags = { "--zone": _FLAG_ZONE.value, "--image-family": _FLAG_IMAGE_FAMILY.value, "--image-project": "deeplearning-platform-release", "--machine-type": _FLAG_INSTANCE_TYPE.value, "--boot-disk-size": f"{_FLAG_BOOT_DISK_SIZE.value}GB", "--scopes": "cloud-platform", } for key, value in named_flags.items(): utils.check_isinstance(value, str) utils.check_isinstance(key, str) for key in named_flags: assert key.startswith("--"), key h2("Creating the VM instance.") command = positional + [ f"{k}={shlex.quote(v)}" for k, v in named_flags.items() ] run_gcloud_command(command) print("") time.sleep(_FLAG_SLEEP_TIME.value) h2("Starting the instance.") command = [ "gcloud", "compute", "instances", "start", _FLAG_INSTANCE_NAME.value ] run_gcloud_command(command) print("") time.sleep(_FLAG_SLEEP_TIME.value)
def main(argv): if len(argv) > 1: raise RuntimeError(argv[1:]) absl_logging.use_python_logging() utils.check_contained(_FLAG_APPROACH_TYPE.value, _ACCEPTABLE_APPROACHES) utils.check_operator(operator.xor, bool(_FLAG_H5_MODEL_PATH.value), bool(_FLAG_CKPT_MODEL_PATH.value)) if _FLAG_H5_MODEL_PATH.value: model_path = _FLAG_H5_MODEL_PATH.value mode = constants.SaveModeChoices.hfh5 elif _FLAG_CKPT_MODEL_PATH.value: model_path = _FLAG_CKPT_MODEL_PATH.value mode = constants.SaveModeChoices.ckpt else: raise RuntimeError("Logically should never happen.") utils.check_exists(model_path) device_type = tf_utils.devices_to_use()[0].device_type # ONLY GPU IS SUPPORTED utils.check_equal(device_type, "GPU") #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Build the distribution strategy #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if device_type == "TPU": # ONLY LOCAL TPU IS "SUPPORTED" utils.check_isinstance(_FLAG_IS_LOCAL_TPU.value, bool) assert _FLAG_IS_LOCAL_TPU.value tpu_config = tf_utils.init_tpus(local=True) utils.check_isinstance(tpu_config, tf_utils.TpuConfigType) utils.check_not_none(tpu_config) strategy = tf.distribute.TPUStrategy(tpu_config.resolver) elif device_type == "GPU": strategy = tf.distribute.MirroredStrategy( devices=tf.config.experimental.list_logical_devices('GPU')) else: raise RuntimeError(device_type) # ONLY GPU IS SUPPORTED print(tf.config.list_logical_devices()) utils.check_isinstance(strategy, tf.distribute.MirroredStrategy) ############################################################################## # Load Model ############################################################################## with utils.log_duration(LOGGER, main.__name__, "All of model preparation"): with strategy.scope(): # HF isn't able to read directly from GCS if (model_path.startswith("gs://") and mode == constants.SaveModeChoices.hfh5): with utils.log_duration(LOGGER, main.__name__, "Download model from GS"): with tempfile.TemporaryDirectory() as td: td += os.path.sep if os.path.exists("/root/google-cloud-sdk/bin/gsutil"): exec_ = "/root/google-cloud-sdk/bin/gsutil" else: exec_ = "gsutil" command = [ exec_, "-m", "cp", "-r", os.path.join(model_path, "*"), td, ] LOGGER.debug("Running bash command: %s", " ".join(command)) subprocess.check_call(command) LOGGER.debug("Files at the temp dir(%s): %s", td, str(os.listdir(td))) model = make_model_tf(td, mode=mode) else: model = make_model_tf(model_path, mode=mode) utils.check_not_none(model) ############################################################################## # Load Dataset Pipeline ############################################################################## utils.check_contained( _FLAG_APPROACH_TYPE.value, { constants.ApproachTypeChoices.naked_lm, constants.ApproachTypeChoices.cached_pretok }) devices = tf_utils.devices_to_use() num_replicas = (len(devices) if devices[0].device_type in {"GPU", "TPU"} else 1) utils.check_equal(devices[0].device_type, "GPU") # Only a batch size of 1 is currently supported. We need attention masks batch_size = _FLAG_BATCH_SIZE.value * num_replicas approach_type = _FLAG_APPROACH_TYPE.value logging.debug("Loading dataset.") tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl") ds = prep_ds_for_generation( dict( tokenizer=tokenizer, context_window_size=1024, dataset_name="kilt_eli5", batch_size=1, # >> We set our own batch size elsewhere db_path=None, # None, random_seed=0, use_subset=False, subset_size=-1, use_helper_words=True, approach_type=approach_type, num_retrievals=5, # Will never change retrieval_temperature=1., retriever=None, # Cached retrievals don't need a retriever repeat=False, # Will never change split=_FLAG_SPLIT.value, enable_debug_checks=False, retrieval_bank_size=5, # Will never change dataset_type=_FLAG_DATASET_TYPE.value, tfr_prefix=_FLAG_TFR_PREFIX.value, qty_shuffle=1, # Will never change max_length_generation=350), tokenizer, _FLAG_SPLIT.value) ds = strategy.experimental_distribute_dataset(ds) ############################################################################## # Generate ############################################################################## LOGGER.debug("Generating.") generations = [] num_entries_in_split = ( task_specific.DATASET_CARDINALITIES["kilt_eli5"][_FLAG_SPLIT.value]) entries_counter = tqdm.tqdm(total=num_entries_in_split) for batch_no, batch in enumerate(ds): # Calling model.generate. We should make a config file with the # hyperparameters for generation, or make a facility in the one we already # have. I feel like a separate one would be better, separating concerns. output = strategy.run( model.generate, kwargs=dict( input_ids=batch, max_length=_FLAG_GENERATION_LENGTH_LIMIT.value, use_cache=True, attention_mask=tf.cast(batch != tokenizer.eos_token_id, tf.int32), repetition_penalty=2., num_beams=5, )) output = tf_utils.process_strat_output(strategy_outputs=output, current_batch_size=batch_size, strategy=strategy, name="generations") #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Display the inputs and outputs. #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ rich_console = rich.console.Console(color_system="256") print_sample = make_print_sample() with utils.log_duration(LOGGER, "main", "all of tokenizer.decode for a batch."): for i in range(batch_size): input_text = tokenizer.decode(batch.numpy()[i]) output_text = tokenizer.decode(output.numpy()[i]) print("#" * 1000) print(f"Batch {batch_no} Generation {i}") print_sample(input_text, f"input batch_no {batch_no}", rich_console) print_sample(output_text, f"output batch_no {batch_no}", rich_console) generations.append(output_text) print("#" * 1000) entries_counter.update(batch.shape[0]) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save the output to a JSON File. #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ utils.to_json_file( os.path.join(_FLAG_OUTPUT_PATH.value, _FLAG_SPLIT.value, _FLAG_APPROACH_TYPE.value, time.strftime("%Y%m%d-%H%M%S.json")), dict(flags={ flag.name: flag.value for flag in flags.FLAGS.flags_by_module_dict()[argv[0]] }, generations=generations)) logging.debug("Saved to: %s", _FLAG_OUTPUT_PATH.value)
def _prepare_samples_w_retrieval(split, batch_size, question_ids_inputs, answer_ids_inputs, gpt2_tokenized_retrieved, distances, num_retrievals, temperature, context_size, enable_debug_checks, use_helper_words, helper_word_token_ids, max_generation_length): """Prepares the samples that use retrieval.""" assert (split == constants.SplitChoices.test) == ( answer_ids_inputs is None), (split == constants.SplitChoices.test, answer_ids_inputs) # If and only if is_not_test = split != constants.SplitChoices.test if not isinstance(question_ids_inputs, tf.RaggedTensor): question_ids_inputs = tf.RaggedTensor.from_tensor( question_ids_inputs, padding=constants.RAGGED_PADDING_ID) if enable_debug_checks: asserts = [] asserts.append( tf.Assert( tf.math.reduce_all( question_ids_inputs != constants.RAGGED_PADDING_ID, ), [question_ids_inputs.to_tensor()])) if is_not_test: asserts.append( tf.Assert( tf.math.reduce_all( answer_ids_inputs != constants.RAGGED_PADDING_ID, ), [answer_ids_inputs.to_tensor()])) with tf.control_dependencies(asserts): question_ids_inputs = tf.identity(question_ids_inputs) # These checks are at graph composition time, so OK utils.check_isinstance(question_ids_inputs, tf.RaggedTensor) if is_not_test: utils.check_isinstance(answer_ids_inputs, tf.RaggedTensor) ############################################################################## # Sample from the possible retrievals ############################################################################## # Choose the indices indices = tf_utils.sample_without_replacement(distances / temperature, num_retrievals) # Concatenate the retrievals concat_retrieved = _tokenize_and_concat_while_loop( gpt2_tokenized_retrieved, indices=indices, batch_size=batch_size, num_retrieved=num_retrievals, ) # Add Context and Answer Helper Words if use_helper_words: concat_retrieved = tf.concat([ helper_word_token_ids["context"], concat_retrieved, ], axis=1) # Cut the lengths down to max_lens_retrieval. # The eventual length of the ["question"] helper_tokens is included in # question_ids_inputs. if is_not_test: max_lens_retrieval = ( context_size * tf.ones( shape=(batch_size, ), dtype=tf.int64, ) - ( question_ids_inputs.row_lengths() + # We always generate the same length of text. max_generation_length + # answer_ids_inputs.row_lengths() + (helper_word_token_ids["answer"].shape[1] if use_helper_words else 0))) else: max_lens_retrieval = (context_size * tf.ones( shape=(batch_size, ), dtype=tf.int64, ) - (question_ids_inputs.row_lengths() + max_generation_length + (helper_word_token_ids["answer"].shape[1] if use_helper_words else 0))) concat_retrieved = tf.ragged.boolean_mask( concat_retrieved, (tf.ragged.range(concat_retrieved.row_lengths()) < tf.expand_dims(max_lens_retrieval, axis=1))) if enable_debug_checks: asserts = [ tf.Assert(tf.math.reduce_all(max_lens_retrieval < context_size), [max_lens_retrieval, context_size]), ] with tf.control_dependencies(asserts): concat_retrieved = tf.identity(concat_retrieved) if use_helper_words: if is_not_test: new_input_ids = tf.concat([ question_ids_inputs, concat_retrieved, helper_word_token_ids["answer"], answer_ids_inputs ], axis=1) new_label_ids = tf.concat([ -100 * tf.ones_like(question_ids_inputs), -100 * tf.ones_like(concat_retrieved), -100 * tf.ones_like(helper_word_token_ids["answer"]), answer_ids_inputs ], axis=1) else: new_input_ids = tf.concat([ question_ids_inputs, concat_retrieved, helper_word_token_ids["answer"], ], axis=1) else: if is_not_test: new_input_ids = tf.concat( [question_ids_inputs, concat_retrieved, answer_ids_inputs], axis=1) new_label_ids = tf.concat([ -100 * tf.ones_like(question_ids_inputs), -100 * tf.ones_like(concat_retrieved), answer_ids_inputs ], axis=1) else: new_input_ids = tf.concat([ question_ids_inputs, concat_retrieved, ], axis=1) return new_input_ids, new_label_ids if is_not_test else None
def main(argv): if len(argv) > 1: raise RuntimeError(argv) absl_logging.use_python_logging() utils.log_module_args(LOGGER, argv[0]) retriever_config = tf_utils.REALMSave( **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value)) assert not _FLAG_USE_SUBSET.value time_stamp = time.strftime("%Y%m%d-%H%M%S") target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp.strip()) if target_path[-1] != "/": target_path += "/" ############################################################################## # Setup devices and strategy ############################################################################## with utils.log_duration(LOGGER, "main", "Initializing devices"): tpu_config = tf_utils.init_tpus() device_type = tf_utils.current_accelerator_type() LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use())) if device_type == "TPU": if tpu_config is None: raise RuntimeError("We should have a tpu_config.") strategy = tf.distribute.TPUStrategy(tpu_config.resolver) batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value elif device_type == "GPU" or device_type == "CPU": strategy = tf.distribute.MirroredStrategy() batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value else: raise RuntimeError(device_type) ############################################################################## # Load the dataset. ############################################################################## eli5 = {} keys = ["train", "eval", "test"] gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl") gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."): for split in tqdm.tqdm(keys): load_path = os.path.join(_FLAG_DATASET_ROOT.value, "HuggingfaceDatasets", f"{split}_kilt_eli5.hf") with tf.device("/job:localhost"): eli5[split] = datasets.load_from_disk(load_path) ############################################################################## # ############################################################################## with utils.log_duration(LOGGER, "Main", "Load the textual dataset"): # Extract the appropriate text # The buffer_size is taken from the original ORQA code. blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records, buffer_size=512 * 1024 * 1024) blocks_dataset = blocks_dataset.batch( retriever_config.num_block_records, drop_remainder=True) blocks = tf.data.experimental.get_single_element(blocks_dataset) ############################################################################ # Prepare the output file. ############################################################################ writers = {} all_paths = {} for split in keys: maybe_subset = "_subset" if _FLAG_USE_SUBSET.value else "" paths = [ os.path.join(target_path + maybe_subset, f"{split}_{i}.tfr") for i in range(_FLAG_NUM_SHARDS.value) ] all_paths[split] = paths writers[split] = [tf.io.TFRecordWriter(filename) for filename in paths] with utils.log_duration(LOGGER, "main", "Loading the reference db."): checkpoint_path = os.path.join( retriever_config.query_embedder_path, "encoded", "encoded.ckpt") reference_db_device = tf_utils.device_mapping().CPUs[0].name with tf.device(reference_db_device): reference_db = tf_utils.load_reference_db( checkpoint_path, variable_name="block_emb", ) ############################################################################ # Prep the encoder and the tokenizer ############################################################################ with utils.log_duration(LOGGER, "main", "Loading the encoder model and the tokenizer."): with strategy.scope(): query_encoder = hub.load(retriever_config.query_embedder_path, tags={}) encode_fn = _make_encode_fn(query_encoder) encode_fn_strategy_run = make_encode_fn_strategy_run_fn( strategy=strategy, encode_fn=encode_fn, ) vocab_file = os.path.join(retriever_config.query_embedder_path, "assets", "vocab.txt") utils.check_exists(vocab_file) do_lower_case = query_encoder.signatures["tokenization_info"]( )["do_lower_case"] tokenization_info = dict(vocab_file=vocab_file, do_lower_case=do_lower_case) tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( query_encoder, tokenization_info) ############################################################################ # Preprocess the dataset ############################################################################ cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")), tf.int32) sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")), tf.int32) transform = _make_transform_fn( bert_tokenizer=tokenizer, bert_cls_token_id=cls_token_id, bert_sep_token_id=sep_token_id, ) feature_dtypes = { constants.CTH5Fields.distances: tf.float32, constants.CTH5Fields.gpt2_retrieved_ids: tf.int32, constants.CTH5Fields.gpt2_answer_ids_inputs: tf.int32, constants.CTH5Fields.gpt2_question_ids_inputs: tf.int32, } with utils.log_duration(LOGGER, "main", "generating codes"): for split in keys: sample_count = 0 eli5: Dict[str, datasets.Dataset] if split != "test": for_slices = dict(sample_id=eli5[split]["id"], question=eli5[split]["input"], answer=[ sample["answer"][0] for sample in eli5[split]["output"] ]) else: for_slices = dict( sample_id=eli5[split]["id"], question=eli5[split]["input"], ) ds = tf.data.Dataset.from_tensor_slices(for_slices) ds = ds.map(transform, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.apply( tf.data.experimental.dense_to_ragged_batch(batch_size)) ds = ds.map(_squeeze, num_parallel_calls=tf.data.experimental.AUTOTUNE) tqdm_inner = tqdm.tqdm(enumerate(ds), total=len(eli5[split]["id"]) // _FLAG_BATCH_SIZE.value, desc=f"Split `{split}`: Batches") for i, batch in tqdm_inner: features = collections.defaultdict(list) ###################################################################### # Enforce the current real batch size ###################################################################### current_batch_size = batch["sample_id"].shape[0] for k, v in batch.items(): utils.check_equal(v.shape[0], current_batch_size) ###################################################################### gpt2_question_ids_inputs = _prep_field(batch["question"], gpt2_tokenizer) utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_question_ids_inputs.shape[0], current_batch_size) if split != "test": gpt2_answer_ids_inputs = _prep_field( batch["answer"], gpt2_tokenizer) utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_answer_ids_inputs.shape[0], current_batch_size) assert len(gpt2_answer_ids_inputs.shape) == 2, ( gpt2_answer_ids_inputs.shape) ###################################################################### # Save the gpt2 tokenized question and answer ###################################################################### features[constants.CTH5Fields.gpt2_question_ids_inputs].extend( gpt2_question_ids_inputs) if split != "test": features[ constants.CTH5Fields.gpt2_answer_ids_inputs].extend( gpt2_answer_ids_inputs) ###################################################################### # Encode the samples. ###################################################################### batch = strategy.experimental_distribute_values_from_function( tf_utils.make_dict_distribute_fn(batch)) embeddings = encode_fn_strategy_run(batch) embeddings = tf_utils.process_strat_output( embeddings, "embeddings", strategy, current_batch_size) utils.check_isinstance(embeddings, ops.EagerTensor) utils.check_equal(embeddings.shape[0], current_batch_size) # pytype doesn't seem to see that we check the type utils.check_equal(embeddings.shape[1], _FLAG_EMBEDDING_DEPTH.value) # pytype: disable=attribute-error ###################################################################### # Retrieve. ###################################################################### with tf.device(reference_db_device): top_k, inner_prods = tf_utils.mips_exact_search( embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db) top_k = tf_utils.process_strat_output(top_k, "top_k", strategy, current_batch_size) utils.check_equal( inner_prods.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) utils.check_equal( top_k.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) features[constants.CTH5Fields.distances].extend(inner_prods) gathered = tf.gather(blocks, top_k).numpy() utils.check_equal(gathered.shape[0], current_batch_size) retrievals = [] for j in range(gathered.shape[0]): local_gathered = gathered[j].tolist() utils.check_equal(len(local_gathered), _FLAG_NUM_RETRIEVALS.value) local_gathered = [ sample.decode() for sample in local_gathered ] token_ids = np.array( gpt2_tokenizer.batch_encode_plus( local_gathered, padding="max_length", truncation=True, ).input_ids) for line in token_ids: assert not np.all(line == 0), line token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1 retrievals.append(token_ids) features[constants.CTH5Fields.gpt2_retrieved_ids] = retrievals utils.check_equal( retrievals[0].shape, (_FLAG_NUM_RETRIEVALS.value, _FLAG_CONTEXT_SIZE.value)) for k, v in features.items(): utils.check_equal(len(v), current_batch_size) for k in range(current_batch_size): feature = tf.train.Features( feature={ k: _bytes_feature( tf.io.serialize_tensor( tf.cast(v[k], feature_dtypes[k]))) for k, v in features.items() }) writers[split][ sample_count % _FLAG_NUM_SHARDS.value].write( tf.train.Example( features=feature).SerializeToString()) sample_count += 1 if sample_count % 1000 == 0: LOGGER.debug("Paths: %s", str(all_paths[split][0])) LOGGER.debug("Done.")
def main(argv): # Arguments and logging boilerplate if len(argv) > 1: raise RuntimeError(argv) absl_logging.use_python_logging() utils.log_module_args(LOGGER, argv[0]) # Load a retriever config. retriever_config = tf_utils.REALMConfig( **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value)) assert not _FLAG_USE_SUBSET.value # Preparation of the output path time_stamp = time.strftime("%Y%m%d-%H%M%S") target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp.strip()) if target_path[-1] != "/": target_path += "/" ############################################################################## # Setup devices and strategy ############################################################################## # Duration is pretty much instantaneous with utils.log_duration(LOGGER, "main", "Initializing devices"): tpu_config = tf_utils.init_tpus(local=_FLAG_TPU_IS_LOCAL.value, tpu_name=_FLAG_TPU_NAME.value) device_type = tf_utils.current_accelerator_type() LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use())) if _FLAG_TPU_NAME.value and device_type == "CPU": raise RuntimeError("Device is CPU and we expected a TPU.") if device_type == "TPU": if tpu_config is None: raise RuntimeError("We should have a tpu_config.") strategy = tf.distribute.TPUStrategy(tpu_config.resolver) batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value elif device_type == "GPU" or device_type == "CPU": strategy = tf.distribute.MirroredStrategy() batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value else: raise RuntimeError(device_type) ############################################################################## # Load the KILT ELI5 dataset. ############################################################################## # Takes a while eli5 = {} keys = ["train", "validation", "test"] gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl") gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."): if _FLAG_DATASET_ROOT.value: for split in tqdm.tqdm(keys): load_path = os.path.join(_FLAG_DATASET_ROOT.value, "HuggingfaceDatasets", f"{split}_kilt_eli5.hf") with tf.device("/job:localhost"): eli5[split] = datasets.load_from_disk(load_path) else: eli5 = datasets.load_dataset("kilt_tasks", "eli5") ############################################################################## # Load the dataset of the text that will be retrieved. ############################################################################## # Takes a long time with utils.log_duration(LOGGER, "Main", "Load the textual dataset"): # Extract the appropriate text # The buffer_size is taken from the original ORQA code. blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records, buffer_size=512 * 1024 * 1024) blocks_dataset = blocks_dataset.batch( retriever_config.num_block_records, drop_remainder=False) blocks: tf.Tensor = tf.data.experimental.get_single_element( blocks_dataset) ############################################################################ # Increase the number of maximum open file descriptors to make space # for all the shards. ############################################################################ max_num_fd = _FLAG_NUM_SHARDS.value * 3 + _MIN_N_FD resource.setrlimit(resource.RLIMIT_NOFILE, (max_num_fd, max_num_fd)) ############################################################################ # Prepare the output files. ############################################################################ writers = {} all_paths = {} for split in keys: maybe_subset = "_subset" if _FLAG_USE_SUBSET.value else "" # Prepare paths. They can't be in a generator. A function generator would be # fine though. paths = [ os.path.join(target_path + maybe_subset, f"{split}_{i}.tfr") for i in range(_FLAG_NUM_SHARDS.value) ] all_paths[split] = paths writers[split] = [] # Create The TFR writers. for i, path in enumerate(paths): writers[split].append(tf.io.TFRecordWriter(path)) # Load the reference DB. We used to accidentally do this once per split :O with utils.log_duration(LOGGER, "main", "Loading the reference db."): checkpoint_path = os.path.join(retriever_config.query_embedder_path, "encoded", "encoded.ckpt") reference_db_device = tf_utils.device_mapping().CPUs[0].name with tf.device(reference_db_device): reference_db = tf_utils.load_reference_db( checkpoint_path, variable_name="block_emb", ) ############################################################################ # Prep the encoder and the tokenizer ############################################################################ with utils.log_duration(LOGGER, "main", "Loading the encoder model and the tokenizer."): with strategy.scope(): query_encoder = hub.load(retriever_config.query_embedder_path, tags={}) encode_fn = _make_encode_fn(query_encoder) encode_fn_strategy_run = make_encode_fn_strategy_run_fn( strategy=strategy, encode_fn=encode_fn, ) vocab_file = os.path.join(retriever_config.query_embedder_path, "assets", "vocab.txt") utils.check_exists(vocab_file) do_lower_case = query_encoder.signatures["tokenization_info"]( )["do_lower_case"] tokenization_info = dict(vocab_file=vocab_file, do_lower_case=do_lower_case) tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( query_encoder, tokenization_info) ############################################################################ # Preprocess the dataset ############################################################################ cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")), tf.int32) sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")), tf.int32) transform = _make_transform_fn( bert_tokenizer=tokenizer, bert_cls_token_id=cls_token_id, bert_sep_token_id=sep_token_id, ) feature_dtypes = { constants.CTH5Fields.distances: tf.float32, constants.CTH5Fields.gpt2_retrieved_ids: tf.int32, constants.CTH5Fields.gpt2_answer_ids_inputs: tf.int32, constants.CTH5Fields.gpt2_question_ids_inputs: tf.int32, } with utils.log_duration(LOGGER, "main", "generating codes"): for split in keys: sample_count = 0 eli5: Dict[str, datasets.Dataset] if split != "test": for_slices = dict(sample_id=eli5[split]["id"], question=eli5[split]["input"], answer=[ sample[0]["answer"] for sample in eli5[split]["output"] ]) else: for_slices = dict( sample_id=eli5[split]["id"], question=eli5[split]["input"], ) ds = tf.data.Dataset.from_tensor_slices(for_slices) ds = ds.map(transform, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.apply( tf.data.experimental.dense_to_ragged_batch(batch_size)) ds = ds.map(_squeeze, num_parallel_calls=tf.data.experimental.AUTOTUNE) tqdm_inner = tqdm.tqdm(enumerate(ds), total=len(eli5[split]["id"]) // _FLAG_BATCH_SIZE.value, desc=f"Split `{split}`: Batches") for i, batch in tqdm_inner: features = collections.defaultdict(list) ###################################################################### # Enforce the current real batch size ###################################################################### current_batch_size = batch["sample_id"].shape[0] for k, v in batch.items(): utils.check_equal(v.shape[0], current_batch_size) ###################################################################### gpt2_question_ids_inputs = _prep_field(batch["question"], gpt2_tokenizer) utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_question_ids_inputs.shape[0], current_batch_size) if split != "test": gpt2_answer_ids_inputs = _prep_field( batch["answer"], gpt2_tokenizer) utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_answer_ids_inputs.shape[0], current_batch_size) assert len(gpt2_answer_ids_inputs.shape) == 2, ( gpt2_answer_ids_inputs.shape) ###################################################################### # Save the gpt2 tokenized question and answer ###################################################################### features[constants.CTH5Fields.gpt2_question_ids_inputs].extend( gpt2_question_ids_inputs) if split != "test": features[ constants.CTH5Fields.gpt2_answer_ids_inputs].extend( gpt2_answer_ids_inputs) ###################################################################### # Encode the samples. ###################################################################### batch = strategy.experimental_distribute_values_from_function( tf_utils.make_dict_distribute_fn(batch)) embeddings = encode_fn_strategy_run(batch) embeddings = tf_utils.process_strat_output( embeddings, "embeddings", strategy, current_batch_size) utils.check_isinstance(embeddings, ops.EagerTensor) utils.check_equal(embeddings.shape[0], current_batch_size) # pytype doesn't seem to see that we check the type utils.check_equal(embeddings.shape[1], _FLAG_EMBEDDING_DEPTH.value) # pytype: disable=attribute-error ###################################################################### # Retrieve. ###################################################################### # Do exact retrieval with tf.device(reference_db_device): top_k, inner_prods = tf_utils.mips_exact_search( embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db) # Collate the results top_k = tf_utils.process_strat_output(top_k, "top_k", strategy, current_batch_size) # Check the shapes utils.check_equal( inner_prods.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) utils.check_equal( top_k.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) # Save the distances features[constants.CTH5Fields.distances].extend(inner_prods) # Retrieve the text fields associated to the indices gathered = tf.gather(blocks, top_k).numpy() utils.check_equal(gathered.shape[0], current_batch_size) utils.check_equal(gathered.shape[1], _FLAG_NUM_RETRIEVALS.value) retrievals = [] for index_in_batch in range(current_batch_size): # Put the appropriate byte strings in a list local_gathered = gathered[index_in_batch].tolist() utils.check_equal(len(local_gathered), _FLAG_NUM_RETRIEVALS.value) # Decode to utf-8 local_gathered = [ sample.decode() for sample in local_gathered ] # Encode to GPT2 BPE token_ids = np.array( gpt2_tokenizer.batch_encode_plus( local_gathered, padding="max_length", truncation=True, ).input_ids) # Make sure no line is empty # TODO(julesgm): Maybe optional for line in token_ids: assert not np.all(line == 0), line # Convert the eos_tokens token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1 # Save the retrievals retrievals.append(token_ids) # Save the feature features[constants.CTH5Fields.gpt2_retrieved_ids] = retrievals utils.check_equal( retrievals[0].shape, (_FLAG_NUM_RETRIEVALS.value, _FLAG_CONTEXT_SIZE.value)) for k, v in features.items(): utils.check_equal(len(v), current_batch_size) for index_in_batch in range(current_batch_size): feature_dict = {} for feature_k, feature_v in features.items(): # Cast the feature to its appropriate dtype casted_feats = tf.cast(feature_v[index_in_batch], feature_dtypes[feature_k]) # Serialize the tensor to bytes feature_bytes = tf.io.serialize_tensor(casted_feats) # Build a bytes list tf.train.Feature object, # the serialization tree node feature_dict[feature_k] = _bytes_feature(feature_bytes) # Create the serialization tree root # Expects a list of features feature = tf.train.Features(feature=feature_dict) # Expects a tf.train.Features object example_obj = tf.train.Example(features=feature) # Serialize that to bytes serialized_example = example_obj.SerializeToString() # Write the bytes # TODO(julesgm): Parallelize this with a thread or a process pool & # futures. writers[split][sample_count % _FLAG_NUM_SHARDS.value].write( serialized_example) sample_count += 1 if sample_count % 1000 == 0: LOGGER.debug("Paths: %s", str(all_paths[split][0])) LOGGER.debug("Flushing and closing the `%s` writers", split) for writer in tqdm.tqdm(writers[split]): writer.flush() writer.close() LOGGER.debug("Done.")
def main(argv): if len(argv) > 1: raise RuntimeError(argv) absl_logging.use_python_logging() retriever_config = tf_utils.REALMSave( **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value)) extra = "_FROM_SUBSET" if _FLAG_USE_SUBSET.value else "" time_stamp = time.strftime("%Y%m%d-%H%M%S") target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp + extra).strip() if target_path[-1] != "/": target_path += "/" ############################################################################## # Setup devices and strategy ############################################################################## with utils.log_duration(LOGGER, "main", "Initializing devices"): tpu_config = tf_utils.init_tpus() device_type = tf_utils.current_accelerator_type() LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use())) if device_type == "TPU": if tpu_config is None: raise RuntimeError("We should have a tpu_config.") strategy = tf.distribute.TPUStrategy(tpu_config.resolver) batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value elif device_type == "GPU" or device_type == "CPU": strategy = tf.distribute.MirroredStrategy() batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value else: raise RuntimeError(device_type) ############################################################################## # Load the dataset. ############################################################################## eli5 = {} keys = ["train", "eval", "test"] gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl") gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."): for split in tqdm.tqdm(keys): load_path = os.path.join(_FLAGS_DATASET_ROOT.value, "HuggingfaceDatasets", f"{split}_kilt_eli5.hf") with tf.device("/job:localhost"): eli5[split] = datasets.load_from_disk(load_path) if _FLAG_USE_SUBSET.value: _warn_subset() ############################################################################## # ############################################################################## with utils.log_duration(LOGGER, "Main", "Load the textual dataset"): # Extract the appropriate text # The buffer_size is taken from the original ORQA code. blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records, buffer_size=512 * 1024 * 1024) blocks_dataset = blocks_dataset.batch( retriever_config.num_block_records, drop_remainder=True) blocks = tf.data.experimental.get_single_element(blocks_dataset) with tempfile.TemporaryDirectory() as tmp_dir: ############################################################################ # Prepare the output file. ############################################################################ tmp_dir = pathlib.Path(tmp_dir) h5_output_path = tmp_dir / "codes.h5" output_file = h5py.File(h5_output_path, "w") flags_dict = { flag.name: flag.value for flag in flags.FLAGS.flags_by_module_dict()[argv[0]] } utils.to_json_file(tmp_dir / "params.json", flags_dict) for split in keys: with utils.log_duration( LOGGER, "main", "Creating the output hdf5 file, embeddings."): num_entries = len(eli5[split]["id"]) if _FLAG_USE_SUBSET.value: num_entries = min(num_entries, _FLAG_SUBSET_AMOUNT.value) split_group = output_file.create_group(split) with utils.log_duration( LOGGER, "main", "Creating the output hdf5 file, retrieval."): split_group.create_dataset( constants.CTH5Fields.distances, shape=(num_entries, _FLAG_NUM_RETRIEVALS.value), dtype=np.float32, ) split_group.create_dataset( constants.CTH5Fields.gpt2_question_ids_inputs, shape=(num_entries, _FLAG_CONTEXT_SIZE.value), dtype=np.int32) if split != "test": split_group.create_dataset( constants.CTH5Fields.gpt2_answer_ids_inputs, shape=(num_entries, _FLAG_CONTEXT_SIZE.value), dtype=np.int32) split_group.create_dataset( constants.CTH5Fields.gpt2_retrieved_ids, shape=( num_entries, _FLAG_NUM_RETRIEVALS.value, _FLAG_MAX_LENGTH_RETRIEVALS.value, ), dtype=np.int32) with utils.log_duration(LOGGER, "main", "Loading the reference db."): checkpoint_path = os.path.join( retriever_config.query_embedder_path, "encoded", "encoded.ckpt") reference_db_device = tf_utils.device_mapping().CPUs[0].name with tf.device(reference_db_device): reference_db = tf_utils.load_reference_db( checkpoint_path, variable_name="block_emb", ) ############################################################################ # Prep the encoder and the tokenizer ############################################################################ with utils.log_duration( LOGGER, "main", "Loading the encoder model and the tokenizer."): with strategy.scope(): query_encoder = hub.load(retriever_config.query_embedder_path, tags={}) encode_fn = _make_encode_fn(query_encoder) encode_fn_strategy_run = _make_encode_fn_strategy_run_fn( strategy=strategy, encode_fn=encode_fn, ) vocab_file = os.path.join(retriever_config.query_embedder_path, "assets", "vocab.txt") utils.check_exists(vocab_file) do_lower_case = query_encoder.signatures["tokenization_info"]( )["do_lower_case"] tokenization_info = dict(vocab_file=vocab_file, do_lower_case=do_lower_case) tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( query_encoder, tokenization_info) ############################################################################ # Preprocess the dataset ############################################################################ cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")), tf.int32) sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")), tf.int32) transform = _make_transform_fn( bert_tokenizer=tokenizer, bert_cls_token_id=cls_token_id, bert_sep_token_id=sep_token_id, ) with utils.log_duration(LOGGER, "main", "generating codes"): tqdm_splits = tqdm.tqdm(keys) for split in tqdm_splits: tqdm_splits.set_description(f"Split `{split}`") eli5: Dict[str, datasets.Dataset] write_start = 0 if _FLAG_USE_SUBSET.value: _warn_subset(tqdm_splits) eli5[split] = eli5[split][:_FLAG_SUBSET_AMOUNT.value] utils.check_operator(operator.le, len(eli5[split]["id"]), _FLAG_SUBSET_AMOUNT.value) utils.check_operator(operator.le, len(eli5[split]["input"]), _FLAG_SUBSET_AMOUNT.value) else: utils.check_equal(len(eli5[split]), len(eli5[split]["id"])) utils.check_equal(len(eli5[split]), len(eli5[split]["input"])) if split != "test": for_slices = dict(sample_id=eli5[split]["id"], question=eli5[split]["input"], answer=[ sample["answer"][0] for sample in eli5[split]["output"] ]) else: for_slices = dict( sample_id=eli5[split]["id"], question=eli5[split]["input"], ) ds = tf.data.Dataset.from_tensor_slices(for_slices) ds = ds.map(transform, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.apply( tf.data.experimental.dense_to_ragged_batch(batch_size)) ds = ds.map(_squeeze, num_parallel_calls=tf.data.experimental.AUTOTUNE) tqdm_inner = tqdm.tqdm(enumerate(ds), total=len(eli5[split]["id"]) // _FLAG_BATCH_SIZE.value, desc=f"Split `{split}`: Batches") for i, batch in tqdm_inner: ###################################################################### # Enforce the current real batch size ###################################################################### current_batch_size = batch["sample_id"].shape[0] for k, v in batch.items(): utils.check_equal(v.shape[0], current_batch_size) ###################################################################### gpt2_question_ids_inputs = _prep_field( batch["question"], gpt2_tokenizer) utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_question_ids_inputs.shape[0], current_batch_size) if split != "test": gpt2_answer_ids_inputs = _prep_field( batch["answer"], gpt2_tokenizer) utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_answer_ids_inputs.shape[0], current_batch_size) assert len(gpt2_answer_ids_inputs.shape) == 2, ( gpt2_answer_ids_inputs.shape) ###################################################################### # Save the gpt2 tokenized question and answer ###################################################################### end = write_start + current_batch_size utils.check_equal( output_file[split][ constants.CTH5Fields.gpt2_question_ids_inputs] [write_start:end].shape[0], current_batch_size) output_file[split][ constants.CTH5Fields.gpt2_question_ids_inputs][ write_start:end] = gpt2_question_ids_inputs if split != "test": output_file[split][ constants.CTH5Fields.gpt2_answer_ids_inputs][ write_start:end] = gpt2_answer_ids_inputs ###################################################################### # Encode the samples. ###################################################################### batch = strategy.experimental_distribute_values_from_function( tf_utils.make_dict_distribute_fn(batch)) embeddings = encode_fn_strategy_run(batch) embeddings = tf_utils.process_strat_output( embeddings, "embeddings", strategy, current_batch_size) utils.check_isinstance(embeddings, ops.EagerTensor) utils.check_equal(embeddings.shape[0], current_batch_size) # pytype doesn't seem to see that we check the type utils.check_equal(embeddings.shape[1], _FLAG_EMBEDDING_DEPTH.value) # pytype: disable=attribute-error ###################################################################### # Retrieve. ###################################################################### with tf.device(reference_db_device): top_k, inner_prods = tf_utils.mips_exact_search( embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db) top_k = tf_utils.process_strat_output( top_k, "top_k", strategy, current_batch_size) utils.check_equal( inner_prods.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) utils.check_equal( top_k.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) output_file[split]["distances"][ write_start:end] = inner_prods gathered = tf.gather(blocks, top_k).numpy() utils.check_equal(gathered.shape[0], current_batch_size) utils.check_equal(write_start + gathered.shape[0], end) for j in range(gathered.shape[0]): local_gathered = gathered[j].tolist() utils.check_equal(len(local_gathered), _FLAG_NUM_RETRIEVALS.value) local_gathered = [ sample.decode() for sample in local_gathered ] token_ids = np.array( gpt2_tokenizer.batch_encode_plus( local_gathered, padding="max_length", truncation=True, ).input_ids) for line in token_ids: assert not np.all(line == 0), line token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1 output_file[split][ constants.CTH5Fields.gpt2_retrieved_ids][ write_start + j] = token_ids[:, :_FLAG_MAX_LENGTH_RETRIEVALS. value] write_start += current_batch_size ############################################################################ # Upload the results to GCS ############################################################################ LOGGER.debug("DONE WITH THE PRODUCTION") output_file.close() with utils.log_duration(LOGGER, "main", "gsutil transfer"): command = [ "/root/google-cloud-sdk/bin/gsutil", "-m", "cp", "-r", str(tmp_dir / "*"), target_path ] LOGGER.debug("Command: %s", " ".join(command)) subprocess.check_call(command) LOGGER.debug("ALL DONE")
def check_tf_tensor(obj): utils.check_isinstance(obj, TfTensorTypeTuple)
def _prepare_samples_w_retrieval( split, batch_size, question_ids_inputs: tf_utils.TFTensorType, answer_ids_inputs: tf_utils.TFTensorType, gpt2_tokenized_retrieved: tf_utils.TFTensorType, distances, num_retrievals_to_use, temperature, context_size, enable_debug_checks, use_helper_words, helper_word_token_ids, max_generation_length): utils.check_contained(use_helper_words, constants.HelperWordModeChoices.choices()) """Prepares the samples that use retrieval. In regards to helper words, we only use them once. This could be changed. It would have many advantages. """ assert (split == constants.SplitChoices.test) == ( answer_ids_inputs is None), (split == constants.SplitChoices.test, answer_ids_inputs) tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2-xl") # panel_title = "Begining of _prepare_samples_w_retrieval" # panel_text = [f"{question_ids_inputs.shape = }"] # panel_text += [f"{question_ids_inputs.row_lengths(axis=-1) = }"] # panel_text += [f"{answer_ids_inputs.shape = }"] # panel_text += [f"{answer_ids_inputs.row_lengths(axis=-1) = }"] # panel_text += [f"{distances.shape = }"] # panel_text += [f"{gpt2_tokenized_retrieved.shape = }"] # panel_text += [f"{gpt2_tokenized_retrieved.row_lengths(axis=-1) = }"] # print(rich.panel.Panel("\n\n".join(panel_text), title=panel_title)) is_not_test = split != constants.SplitChoices.test if not isinstance(question_ids_inputs, tf.RaggedTensor): question_ids_inputs = tf.RaggedTensor.from_tensor( question_ids_inputs, padding=constants.RAGGED_PADDING_ID) if enable_debug_checks: asserts = [] asserts.append( tf.Assert( tf.math.reduce_all( question_ids_inputs != constants.RAGGED_PADDING_ID, ), [question_ids_inputs.to_tensor()])) if is_not_test: asserts.append( tf.Assert( tf.math.reduce_all( answer_ids_inputs != constants.RAGGED_PADDING_ID, ), [answer_ids_inputs.to_tensor()])) with tf.control_dependencies(asserts): question_ids_inputs = tf.identity(question_ids_inputs) # These checks are at graph composition time, so OK utils.check_isinstance(question_ids_inputs, tf.RaggedTensor) if is_not_test: utils.check_isinstance(answer_ids_inputs, tf.RaggedTensor) ############################################################################## # Sample from the possible retrievals ############################################################################## # Choose the indices selected_context_indices = tf_utils.sample_without_replacement( distances / temperature, num_retrievals_to_use) # Concatenate the retrievals utils.check_isinstance(helper_word_token_ids, dict) utils.check_isinstance( helper_word_token_ids['context'], tuple([np.ndarray] + list(tf_utils.TfTensorTypeTuple))) concat_retrieved = _tokenize_and_concat_while_loop( gpt2_tokenized_retrieved, selected_context_indices=selected_context_indices, batch_size=batch_size, num_retrievals_to_use=num_retrievals_to_use, helper_word_mode=use_helper_words, context_helper_word_tokens=helper_word_token_ids['context'], ) if use_helper_words == constants.HelperWordModeChoices.once: concat_retrieved = tf.concat([ helper_word_token_ids["context"], concat_retrieved, ], axis=1) # _print_info( # concat_retrieved, # f"Num of 'context' helper words. Mode: {use_helper_words}", # tokenizer, # helper_word_token_ids # ) # Cut the lengths down to max_lens_retrieval. # The eventual length of the ["question"] helper_tokens is included in # question_ids_inputs. if is_not_test: max_lens_retrieval = ( context_size * tf.ones( shape=(batch_size, ), dtype=tf.int64, ) - ( question_ids_inputs.row_lengths() + # We always generate the same length of text. max_generation_length + # answer_ids_inputs.row_lengths() + (helper_word_token_ids["answer"].shape[1] if use_helper_words else 0))) else: max_lens_retrieval = (context_size * tf.ones( shape=(batch_size, ), dtype=tf.int64, ) - (question_ids_inputs.row_lengths() + max_generation_length + (helper_word_token_ids["answer"].shape[1] if use_helper_words else 0))) concat_retrieved = tf.ragged.boolean_mask( concat_retrieved, (tf.ragged.range(concat_retrieved.row_lengths()) < tf.expand_dims(max_lens_retrieval, axis=1))) panel_text = [] panel_text += [f"{selected_context_indices.shape = }"] panel_text += [f"{concat_retrieved.shape = }"] panel_text += [f"{concat_retrieved.row_lengths(axis=-1) = }"] panel_text += [f"{max_lens_retrieval = }"] print(rich.panel.Panel("\n\n".join(panel_text))) if enable_debug_checks: asserts = [ tf.Assert(tf.math.reduce_all(max_lens_retrieval < context_size), [max_lens_retrieval, context_size]), ] with tf.control_dependencies(asserts): concat_retrieved = tf.identity(concat_retrieved) if use_helper_words: if is_not_test: new_input_ids = tf.concat([ question_ids_inputs, concat_retrieved, helper_word_token_ids["answer"], answer_ids_inputs ], axis=1) new_label_ids = tf.concat([ -100 * tf.ones_like(question_ids_inputs), -100 * tf.ones_like(concat_retrieved), -100 * tf.ones_like(helper_word_token_ids["answer"]), answer_ids_inputs ], axis=1) else: new_input_ids = tf.concat([ question_ids_inputs, concat_retrieved, helper_word_token_ids["answer"], ], axis=1) else: if is_not_test: new_input_ids = tf.concat( [question_ids_inputs, concat_retrieved, answer_ids_inputs], axis=1) new_label_ids = tf.concat([ -100 * tf.ones_like(question_ids_inputs), -100 * tf.ones_like(concat_retrieved), answer_ids_inputs ], axis=1) else: new_input_ids = tf.concat([ question_ids_inputs, concat_retrieved, ], axis=1) new_input_ids: tf.RaggedTensor return new_input_ids, new_label_ids if is_not_test else None