def main(argv): # Arguments and logging boilerplate if len(argv) > 1: raise RuntimeError(argv) absl_logging.use_python_logging() utils.log_module_args(LOGGER, argv[0]) # Load a retriever config. retriever_config = tf_utils.REALMConfig( **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value)) assert not _FLAG_USE_SUBSET.value # Preparation of the output path time_stamp = time.strftime("%Y%m%d-%H%M%S") target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp.strip()) if target_path[-1] != "/": target_path += "/" ############################################################################## # Setup devices and strategy ############################################################################## # Duration is pretty much instantaneous with utils.log_duration(LOGGER, "main", "Initializing devices"): tpu_config = tf_utils.init_tpus(local=_FLAG_TPU_IS_LOCAL.value, tpu_name=_FLAG_TPU_NAME.value) device_type = tf_utils.current_accelerator_type() LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use())) if _FLAG_TPU_NAME.value and device_type == "CPU": raise RuntimeError("Device is CPU and we expected a TPU.") if device_type == "TPU": if tpu_config is None: raise RuntimeError("We should have a tpu_config.") strategy = tf.distribute.TPUStrategy(tpu_config.resolver) batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value elif device_type == "GPU" or device_type == "CPU": strategy = tf.distribute.MirroredStrategy() batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value else: raise RuntimeError(device_type) ############################################################################## # Load the KILT ELI5 dataset. ############################################################################## # Takes a while eli5 = {} keys = ["train", "validation", "test"] gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl") gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."): if _FLAG_DATASET_ROOT.value: for split in tqdm.tqdm(keys): load_path = os.path.join(_FLAG_DATASET_ROOT.value, "HuggingfaceDatasets", f"{split}_kilt_eli5.hf") with tf.device("/job:localhost"): eli5[split] = datasets.load_from_disk(load_path) else: eli5 = datasets.load_dataset("kilt_tasks", "eli5") ############################################################################## # Load the dataset of the text that will be retrieved. ############################################################################## # Takes a long time with utils.log_duration(LOGGER, "Main", "Load the textual dataset"): # Extract the appropriate text # The buffer_size is taken from the original ORQA code. blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records, buffer_size=512 * 1024 * 1024) blocks_dataset = blocks_dataset.batch( retriever_config.num_block_records, drop_remainder=False) blocks: tf.Tensor = tf.data.experimental.get_single_element( blocks_dataset) ############################################################################ # Increase the number of maximum open file descriptors to make space # for all the shards. ############################################################################ max_num_fd = _FLAG_NUM_SHARDS.value * 3 + _MIN_N_FD resource.setrlimit(resource.RLIMIT_NOFILE, (max_num_fd, max_num_fd)) ############################################################################ # Prepare the output files. ############################################################################ writers = {} all_paths = {} for split in keys: maybe_subset = "_subset" if _FLAG_USE_SUBSET.value else "" # Prepare paths. They can't be in a generator. A function generator would be # fine though. paths = [ os.path.join(target_path + maybe_subset, f"{split}_{i}.tfr") for i in range(_FLAG_NUM_SHARDS.value) ] all_paths[split] = paths writers[split] = [] # Create The TFR writers. for i, path in enumerate(paths): writers[split].append(tf.io.TFRecordWriter(path)) # Load the reference DB. We used to accidentally do this once per split :O with utils.log_duration(LOGGER, "main", "Loading the reference db."): checkpoint_path = os.path.join(retriever_config.query_embedder_path, "encoded", "encoded.ckpt") reference_db_device = tf_utils.device_mapping().CPUs[0].name with tf.device(reference_db_device): reference_db = tf_utils.load_reference_db( checkpoint_path, variable_name="block_emb", ) ############################################################################ # Prep the encoder and the tokenizer ############################################################################ with utils.log_duration(LOGGER, "main", "Loading the encoder model and the tokenizer."): with strategy.scope(): query_encoder = hub.load(retriever_config.query_embedder_path, tags={}) encode_fn = _make_encode_fn(query_encoder) encode_fn_strategy_run = make_encode_fn_strategy_run_fn( strategy=strategy, encode_fn=encode_fn, ) vocab_file = os.path.join(retriever_config.query_embedder_path, "assets", "vocab.txt") utils.check_exists(vocab_file) do_lower_case = query_encoder.signatures["tokenization_info"]( )["do_lower_case"] tokenization_info = dict(vocab_file=vocab_file, do_lower_case=do_lower_case) tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( query_encoder, tokenization_info) ############################################################################ # Preprocess the dataset ############################################################################ cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")), tf.int32) sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")), tf.int32) transform = _make_transform_fn( bert_tokenizer=tokenizer, bert_cls_token_id=cls_token_id, bert_sep_token_id=sep_token_id, ) feature_dtypes = { constants.CTH5Fields.distances: tf.float32, constants.CTH5Fields.gpt2_retrieved_ids: tf.int32, constants.CTH5Fields.gpt2_answer_ids_inputs: tf.int32, constants.CTH5Fields.gpt2_question_ids_inputs: tf.int32, } with utils.log_duration(LOGGER, "main", "generating codes"): for split in keys: sample_count = 0 eli5: Dict[str, datasets.Dataset] if split != "test": for_slices = dict(sample_id=eli5[split]["id"], question=eli5[split]["input"], answer=[ sample[0]["answer"] for sample in eli5[split]["output"] ]) else: for_slices = dict( sample_id=eli5[split]["id"], question=eli5[split]["input"], ) ds = tf.data.Dataset.from_tensor_slices(for_slices) ds = ds.map(transform, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.apply( tf.data.experimental.dense_to_ragged_batch(batch_size)) ds = ds.map(_squeeze, num_parallel_calls=tf.data.experimental.AUTOTUNE) tqdm_inner = tqdm.tqdm(enumerate(ds), total=len(eli5[split]["id"]) // _FLAG_BATCH_SIZE.value, desc=f"Split `{split}`: Batches") for i, batch in tqdm_inner: features = collections.defaultdict(list) ###################################################################### # Enforce the current real batch size ###################################################################### current_batch_size = batch["sample_id"].shape[0] for k, v in batch.items(): utils.check_equal(v.shape[0], current_batch_size) ###################################################################### gpt2_question_ids_inputs = _prep_field(batch["question"], gpt2_tokenizer) utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_question_ids_inputs.shape[0], current_batch_size) if split != "test": gpt2_answer_ids_inputs = _prep_field( batch["answer"], gpt2_tokenizer) utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_answer_ids_inputs.shape[0], current_batch_size) assert len(gpt2_answer_ids_inputs.shape) == 2, ( gpt2_answer_ids_inputs.shape) ###################################################################### # Save the gpt2 tokenized question and answer ###################################################################### features[constants.CTH5Fields.gpt2_question_ids_inputs].extend( gpt2_question_ids_inputs) if split != "test": features[ constants.CTH5Fields.gpt2_answer_ids_inputs].extend( gpt2_answer_ids_inputs) ###################################################################### # Encode the samples. ###################################################################### batch = strategy.experimental_distribute_values_from_function( tf_utils.make_dict_distribute_fn(batch)) embeddings = encode_fn_strategy_run(batch) embeddings = tf_utils.process_strat_output( embeddings, "embeddings", strategy, current_batch_size) utils.check_isinstance(embeddings, ops.EagerTensor) utils.check_equal(embeddings.shape[0], current_batch_size) # pytype doesn't seem to see that we check the type utils.check_equal(embeddings.shape[1], _FLAG_EMBEDDING_DEPTH.value) # pytype: disable=attribute-error ###################################################################### # Retrieve. ###################################################################### # Do exact retrieval with tf.device(reference_db_device): top_k, inner_prods = tf_utils.mips_exact_search( embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db) # Collate the results top_k = tf_utils.process_strat_output(top_k, "top_k", strategy, current_batch_size) # Check the shapes utils.check_equal( inner_prods.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) utils.check_equal( top_k.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) # Save the distances features[constants.CTH5Fields.distances].extend(inner_prods) # Retrieve the text fields associated to the indices gathered = tf.gather(blocks, top_k).numpy() utils.check_equal(gathered.shape[0], current_batch_size) utils.check_equal(gathered.shape[1], _FLAG_NUM_RETRIEVALS.value) retrievals = [] for index_in_batch in range(current_batch_size): # Put the appropriate byte strings in a list local_gathered = gathered[index_in_batch].tolist() utils.check_equal(len(local_gathered), _FLAG_NUM_RETRIEVALS.value) # Decode to utf-8 local_gathered = [ sample.decode() for sample in local_gathered ] # Encode to GPT2 BPE token_ids = np.array( gpt2_tokenizer.batch_encode_plus( local_gathered, padding="max_length", truncation=True, ).input_ids) # Make sure no line is empty # TODO(julesgm): Maybe optional for line in token_ids: assert not np.all(line == 0), line # Convert the eos_tokens token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1 # Save the retrievals retrievals.append(token_ids) # Save the feature features[constants.CTH5Fields.gpt2_retrieved_ids] = retrievals utils.check_equal( retrievals[0].shape, (_FLAG_NUM_RETRIEVALS.value, _FLAG_CONTEXT_SIZE.value)) for k, v in features.items(): utils.check_equal(len(v), current_batch_size) for index_in_batch in range(current_batch_size): feature_dict = {} for feature_k, feature_v in features.items(): # Cast the feature to its appropriate dtype casted_feats = tf.cast(feature_v[index_in_batch], feature_dtypes[feature_k]) # Serialize the tensor to bytes feature_bytes = tf.io.serialize_tensor(casted_feats) # Build a bytes list tf.train.Feature object, # the serialization tree node feature_dict[feature_k] = _bytes_feature(feature_bytes) # Create the serialization tree root # Expects a list of features feature = tf.train.Features(feature=feature_dict) # Expects a tf.train.Features object example_obj = tf.train.Example(features=feature) # Serialize that to bytes serialized_example = example_obj.SerializeToString() # Write the bytes # TODO(julesgm): Parallelize this with a thread or a process pool & # futures. writers[split][sample_count % _FLAG_NUM_SHARDS.value].write( serialized_example) sample_count += 1 if sample_count % 1000 == 0: LOGGER.debug("Paths: %s", str(all_paths[split][0])) LOGGER.debug("Flushing and closing the `%s` writers", split) for writer in tqdm.tqdm(writers[split]): writer.flush() writer.close() LOGGER.debug("Done.")
def main(argv): if len(argv) > 1: raise RuntimeError(argv) absl_logging.use_python_logging() utils.log_module_args(LOGGER, argv[0]) retriever_config = tf_utils.REALMSave( **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value)) assert not _FLAG_USE_SUBSET.value time_stamp = time.strftime("%Y%m%d-%H%M%S") target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp.strip()) if target_path[-1] != "/": target_path += "/" ############################################################################## # Setup devices and strategy ############################################################################## with utils.log_duration(LOGGER, "main", "Initializing devices"): tpu_config = tf_utils.init_tpus() device_type = tf_utils.current_accelerator_type() LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use())) if device_type == "TPU": if tpu_config is None: raise RuntimeError("We should have a tpu_config.") strategy = tf.distribute.TPUStrategy(tpu_config.resolver) batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value elif device_type == "GPU" or device_type == "CPU": strategy = tf.distribute.MirroredStrategy() batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value else: raise RuntimeError(device_type) ############################################################################## # Load the dataset. ############################################################################## eli5 = {} keys = ["train", "eval", "test"] gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl") gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."): for split in tqdm.tqdm(keys): load_path = os.path.join(_FLAG_DATASET_ROOT.value, "HuggingfaceDatasets", f"{split}_kilt_eli5.hf") with tf.device("/job:localhost"): eli5[split] = datasets.load_from_disk(load_path) ############################################################################## # ############################################################################## with utils.log_duration(LOGGER, "Main", "Load the textual dataset"): # Extract the appropriate text # The buffer_size is taken from the original ORQA code. blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records, buffer_size=512 * 1024 * 1024) blocks_dataset = blocks_dataset.batch( retriever_config.num_block_records, drop_remainder=True) blocks = tf.data.experimental.get_single_element(blocks_dataset) ############################################################################ # Prepare the output file. ############################################################################ writers = {} all_paths = {} for split in keys: maybe_subset = "_subset" if _FLAG_USE_SUBSET.value else "" paths = [ os.path.join(target_path + maybe_subset, f"{split}_{i}.tfr") for i in range(_FLAG_NUM_SHARDS.value) ] all_paths[split] = paths writers[split] = [tf.io.TFRecordWriter(filename) for filename in paths] with utils.log_duration(LOGGER, "main", "Loading the reference db."): checkpoint_path = os.path.join( retriever_config.query_embedder_path, "encoded", "encoded.ckpt") reference_db_device = tf_utils.device_mapping().CPUs[0].name with tf.device(reference_db_device): reference_db = tf_utils.load_reference_db( checkpoint_path, variable_name="block_emb", ) ############################################################################ # Prep the encoder and the tokenizer ############################################################################ with utils.log_duration(LOGGER, "main", "Loading the encoder model and the tokenizer."): with strategy.scope(): query_encoder = hub.load(retriever_config.query_embedder_path, tags={}) encode_fn = _make_encode_fn(query_encoder) encode_fn_strategy_run = make_encode_fn_strategy_run_fn( strategy=strategy, encode_fn=encode_fn, ) vocab_file = os.path.join(retriever_config.query_embedder_path, "assets", "vocab.txt") utils.check_exists(vocab_file) do_lower_case = query_encoder.signatures["tokenization_info"]( )["do_lower_case"] tokenization_info = dict(vocab_file=vocab_file, do_lower_case=do_lower_case) tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( query_encoder, tokenization_info) ############################################################################ # Preprocess the dataset ############################################################################ cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")), tf.int32) sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")), tf.int32) transform = _make_transform_fn( bert_tokenizer=tokenizer, bert_cls_token_id=cls_token_id, bert_sep_token_id=sep_token_id, ) feature_dtypes = { constants.CTH5Fields.distances: tf.float32, constants.CTH5Fields.gpt2_retrieved_ids: tf.int32, constants.CTH5Fields.gpt2_answer_ids_inputs: tf.int32, constants.CTH5Fields.gpt2_question_ids_inputs: tf.int32, } with utils.log_duration(LOGGER, "main", "generating codes"): for split in keys: sample_count = 0 eli5: Dict[str, datasets.Dataset] if split != "test": for_slices = dict(sample_id=eli5[split]["id"], question=eli5[split]["input"], answer=[ sample["answer"][0] for sample in eli5[split]["output"] ]) else: for_slices = dict( sample_id=eli5[split]["id"], question=eli5[split]["input"], ) ds = tf.data.Dataset.from_tensor_slices(for_slices) ds = ds.map(transform, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.apply( tf.data.experimental.dense_to_ragged_batch(batch_size)) ds = ds.map(_squeeze, num_parallel_calls=tf.data.experimental.AUTOTUNE) tqdm_inner = tqdm.tqdm(enumerate(ds), total=len(eli5[split]["id"]) // _FLAG_BATCH_SIZE.value, desc=f"Split `{split}`: Batches") for i, batch in tqdm_inner: features = collections.defaultdict(list) ###################################################################### # Enforce the current real batch size ###################################################################### current_batch_size = batch["sample_id"].shape[0] for k, v in batch.items(): utils.check_equal(v.shape[0], current_batch_size) ###################################################################### gpt2_question_ids_inputs = _prep_field(batch["question"], gpt2_tokenizer) utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_question_ids_inputs.shape[0], current_batch_size) if split != "test": gpt2_answer_ids_inputs = _prep_field( batch["answer"], gpt2_tokenizer) utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_answer_ids_inputs.shape[0], current_batch_size) assert len(gpt2_answer_ids_inputs.shape) == 2, ( gpt2_answer_ids_inputs.shape) ###################################################################### # Save the gpt2 tokenized question and answer ###################################################################### features[constants.CTH5Fields.gpt2_question_ids_inputs].extend( gpt2_question_ids_inputs) if split != "test": features[ constants.CTH5Fields.gpt2_answer_ids_inputs].extend( gpt2_answer_ids_inputs) ###################################################################### # Encode the samples. ###################################################################### batch = strategy.experimental_distribute_values_from_function( tf_utils.make_dict_distribute_fn(batch)) embeddings = encode_fn_strategy_run(batch) embeddings = tf_utils.process_strat_output( embeddings, "embeddings", strategy, current_batch_size) utils.check_isinstance(embeddings, ops.EagerTensor) utils.check_equal(embeddings.shape[0], current_batch_size) # pytype doesn't seem to see that we check the type utils.check_equal(embeddings.shape[1], _FLAG_EMBEDDING_DEPTH.value) # pytype: disable=attribute-error ###################################################################### # Retrieve. ###################################################################### with tf.device(reference_db_device): top_k, inner_prods = tf_utils.mips_exact_search( embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db) top_k = tf_utils.process_strat_output(top_k, "top_k", strategy, current_batch_size) utils.check_equal( inner_prods.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) utils.check_equal( top_k.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) features[constants.CTH5Fields.distances].extend(inner_prods) gathered = tf.gather(blocks, top_k).numpy() utils.check_equal(gathered.shape[0], current_batch_size) retrievals = [] for j in range(gathered.shape[0]): local_gathered = gathered[j].tolist() utils.check_equal(len(local_gathered), _FLAG_NUM_RETRIEVALS.value) local_gathered = [ sample.decode() for sample in local_gathered ] token_ids = np.array( gpt2_tokenizer.batch_encode_plus( local_gathered, padding="max_length", truncation=True, ).input_ids) for line in token_ids: assert not np.all(line == 0), line token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1 retrievals.append(token_ids) features[constants.CTH5Fields.gpt2_retrieved_ids] = retrievals utils.check_equal( retrievals[0].shape, (_FLAG_NUM_RETRIEVALS.value, _FLAG_CONTEXT_SIZE.value)) for k, v in features.items(): utils.check_equal(len(v), current_batch_size) for k in range(current_batch_size): feature = tf.train.Features( feature={ k: _bytes_feature( tf.io.serialize_tensor( tf.cast(v[k], feature_dtypes[k]))) for k, v in features.items() }) writers[split][ sample_count % _FLAG_NUM_SHARDS.value].write( tf.train.Example( features=feature).SerializeToString()) sample_count += 1 if sample_count % 1000 == 0: LOGGER.debug("Paths: %s", str(all_paths[split][0])) LOGGER.debug("Done.")
def main(argv): if len(argv) > 1: raise RuntimeError(argv) absl_logging.use_python_logging() retriever_config = tf_utils.REALMSave( **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value)) extra = "_FROM_SUBSET" if _FLAG_USE_SUBSET.value else "" time_stamp = time.strftime("%Y%m%d-%H%M%S") target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp + extra).strip() if target_path[-1] != "/": target_path += "/" ############################################################################## # Setup devices and strategy ############################################################################## with utils.log_duration(LOGGER, "main", "Initializing devices"): tpu_config = tf_utils.init_tpus() device_type = tf_utils.current_accelerator_type() LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use())) if device_type == "TPU": if tpu_config is None: raise RuntimeError("We should have a tpu_config.") strategy = tf.distribute.TPUStrategy(tpu_config.resolver) batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value elif device_type == "GPU" or device_type == "CPU": strategy = tf.distribute.MirroredStrategy() batch_size = len( tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value else: raise RuntimeError(device_type) ############################################################################## # Load the dataset. ############################################################################## eli5 = {} keys = ["train", "eval", "test"] gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl") gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."): for split in tqdm.tqdm(keys): load_path = os.path.join(_FLAGS_DATASET_ROOT.value, "HuggingfaceDatasets", f"{split}_kilt_eli5.hf") with tf.device("/job:localhost"): eli5[split] = datasets.load_from_disk(load_path) if _FLAG_USE_SUBSET.value: _warn_subset() ############################################################################## # ############################################################################## with utils.log_duration(LOGGER, "Main", "Load the textual dataset"): # Extract the appropriate text # The buffer_size is taken from the original ORQA code. blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records, buffer_size=512 * 1024 * 1024) blocks_dataset = blocks_dataset.batch( retriever_config.num_block_records, drop_remainder=True) blocks = tf.data.experimental.get_single_element(blocks_dataset) with tempfile.TemporaryDirectory() as tmp_dir: ############################################################################ # Prepare the output file. ############################################################################ tmp_dir = pathlib.Path(tmp_dir) h5_output_path = tmp_dir / "codes.h5" output_file = h5py.File(h5_output_path, "w") flags_dict = { flag.name: flag.value for flag in flags.FLAGS.flags_by_module_dict()[argv[0]] } utils.to_json_file(tmp_dir / "params.json", flags_dict) for split in keys: with utils.log_duration( LOGGER, "main", "Creating the output hdf5 file, embeddings."): num_entries = len(eli5[split]["id"]) if _FLAG_USE_SUBSET.value: num_entries = min(num_entries, _FLAG_SUBSET_AMOUNT.value) split_group = output_file.create_group(split) with utils.log_duration( LOGGER, "main", "Creating the output hdf5 file, retrieval."): split_group.create_dataset( constants.CTH5Fields.distances, shape=(num_entries, _FLAG_NUM_RETRIEVALS.value), dtype=np.float32, ) split_group.create_dataset( constants.CTH5Fields.gpt2_question_ids_inputs, shape=(num_entries, _FLAG_CONTEXT_SIZE.value), dtype=np.int32) if split != "test": split_group.create_dataset( constants.CTH5Fields.gpt2_answer_ids_inputs, shape=(num_entries, _FLAG_CONTEXT_SIZE.value), dtype=np.int32) split_group.create_dataset( constants.CTH5Fields.gpt2_retrieved_ids, shape=( num_entries, _FLAG_NUM_RETRIEVALS.value, _FLAG_MAX_LENGTH_RETRIEVALS.value, ), dtype=np.int32) with utils.log_duration(LOGGER, "main", "Loading the reference db."): checkpoint_path = os.path.join( retriever_config.query_embedder_path, "encoded", "encoded.ckpt") reference_db_device = tf_utils.device_mapping().CPUs[0].name with tf.device(reference_db_device): reference_db = tf_utils.load_reference_db( checkpoint_path, variable_name="block_emb", ) ############################################################################ # Prep the encoder and the tokenizer ############################################################################ with utils.log_duration( LOGGER, "main", "Loading the encoder model and the tokenizer."): with strategy.scope(): query_encoder = hub.load(retriever_config.query_embedder_path, tags={}) encode_fn = _make_encode_fn(query_encoder) encode_fn_strategy_run = _make_encode_fn_strategy_run_fn( strategy=strategy, encode_fn=encode_fn, ) vocab_file = os.path.join(retriever_config.query_embedder_path, "assets", "vocab.txt") utils.check_exists(vocab_file) do_lower_case = query_encoder.signatures["tokenization_info"]( )["do_lower_case"] tokenization_info = dict(vocab_file=vocab_file, do_lower_case=do_lower_case) tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( query_encoder, tokenization_info) ############################################################################ # Preprocess the dataset ############################################################################ cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")), tf.int32) sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")), tf.int32) transform = _make_transform_fn( bert_tokenizer=tokenizer, bert_cls_token_id=cls_token_id, bert_sep_token_id=sep_token_id, ) with utils.log_duration(LOGGER, "main", "generating codes"): tqdm_splits = tqdm.tqdm(keys) for split in tqdm_splits: tqdm_splits.set_description(f"Split `{split}`") eli5: Dict[str, datasets.Dataset] write_start = 0 if _FLAG_USE_SUBSET.value: _warn_subset(tqdm_splits) eli5[split] = eli5[split][:_FLAG_SUBSET_AMOUNT.value] utils.check_operator(operator.le, len(eli5[split]["id"]), _FLAG_SUBSET_AMOUNT.value) utils.check_operator(operator.le, len(eli5[split]["input"]), _FLAG_SUBSET_AMOUNT.value) else: utils.check_equal(len(eli5[split]), len(eli5[split]["id"])) utils.check_equal(len(eli5[split]), len(eli5[split]["input"])) if split != "test": for_slices = dict(sample_id=eli5[split]["id"], question=eli5[split]["input"], answer=[ sample["answer"][0] for sample in eli5[split]["output"] ]) else: for_slices = dict( sample_id=eli5[split]["id"], question=eli5[split]["input"], ) ds = tf.data.Dataset.from_tensor_slices(for_slices) ds = ds.map(transform, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.apply( tf.data.experimental.dense_to_ragged_batch(batch_size)) ds = ds.map(_squeeze, num_parallel_calls=tf.data.experimental.AUTOTUNE) tqdm_inner = tqdm.tqdm(enumerate(ds), total=len(eli5[split]["id"]) // _FLAG_BATCH_SIZE.value, desc=f"Split `{split}`: Batches") for i, batch in tqdm_inner: ###################################################################### # Enforce the current real batch size ###################################################################### current_batch_size = batch["sample_id"].shape[0] for k, v in batch.items(): utils.check_equal(v.shape[0], current_batch_size) ###################################################################### gpt2_question_ids_inputs = _prep_field( batch["question"], gpt2_tokenizer) utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_question_ids_inputs.shape[0], current_batch_size) if split != "test": gpt2_answer_ids_inputs = _prep_field( batch["answer"], gpt2_tokenizer) utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32) utils.check_equal(gpt2_answer_ids_inputs.shape[0], current_batch_size) assert len(gpt2_answer_ids_inputs.shape) == 2, ( gpt2_answer_ids_inputs.shape) ###################################################################### # Save the gpt2 tokenized question and answer ###################################################################### end = write_start + current_batch_size utils.check_equal( output_file[split][ constants.CTH5Fields.gpt2_question_ids_inputs] [write_start:end].shape[0], current_batch_size) output_file[split][ constants.CTH5Fields.gpt2_question_ids_inputs][ write_start:end] = gpt2_question_ids_inputs if split != "test": output_file[split][ constants.CTH5Fields.gpt2_answer_ids_inputs][ write_start:end] = gpt2_answer_ids_inputs ###################################################################### # Encode the samples. ###################################################################### batch = strategy.experimental_distribute_values_from_function( tf_utils.make_dict_distribute_fn(batch)) embeddings = encode_fn_strategy_run(batch) embeddings = tf_utils.process_strat_output( embeddings, "embeddings", strategy, current_batch_size) utils.check_isinstance(embeddings, ops.EagerTensor) utils.check_equal(embeddings.shape[0], current_batch_size) # pytype doesn't seem to see that we check the type utils.check_equal(embeddings.shape[1], _FLAG_EMBEDDING_DEPTH.value) # pytype: disable=attribute-error ###################################################################### # Retrieve. ###################################################################### with tf.device(reference_db_device): top_k, inner_prods = tf_utils.mips_exact_search( embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db) top_k = tf_utils.process_strat_output( top_k, "top_k", strategy, current_batch_size) utils.check_equal( inner_prods.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) utils.check_equal( top_k.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)) output_file[split]["distances"][ write_start:end] = inner_prods gathered = tf.gather(blocks, top_k).numpy() utils.check_equal(gathered.shape[0], current_batch_size) utils.check_equal(write_start + gathered.shape[0], end) for j in range(gathered.shape[0]): local_gathered = gathered[j].tolist() utils.check_equal(len(local_gathered), _FLAG_NUM_RETRIEVALS.value) local_gathered = [ sample.decode() for sample in local_gathered ] token_ids = np.array( gpt2_tokenizer.batch_encode_plus( local_gathered, padding="max_length", truncation=True, ).input_ids) for line in token_ids: assert not np.all(line == 0), line token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1 output_file[split][ constants.CTH5Fields.gpt2_retrieved_ids][ write_start + j] = token_ids[:, :_FLAG_MAX_LENGTH_RETRIEVALS. value] write_start += current_batch_size ############################################################################ # Upload the results to GCS ############################################################################ LOGGER.debug("DONE WITH THE PRODUCTION") output_file.close() with utils.log_duration(LOGGER, "main", "gsutil transfer"): command = [ "/root/google-cloud-sdk/bin/gsutil", "-m", "cp", "-r", str(tmp_dir / "*"), target_path ] LOGGER.debug("Command: %s", " ".join(command)) subprocess.check_call(command) LOGGER.debug("ALL DONE")