示例#1
0
def main(argv):
    if len(argv) > 1:
        raise RuntimeError(argv[1:])
    absl_logging.use_python_logging()
    utils.check_contained(_FLAG_APPROACH_TYPE.value, _ACCEPTABLE_APPROACHES)

    utils.check_operator(operator.xor, bool(_FLAG_H5_MODEL_PATH.value),
                         bool(_FLAG_CKPT_MODEL_PATH.value))

    if _FLAG_H5_MODEL_PATH.value:
        model_path = _FLAG_H5_MODEL_PATH.value
        mode = constants.SaveModeChoices.hfh5
    elif _FLAG_CKPT_MODEL_PATH.value:
        model_path = _FLAG_CKPT_MODEL_PATH.value
        mode = constants.SaveModeChoices.ckpt
    else:
        raise RuntimeError("Logically should never happen.")

    utils.check_exists(model_path)
    device_type = tf_utils.devices_to_use()[0].device_type

    # ONLY GPU IS SUPPORTED
    utils.check_equal(device_type, "GPU")

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Build the distribution strategy
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if device_type == "TPU":
        # ONLY LOCAL TPU IS "SUPPORTED"
        utils.check_isinstance(_FLAG_IS_LOCAL_TPU.value, bool)
        assert _FLAG_IS_LOCAL_TPU.value
        tpu_config = tf_utils.init_tpus(local=True)
        utils.check_isinstance(tpu_config, tf_utils.TpuConfigType)
        utils.check_not_none(tpu_config)
        strategy = tf.distribute.TPUStrategy(tpu_config.resolver)
    elif device_type == "GPU":
        strategy = tf.distribute.MirroredStrategy(
            devices=tf.config.experimental.list_logical_devices('GPU'))
    else:
        raise RuntimeError(device_type)

    # ONLY GPU IS SUPPORTED
    print(tf.config.list_logical_devices())
    utils.check_isinstance(strategy, tf.distribute.MirroredStrategy)

    ##############################################################################
    # Load Model
    ##############################################################################
    with utils.log_duration(LOGGER, main.__name__, "All of model preparation"):
        with strategy.scope():
            # HF isn't able to read directly from GCS
            if (model_path.startswith("gs://")
                    and mode == constants.SaveModeChoices.hfh5):
                with utils.log_duration(LOGGER, main.__name__,
                                        "Download model from GS"):
                    with tempfile.TemporaryDirectory() as td:
                        td += os.path.sep

                        if os.path.exists("/root/google-cloud-sdk/bin/gsutil"):
                            exec_ = "/root/google-cloud-sdk/bin/gsutil"
                        else:
                            exec_ = "gsutil"

                        command = [
                            exec_,
                            "-m",
                            "cp",
                            "-r",
                            os.path.join(model_path, "*"),
                            td,
                        ]
                        LOGGER.debug("Running bash command: %s",
                                     " ".join(command))
                        subprocess.check_call(command)
                        LOGGER.debug("Files at the temp dir(%s): %s", td,
                                     str(os.listdir(td)))

                        model = make_model_tf(td, mode=mode)
            else:
                model = make_model_tf(model_path, mode=mode)

    utils.check_not_none(model)

    ##############################################################################
    # Load Dataset Pipeline
    ##############################################################################
    utils.check_contained(
        _FLAG_APPROACH_TYPE.value, {
            constants.ApproachTypeChoices.naked_lm,
            constants.ApproachTypeChoices.cached_pretok
        })
    devices = tf_utils.devices_to_use()
    num_replicas = (len(devices)
                    if devices[0].device_type in {"GPU", "TPU"} else 1)
    utils.check_equal(devices[0].device_type, "GPU")

    # Only a batch size of 1 is currently supported. We need attention masks
    batch_size = _FLAG_BATCH_SIZE.value * num_replicas
    approach_type = _FLAG_APPROACH_TYPE.value

    logging.debug("Loading dataset.")
    tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl")
    ds = prep_ds_for_generation(
        dict(
            tokenizer=tokenizer,
            context_window_size=1024,
            dataset_name="kilt_eli5",
            batch_size=1,  # >> We set our own batch size elsewhere
            db_path=None,  # None,
            random_seed=0,
            use_subset=False,
            subset_size=-1,
            use_helper_words=True,
            approach_type=approach_type,
            num_retrievals=5,  # Will never change
            retrieval_temperature=1.,
            retriever=None,  # Cached retrievals don't need a retriever
            repeat=False,  # Will never change
            split=_FLAG_SPLIT.value,
            enable_debug_checks=False,
            retrieval_bank_size=5,  # Will never change
            dataset_type=_FLAG_DATASET_TYPE.value,
            tfr_prefix=_FLAG_TFR_PREFIX.value,
            qty_shuffle=1,  # Will never change
            max_length_generation=350),
        tokenizer,
        _FLAG_SPLIT.value)

    ds = strategy.experimental_distribute_dataset(ds)

    ##############################################################################
    # Generate
    ##############################################################################
    LOGGER.debug("Generating.")
    generations = []
    num_entries_in_split = (
        task_specific.DATASET_CARDINALITIES["kilt_eli5"][_FLAG_SPLIT.value])

    entries_counter = tqdm.tqdm(total=num_entries_in_split)

    for batch_no, batch in enumerate(ds):
        # Calling model.generate. We should make a config file with the
        # hyperparameters for generation, or make a facility in the one we already
        # have. I feel like a separate one would be better, separating concerns.
        output = strategy.run(
            model.generate,
            kwargs=dict(
                input_ids=batch,
                max_length=_FLAG_GENERATION_LENGTH_LIMIT.value,
                use_cache=True,
                attention_mask=tf.cast(batch != tokenizer.eos_token_id,
                                       tf.int32),
                repetition_penalty=2.,
                num_beams=5,
            ))
        output = tf_utils.process_strat_output(strategy_outputs=output,
                                               current_batch_size=batch_size,
                                               strategy=strategy,
                                               name="generations")

        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Display the inputs and outputs.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

        rich_console = rich.console.Console(color_system="256")
        print_sample = make_print_sample()

        with utils.log_duration(LOGGER, "main",
                                "all of tokenizer.decode for a batch."):
            for i in range(batch_size):
                input_text = tokenizer.decode(batch.numpy()[i])
                output_text = tokenizer.decode(output.numpy()[i])
                print("#" * 1000)
                print(f"Batch {batch_no} Generation {i}")
                print_sample(input_text, f"input batch_no {batch_no}",
                             rich_console)
                print_sample(output_text, f"output batch_no {batch_no}",
                             rich_console)
                generations.append(output_text)
            print("#" * 1000)
        entries_counter.update(batch.shape[0])

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Save the output to a JSON File.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    utils.to_json_file(
        os.path.join(_FLAG_OUTPUT_PATH.value, _FLAG_SPLIT.value,
                     _FLAG_APPROACH_TYPE.value,
                     time.strftime("%Y%m%d-%H%M%S.json")),
        dict(flags={
            flag.name: flag.value
            for flag in flags.FLAGS.flags_by_module_dict()[argv[0]]
        },
             generations=generations))
    logging.debug("Saved to: %s", _FLAG_OUTPUT_PATH.value)
示例#2
0
def main(argv):
    if len(argv) > 1:
        raise RuntimeError(argv[1:])
    absl_logging.use_python_logging()
    utils.check_contained(_FLAG_APPROACH_TYPE.value, _ACCEPTABLE_APPROACHES)
    db_path = _FLAG_DB_PATH.value
    model_path = _FLAG_MODEL_PATH.value
    tpu_config = tf_utils.init_tpus()
    device_type = tf_utils.devices_to_use()[0].device_type
    if device_type == "TPU":
        assert isinstance(tpu_config, tf_utils.TpuConfigType)
        strategy = tf.distribute.TPUStrategy(tpu_config.resolver)
    elif device_type == "GPU" or "CPU":
        # MirroredStrategy automatically becomes OneDeviceStrategy if there is
        # just one device, like one GPU or only CPUs.
        strategy = tf.distribute.MirroredStrategy()
    else:
        raise RuntimeError()

    ##############################################################################
    # Load Model
    ##############################################################################
    with utils.log_duration(LOGGER, main.__name__, "All of model preparation"):

        def make_model_tf(path):
            with utils.log_duration(LOGGER, make_model_tf.__name__,
                                    "Load model."):
                if os.path.exists(path):
                    config_path = os.path.join(path, "config.json")
                    model_path = os.path.join(path, "tf_model.h5")
                    utils.check_exists(config_path)
                    utils.check_exists(model_path)
                    config = transformers.GPT2Config.from_pretrained(
                        config_path)
                    return transformers.TFGPT2LMHeadModel.from_pretrained(
                        model_path, config=config)
                else:
                    return transformers.TFGPT2LMHeadModel.from_pretrained(
                        path, )

        with strategy.scope():
            if model_path.startswith("gs://"):
                with utils.log_duration(LOGGER, main.__name__,
                                        "Download model from GS"):
                    with tempfile.TemporaryDirectory() as td:
                        td += os.path.sep

                        if os.path.exists("/root/google-cloud-sdk/bin/gsutil"):
                            exec_ = "/root/google-cloud-sdk/bin/gsutil"
                        else:
                            exec_ = "gsutil"

                        command = [
                            exec_,
                            "-m",
                            "cp",
                            "-r",
                            os.path.join(model_path, "*"),
                            td,
                        ]
                        LOGGER.debug("Running bash command: %s",
                                     " ".join(command))
                        subprocess.check_call(command)
                        LOGGER.debug("Files at the temp dir(%s): %s", td,
                                     str(os.listdir(td)))

                        model = make_model_tf(td)
            else:
                model = make_model_tf(model_path)

            model.__call__ = tf.function(
                model.__call__,
                experimental_relax_shapes=True,
                experimental_compile=True,
            )

    ##############################################################################
    # Load Dataset Pipeline
    ##############################################################################

    utils.check_contained(
        _FLAG_APPROACH_TYPE.value, {
            constants.ApproachTypeChoices.naked_lm,
            constants.ApproachTypeChoices.naked_lm
        })
    devices = tf_utils.devices_to_use()
    num_replicas = len(devices) if devices[0].device_type in {"GPU", "TPU"
                                                              } else 1
    # Only a batch size of 1 is currently supported. We need attention masks
    utils.check_equal(_FLAG_BATCH_SIZE.value, 1)
    batch_size = _FLAG_BATCH_SIZE.value * num_replicas
    approach_type = _FLAG_APPROACH_TYPE.value

    # Things that will never change:
    random_seed = 0
    use_helper_words = True
    retrieval_temperature = 1
    context_window_size = 1024

    logging.debug("Loading dataset.")
    tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl")
    ds = task_specific.create_lm_ds_kilt_eli5(
        tokenizer=tokenizer,
        context_window_size=context_window_size,
        dataset_name="kilt_eli5",
        batch_size=1,  # >> We set our own batch size elsewhere
        db_path=db_path,
        random_seed=random_seed,
        use_subset=False,
        subset_size=-1,
        use_helper_words=use_helper_words,
        approach_type=approach_type,
        num_retrievals=5,  # Will never change
        retrieval_temperature=retrieval_temperature,
        retriever=None,  # Cached retrievals don't need a retriever
        repeat=False,  # Will never change
        split=_FLAG_SPLIT.value,
        enable_debug_checks=False,
        retrieval_bank_size=5,  # Will never change
        dataset_type=_FLAG_DATASET_TYPE.value,
        tfr_prefix=_FLAG_TFR_PREFIX.value,
        qty_shuffle=1,  # Will never change
        max_length_generation=_FLAG_GENERATION_LENGTH_LIMIT.value)

    def further_prep_generate_not_test(batch):
        batch = tf.boolean_mask(
            batch["input_ids"],
            tf.logical_and(batch["label_ids"] == -100,
                           batch["input_ids"] != tokenizer.eos_token_id))
        return batch

    @tf.function
    def further_prep_generate_test(batch):
        batch = tf.boolean_mask(batch["input_ids"],
                                batch["input_ids"] != tokenizer.eos_token_id)
        return batch

    if _FLAG_SPLIT.value == constants.SplitChoices.test:
        ds = ds.map(further_prep_generate_test)
    else:
        ds = ds.map(further_prep_generate_not_test)

    ds = ds.padded_batch(batch_size=batch_size,
                         padding_values=tokenizer.eos_token_id)
    ds = strategy.experimental_distribute_dataset(ds)

    ##############################################################################
    # Generate
    ##############################################################################
    LOGGER.debug("Generating.")
    generations = []
    counter = tqdm.tqdm(ds,
                        total=task_specific.DATASET_CARDINALITIES["kilt_eli5"][
                            _FLAG_SPLIT.value])

    for batch_no, batch in enumerate(counter):
        output = strategy.run(
            model.generate,
            kwargs=dict(input_ids=batch,
                        max_length=_FLAG_GENERATION_LENGTH_LIMIT.value,
                        use_cache=True,
                        attention_mask=batch == tokenizer.eos_token_id))

        LOGGER.debug("INPUT: %s", tokenizer.decode(batch[0]))
        output = tf_utils.process_strat_output(strategy_outputs=output,
                                               current_batch_size=batch_size,
                                               strategy=strategy,
                                               name="generations")

        with utils.log_duration(LOGGER, "main",
                                "all of tokenizer.decode for a batch."):
            for i in range(batch_size):
                text = tokenizer.decode(output.numpy()[i])
                LOGGER.debug("Batch %d Generation %d", batch_no, i)
                LOGGER.debug(text.replace("\n", " <\\n> "))
                generations.append(text)

        counter.update(batch.shape[0])

    utils.to_json_file(
        os.path.join(_FLAG_OUTPUT_PATH.value, _FLAG_SPLIT.value,
                     _FLAG_APPROACH_TYPE.value,
                     time.strftime("%Y%m%d-%H%M%S.json")),
        dict(flags={
            flag.name: flag.value
            for flag in flags.FLAGS.flags_by_module_dict()[argv[0]]
        },
             generations=generations))
    logging.debug("Saved to: %s", _FLAG_OUTPUT_PATH.value)
示例#3
0
def main(argv):
    # Arguments and logging boilerplate
    if len(argv) > 1:
        raise RuntimeError(argv)

    absl_logging.use_python_logging()
    utils.log_module_args(LOGGER, argv[0])

    # Load a retriever config.
    retriever_config = tf_utils.REALMConfig(
        **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value))
    assert not _FLAG_USE_SUBSET.value

    # Preparation of the output path
    time_stamp = time.strftime("%Y%m%d-%H%M%S")
    target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp.strip())
    if target_path[-1] != "/":
        target_path += "/"

    ##############################################################################
    # Setup devices and strategy
    ##############################################################################
    # Duration is pretty much instantaneous
    with utils.log_duration(LOGGER, "main", "Initializing devices"):
        tpu_config = tf_utils.init_tpus(local=_FLAG_TPU_IS_LOCAL.value,
                                        tpu_name=_FLAG_TPU_NAME.value)
        device_type = tf_utils.current_accelerator_type()
        LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use()))
        if _FLAG_TPU_NAME.value and device_type == "CPU":
            raise RuntimeError("Device is CPU and we expected a TPU.")

        if device_type == "TPU":
            if tpu_config is None:
                raise RuntimeError("We should have a tpu_config.")
            strategy = tf.distribute.TPUStrategy(tpu_config.resolver)
            batch_size = len(
                tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
        elif device_type == "GPU" or device_type == "CPU":
            strategy = tf.distribute.MirroredStrategy()
            batch_size = len(
                tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
        else:
            raise RuntimeError(device_type)

    ##############################################################################
    # Load the KILT ELI5 dataset.
    ##############################################################################
    # Takes a while
    eli5 = {}
    keys = ["train", "validation", "test"]
    gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl")
    gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

    with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."):
        if _FLAG_DATASET_ROOT.value:
            for split in tqdm.tqdm(keys):
                load_path = os.path.join(_FLAG_DATASET_ROOT.value,
                                         "HuggingfaceDatasets",
                                         f"{split}_kilt_eli5.hf")
                with tf.device("/job:localhost"):
                    eli5[split] = datasets.load_from_disk(load_path)
        else:
            eli5 = datasets.load_dataset("kilt_tasks", "eli5")

    ##############################################################################
    # Load the dataset of the text that will be retrieved.
    ##############################################################################
    # Takes a long time
    with utils.log_duration(LOGGER, "Main", "Load the textual dataset"):
        # Extract the appropriate text
        # The buffer_size is taken from the original ORQA code.
        blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records,
                                                 buffer_size=512 * 1024 * 1024)
        blocks_dataset = blocks_dataset.batch(
            retriever_config.num_block_records, drop_remainder=False)
        blocks: tf.Tensor = tf.data.experimental.get_single_element(
            blocks_dataset)

    ############################################################################
    # Increase the number of maximum open file descriptors to make space
    # for all the shards.
    ############################################################################
    max_num_fd = _FLAG_NUM_SHARDS.value * 3 + _MIN_N_FD
    resource.setrlimit(resource.RLIMIT_NOFILE, (max_num_fd, max_num_fd))

    ############################################################################
    # Prepare the output files.
    ############################################################################
    writers = {}
    all_paths = {}

    for split in keys:
        maybe_subset = "_subset" if _FLAG_USE_SUBSET.value else ""
        # Prepare paths. They can't be in a generator. A function generator would be
        # fine though.
        paths = [
            os.path.join(target_path + maybe_subset, f"{split}_{i}.tfr")
            for i in range(_FLAG_NUM_SHARDS.value)
        ]
        all_paths[split] = paths
        writers[split] = []

        # Create The TFR writers.
        for i, path in enumerate(paths):
            writers[split].append(tf.io.TFRecordWriter(path))

    # Load the reference DB. We used to accidentally do this once per split :O
    with utils.log_duration(LOGGER, "main", "Loading the reference db."):
        checkpoint_path = os.path.join(retriever_config.query_embedder_path,
                                       "encoded", "encoded.ckpt")
        reference_db_device = tf_utils.device_mapping().CPUs[0].name
        with tf.device(reference_db_device):
            reference_db = tf_utils.load_reference_db(
                checkpoint_path,
                variable_name="block_emb",
            )

    ############################################################################
    # Prep the encoder and the tokenizer
    ############################################################################
    with utils.log_duration(LOGGER, "main",
                            "Loading the encoder model and the tokenizer."):
        with strategy.scope():
            query_encoder = hub.load(retriever_config.query_embedder_path,
                                     tags={})
        encode_fn = _make_encode_fn(query_encoder)
        encode_fn_strategy_run = make_encode_fn_strategy_run_fn(
            strategy=strategy,
            encode_fn=encode_fn,
        )

        vocab_file = os.path.join(retriever_config.query_embedder_path,
                                  "assets", "vocab.txt")
        utils.check_exists(vocab_file)
        do_lower_case = query_encoder.signatures["tokenization_info"](
        )["do_lower_case"]
        tokenization_info = dict(vocab_file=vocab_file,
                                 do_lower_case=do_lower_case)

        tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
            query_encoder, tokenization_info)

    ############################################################################
    # Preprocess the dataset
    ############################################################################
    cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")),
                           tf.int32)
    sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")),
                           tf.int32)
    transform = _make_transform_fn(
        bert_tokenizer=tokenizer,
        bert_cls_token_id=cls_token_id,
        bert_sep_token_id=sep_token_id,
    )

    feature_dtypes = {
        constants.CTH5Fields.distances: tf.float32,
        constants.CTH5Fields.gpt2_retrieved_ids: tf.int32,
        constants.CTH5Fields.gpt2_answer_ids_inputs: tf.int32,
        constants.CTH5Fields.gpt2_question_ids_inputs: tf.int32,
    }

    with utils.log_duration(LOGGER, "main", "generating codes"):
        for split in keys:
            sample_count = 0
            eli5: Dict[str, datasets.Dataset]

            if split != "test":
                for_slices = dict(sample_id=eli5[split]["id"],
                                  question=eli5[split]["input"],
                                  answer=[
                                      sample[0]["answer"]
                                      for sample in eli5[split]["output"]
                                  ])
            else:
                for_slices = dict(
                    sample_id=eli5[split]["id"],
                    question=eli5[split]["input"],
                )

            ds = tf.data.Dataset.from_tensor_slices(for_slices)
            ds = ds.map(transform,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)

            ds = ds.apply(
                tf.data.experimental.dense_to_ragged_batch(batch_size))
            ds = ds.map(_squeeze,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)

            tqdm_inner = tqdm.tqdm(enumerate(ds),
                                   total=len(eli5[split]["id"]) //
                                   _FLAG_BATCH_SIZE.value,
                                   desc=f"Split `{split}`: Batches")

            for i, batch in tqdm_inner:
                features = collections.defaultdict(list)

                ######################################################################
                # Enforce the current real batch size
                ######################################################################
                current_batch_size = batch["sample_id"].shape[0]
                for k, v in batch.items():
                    utils.check_equal(v.shape[0], current_batch_size)
                ######################################################################

                gpt2_question_ids_inputs = _prep_field(batch["question"],
                                                       gpt2_tokenizer)
                utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32)
                utils.check_equal(gpt2_question_ids_inputs.shape[0],
                                  current_batch_size)

                if split != "test":
                    gpt2_answer_ids_inputs = _prep_field(
                        batch["answer"], gpt2_tokenizer)
                    utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32)
                    utils.check_equal(gpt2_answer_ids_inputs.shape[0],
                                      current_batch_size)

                    assert len(gpt2_answer_ids_inputs.shape) == 2, (
                        gpt2_answer_ids_inputs.shape)

                ######################################################################
                # Save the gpt2 tokenized question and answer
                ######################################################################

                features[constants.CTH5Fields.gpt2_question_ids_inputs].extend(
                    gpt2_question_ids_inputs)

                if split != "test":
                    features[
                        constants.CTH5Fields.gpt2_answer_ids_inputs].extend(
                            gpt2_answer_ids_inputs)

                ######################################################################
                # Encode the samples.
                ######################################################################
                batch = strategy.experimental_distribute_values_from_function(
                    tf_utils.make_dict_distribute_fn(batch))

                embeddings = encode_fn_strategy_run(batch)
                embeddings = tf_utils.process_strat_output(
                    embeddings, "embeddings", strategy, current_batch_size)
                utils.check_isinstance(embeddings, ops.EagerTensor)
                utils.check_equal(embeddings.shape[0], current_batch_size)

                # pytype doesn't seem to see that we check the type
                utils.check_equal(embeddings.shape[1],
                                  _FLAG_EMBEDDING_DEPTH.value)  # pytype: disable=attribute-error

                ######################################################################
                # Retrieve.
                ######################################################################
                # Do exact retrieval
                with tf.device(reference_db_device):
                    top_k, inner_prods = tf_utils.mips_exact_search(
                        embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db)

                # Collate the results
                top_k = tf_utils.process_strat_output(top_k, "top_k", strategy,
                                                      current_batch_size)

                # Check the shapes
                utils.check_equal(
                    inner_prods.shape,
                    (current_batch_size, _FLAG_NUM_RETRIEVALS.value))
                utils.check_equal(
                    top_k.shape,
                    (current_batch_size, _FLAG_NUM_RETRIEVALS.value))

                # Save the distances
                features[constants.CTH5Fields.distances].extend(inner_prods)

                # Retrieve the text fields associated to the indices
                gathered = tf.gather(blocks, top_k).numpy()
                utils.check_equal(gathered.shape[0], current_batch_size)
                utils.check_equal(gathered.shape[1],
                                  _FLAG_NUM_RETRIEVALS.value)

                retrievals = []
                for index_in_batch in range(current_batch_size):
                    # Put the appropriate byte strings in a list
                    local_gathered = gathered[index_in_batch].tolist()
                    utils.check_equal(len(local_gathered),
                                      _FLAG_NUM_RETRIEVALS.value)
                    # Decode to utf-8
                    local_gathered = [
                        sample.decode() for sample in local_gathered
                    ]
                    # Encode to GPT2 BPE
                    token_ids = np.array(
                        gpt2_tokenizer.batch_encode_plus(
                            local_gathered,
                            padding="max_length",
                            truncation=True,
                        ).input_ids)

                    # Make sure no line is empty
                    # TODO(julesgm): Maybe optional
                    for line in token_ids:
                        assert not np.all(line == 0), line

                    # Convert the eos_tokens
                    token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1

                    # Save the retrievals
                    retrievals.append(token_ids)

                # Save the feature
                features[constants.CTH5Fields.gpt2_retrieved_ids] = retrievals

                utils.check_equal(
                    retrievals[0].shape,
                    (_FLAG_NUM_RETRIEVALS.value, _FLAG_CONTEXT_SIZE.value))

                for k, v in features.items():
                    utils.check_equal(len(v), current_batch_size)

                for index_in_batch in range(current_batch_size):
                    feature_dict = {}
                    for feature_k, feature_v in features.items():
                        # Cast the feature to its appropriate dtype
                        casted_feats = tf.cast(feature_v[index_in_batch],
                                               feature_dtypes[feature_k])
                        # Serialize the tensor to bytes
                        feature_bytes = tf.io.serialize_tensor(casted_feats)
                        # Build a bytes list tf.train.Feature object,
                        # the serialization tree node
                        feature_dict[feature_k] = _bytes_feature(feature_bytes)

                    # Create the serialization tree root
                    # Expects a list of features
                    feature = tf.train.Features(feature=feature_dict)
                    # Expects a tf.train.Features object
                    example_obj = tf.train.Example(features=feature)

                    # Serialize that to bytes
                    serialized_example = example_obj.SerializeToString()

                    # Write the bytes
                    # TODO(julesgm): Parallelize this with a thread or a process pool &
                    #   futures.
                    writers[split][sample_count %
                                   _FLAG_NUM_SHARDS.value].write(
                                       serialized_example)
                    sample_count += 1

                if sample_count % 1000 == 0:
                    LOGGER.debug("Paths: %s", str(all_paths[split][0]))

            LOGGER.debug("Flushing and closing the `%s` writers", split)
            for writer in tqdm.tqdm(writers[split]):
                writer.flush()
                writer.close()

    LOGGER.debug("Done.")
def main(argv):
    if len(argv) > 1:
        raise RuntimeError(argv)
    absl_logging.use_python_logging()
    utils.log_module_args(LOGGER, argv[0])

    retriever_config = tf_utils.REALMSave(
        **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value))
    assert not _FLAG_USE_SUBSET.value

    time_stamp = time.strftime("%Y%m%d-%H%M%S")
    target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp.strip())
    if target_path[-1] != "/":
        target_path += "/"

    ##############################################################################
    # Setup devices and strategy
    ##############################################################################
    with utils.log_duration(LOGGER, "main", "Initializing devices"):
        tpu_config = tf_utils.init_tpus()
        device_type = tf_utils.current_accelerator_type()
        LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use()))

        if device_type == "TPU":
            if tpu_config is None:
                raise RuntimeError("We should have a tpu_config.")
            strategy = tf.distribute.TPUStrategy(tpu_config.resolver)
            batch_size = len(
                tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
        elif device_type == "GPU" or device_type == "CPU":
            strategy = tf.distribute.MirroredStrategy()
            batch_size = len(
                tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
        else:
            raise RuntimeError(device_type)

    ##############################################################################
    # Load the dataset.
    ##############################################################################
    eli5 = {}
    keys = ["train", "eval", "test"]
    gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl")
    gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

    with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."):
        for split in tqdm.tqdm(keys):
            load_path = os.path.join(_FLAG_DATASET_ROOT.value,
                                     "HuggingfaceDatasets",
                                     f"{split}_kilt_eli5.hf")
            with tf.device("/job:localhost"):
                eli5[split] = datasets.load_from_disk(load_path)

    ##############################################################################
    #
    ##############################################################################
    with utils.log_duration(LOGGER, "Main", "Load the textual dataset"):
        # Extract the appropriate text
        # The buffer_size is taken from the original ORQA code.
        blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records,
                                                 buffer_size=512 * 1024 * 1024)
        blocks_dataset = blocks_dataset.batch(
            retriever_config.num_block_records, drop_remainder=True)
        blocks = tf.data.experimental.get_single_element(blocks_dataset)

    ############################################################################
    # Prepare the output file.
    ############################################################################
    writers = {}

    all_paths = {}
    for split in keys:
        maybe_subset = "_subset" if _FLAG_USE_SUBSET.value else ""
        paths = [
            os.path.join(target_path + maybe_subset, f"{split}_{i}.tfr")
            for i in range(_FLAG_NUM_SHARDS.value)
        ]
        all_paths[split] = paths
        writers[split] = [tf.io.TFRecordWriter(filename) for filename in paths]

        with utils.log_duration(LOGGER, "main", "Loading the reference db."):
            checkpoint_path = os.path.join(
                retriever_config.query_embedder_path, "encoded",
                "encoded.ckpt")

            reference_db_device = tf_utils.device_mapping().CPUs[0].name
            with tf.device(reference_db_device):
                reference_db = tf_utils.load_reference_db(
                    checkpoint_path,
                    variable_name="block_emb",
                )

    ############################################################################
    # Prep the encoder and the tokenizer
    ############################################################################
    with utils.log_duration(LOGGER, "main",
                            "Loading the encoder model and the tokenizer."):
        with strategy.scope():
            query_encoder = hub.load(retriever_config.query_embedder_path,
                                     tags={})
        encode_fn = _make_encode_fn(query_encoder)
        encode_fn_strategy_run = make_encode_fn_strategy_run_fn(
            strategy=strategy,
            encode_fn=encode_fn,
        )

        vocab_file = os.path.join(retriever_config.query_embedder_path,
                                  "assets", "vocab.txt")
        utils.check_exists(vocab_file)
        do_lower_case = query_encoder.signatures["tokenization_info"](
        )["do_lower_case"]
        tokenization_info = dict(vocab_file=vocab_file,
                                 do_lower_case=do_lower_case)

        tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
            query_encoder, tokenization_info)

    ############################################################################
    # Preprocess the dataset
    ############################################################################
    cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")),
                           tf.int32)
    sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")),
                           tf.int32)
    transform = _make_transform_fn(
        bert_tokenizer=tokenizer,
        bert_cls_token_id=cls_token_id,
        bert_sep_token_id=sep_token_id,
    )

    feature_dtypes = {
        constants.CTH5Fields.distances: tf.float32,
        constants.CTH5Fields.gpt2_retrieved_ids: tf.int32,
        constants.CTH5Fields.gpt2_answer_ids_inputs: tf.int32,
        constants.CTH5Fields.gpt2_question_ids_inputs: tf.int32,
    }

    with utils.log_duration(LOGGER, "main", "generating codes"):
        for split in keys:
            sample_count = 0
            eli5: Dict[str, datasets.Dataset]

            if split != "test":
                for_slices = dict(sample_id=eli5[split]["id"],
                                  question=eli5[split]["input"],
                                  answer=[
                                      sample["answer"][0]
                                      for sample in eli5[split]["output"]
                                  ])
            else:
                for_slices = dict(
                    sample_id=eli5[split]["id"],
                    question=eli5[split]["input"],
                )

            ds = tf.data.Dataset.from_tensor_slices(for_slices)
            ds = ds.map(transform,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)

            ds = ds.apply(
                tf.data.experimental.dense_to_ragged_batch(batch_size))
            ds = ds.map(_squeeze,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)

            tqdm_inner = tqdm.tqdm(enumerate(ds),
                                   total=len(eli5[split]["id"]) //
                                   _FLAG_BATCH_SIZE.value,
                                   desc=f"Split `{split}`: Batches")

            for i, batch in tqdm_inner:
                features = collections.defaultdict(list)

                ######################################################################
                # Enforce the current real batch size
                ######################################################################
                current_batch_size = batch["sample_id"].shape[0]
                for k, v in batch.items():
                    utils.check_equal(v.shape[0], current_batch_size)
                ######################################################################

                gpt2_question_ids_inputs = _prep_field(batch["question"],
                                                       gpt2_tokenizer)
                utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32)
                utils.check_equal(gpt2_question_ids_inputs.shape[0],
                                  current_batch_size)

                if split != "test":
                    gpt2_answer_ids_inputs = _prep_field(
                        batch["answer"], gpt2_tokenizer)
                    utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32)
                    utils.check_equal(gpt2_answer_ids_inputs.shape[0],
                                      current_batch_size)

                    assert len(gpt2_answer_ids_inputs.shape) == 2, (
                        gpt2_answer_ids_inputs.shape)

                ######################################################################
                # Save the gpt2 tokenized question and answer
                ######################################################################

                features[constants.CTH5Fields.gpt2_question_ids_inputs].extend(
                    gpt2_question_ids_inputs)

                if split != "test":
                    features[
                        constants.CTH5Fields.gpt2_answer_ids_inputs].extend(
                            gpt2_answer_ids_inputs)

                ######################################################################
                # Encode the samples.
                ######################################################################
                batch = strategy.experimental_distribute_values_from_function(
                    tf_utils.make_dict_distribute_fn(batch))

                embeddings = encode_fn_strategy_run(batch)
                embeddings = tf_utils.process_strat_output(
                    embeddings, "embeddings", strategy, current_batch_size)
                utils.check_isinstance(embeddings, ops.EagerTensor)
                utils.check_equal(embeddings.shape[0], current_batch_size)

                # pytype doesn't seem to see that we check the type
                utils.check_equal(embeddings.shape[1],
                                  _FLAG_EMBEDDING_DEPTH.value)  # pytype: disable=attribute-error

                ######################################################################
                # Retrieve.
                ######################################################################
                with tf.device(reference_db_device):
                    top_k, inner_prods = tf_utils.mips_exact_search(
                        embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db)
                top_k = tf_utils.process_strat_output(top_k, "top_k", strategy,
                                                      current_batch_size)
                utils.check_equal(
                    inner_prods.shape,
                    (current_batch_size, _FLAG_NUM_RETRIEVALS.value))
                utils.check_equal(
                    top_k.shape,
                    (current_batch_size, _FLAG_NUM_RETRIEVALS.value))

                features[constants.CTH5Fields.distances].extend(inner_prods)

                gathered = tf.gather(blocks, top_k).numpy()
                utils.check_equal(gathered.shape[0], current_batch_size)
                retrievals = []
                for j in range(gathered.shape[0]):
                    local_gathered = gathered[j].tolist()
                    utils.check_equal(len(local_gathered),
                                      _FLAG_NUM_RETRIEVALS.value)
                    local_gathered = [
                        sample.decode() for sample in local_gathered
                    ]
                    token_ids = np.array(
                        gpt2_tokenizer.batch_encode_plus(
                            local_gathered,
                            padding="max_length",
                            truncation=True,
                        ).input_ids)
                    for line in token_ids:
                        assert not np.all(line == 0), line

                    token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1
                    retrievals.append(token_ids)
                features[constants.CTH5Fields.gpt2_retrieved_ids] = retrievals

                utils.check_equal(
                    retrievals[0].shape,
                    (_FLAG_NUM_RETRIEVALS.value, _FLAG_CONTEXT_SIZE.value))

                for k, v in features.items():
                    utils.check_equal(len(v), current_batch_size)

                for k in range(current_batch_size):
                    feature = tf.train.Features(
                        feature={
                            k: _bytes_feature(
                                tf.io.serialize_tensor(
                                    tf.cast(v[k], feature_dtypes[k])))
                            for k, v in features.items()
                        })

                    writers[split][
                        sample_count % _FLAG_NUM_SHARDS.value].write(
                            tf.train.Example(
                                features=feature).SerializeToString())
                    sample_count += 1
                if sample_count % 1000 == 0:
                    LOGGER.debug("Paths: %s", str(all_paths[split][0]))

    LOGGER.debug("Done.")
示例#5
0
def main(argv):
    if len(argv) > 1:
        raise RuntimeError(argv)
    absl_logging.use_python_logging()
    retriever_config = tf_utils.REALMSave(
        **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value))

    extra = "_FROM_SUBSET" if _FLAG_USE_SUBSET.value else ""
    time_stamp = time.strftime("%Y%m%d-%H%M%S")
    target_path = os.path.join(_FLAG_OUTPUT_PATH.value,
                               time_stamp + extra).strip()
    if target_path[-1] != "/":
        target_path += "/"

    ##############################################################################
    # Setup devices and strategy
    ##############################################################################
    with utils.log_duration(LOGGER, "main", "Initializing devices"):
        tpu_config = tf_utils.init_tpus()
        device_type = tf_utils.current_accelerator_type()
        LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use()))

        if device_type == "TPU":
            if tpu_config is None:
                raise RuntimeError("We should have a tpu_config.")
            strategy = tf.distribute.TPUStrategy(tpu_config.resolver)
            batch_size = len(
                tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
        elif device_type == "GPU" or device_type == "CPU":
            strategy = tf.distribute.MirroredStrategy()
            batch_size = len(
                tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
        else:
            raise RuntimeError(device_type)

    ##############################################################################
    # Load the dataset.
    ##############################################################################
    eli5 = {}
    keys = ["train", "eval", "test"]
    gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl")
    gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

    with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."):
        for split in tqdm.tqdm(keys):
            load_path = os.path.join(_FLAGS_DATASET_ROOT.value,
                                     "HuggingfaceDatasets",
                                     f"{split}_kilt_eli5.hf")
            with tf.device("/job:localhost"):
                eli5[split] = datasets.load_from_disk(load_path)

    if _FLAG_USE_SUBSET.value:
        _warn_subset()

    ##############################################################################
    #
    ##############################################################################
    with utils.log_duration(LOGGER, "Main", "Load the textual dataset"):
        # Extract the appropriate text
        # The buffer_size is taken from the original ORQA code.
        blocks_dataset = tf.data.TFRecordDataset(retriever_config.text_records,
                                                 buffer_size=512 * 1024 * 1024)
        blocks_dataset = blocks_dataset.batch(
            retriever_config.num_block_records, drop_remainder=True)
        blocks = tf.data.experimental.get_single_element(blocks_dataset)

    with tempfile.TemporaryDirectory() as tmp_dir:
        ############################################################################
        # Prepare the output file.
        ############################################################################
        tmp_dir = pathlib.Path(tmp_dir)
        h5_output_path = tmp_dir / "codes.h5"
        output_file = h5py.File(h5_output_path, "w")
        flags_dict = {
            flag.name: flag.value
            for flag in flags.FLAGS.flags_by_module_dict()[argv[0]]
        }
        utils.to_json_file(tmp_dir / "params.json", flags_dict)

        for split in keys:
            with utils.log_duration(
                    LOGGER, "main",
                    "Creating the output hdf5 file, embeddings."):
                num_entries = len(eli5[split]["id"])
                if _FLAG_USE_SUBSET.value:
                    num_entries = min(num_entries, _FLAG_SUBSET_AMOUNT.value)
                split_group = output_file.create_group(split)

            with utils.log_duration(
                    LOGGER, "main",
                    "Creating the output hdf5 file, retrieval."):
                split_group.create_dataset(
                    constants.CTH5Fields.distances,
                    shape=(num_entries, _FLAG_NUM_RETRIEVALS.value),
                    dtype=np.float32,
                )
                split_group.create_dataset(
                    constants.CTH5Fields.gpt2_question_ids_inputs,
                    shape=(num_entries, _FLAG_CONTEXT_SIZE.value),
                    dtype=np.int32)
                if split != "test":
                    split_group.create_dataset(
                        constants.CTH5Fields.gpt2_answer_ids_inputs,
                        shape=(num_entries, _FLAG_CONTEXT_SIZE.value),
                        dtype=np.int32)

                split_group.create_dataset(
                    constants.CTH5Fields.gpt2_retrieved_ids,
                    shape=(
                        num_entries,
                        _FLAG_NUM_RETRIEVALS.value,
                        _FLAG_MAX_LENGTH_RETRIEVALS.value,
                    ),
                    dtype=np.int32)

            with utils.log_duration(LOGGER, "main",
                                    "Loading the reference db."):
                checkpoint_path = os.path.join(
                    retriever_config.query_embedder_path, "encoded",
                    "encoded.ckpt")

                reference_db_device = tf_utils.device_mapping().CPUs[0].name
                with tf.device(reference_db_device):
                    reference_db = tf_utils.load_reference_db(
                        checkpoint_path,
                        variable_name="block_emb",
                    )

        ############################################################################
        # Prep the encoder and the tokenizer
        ############################################################################
        with utils.log_duration(
                LOGGER, "main",
                "Loading the encoder model and the tokenizer."):
            with strategy.scope():
                query_encoder = hub.load(retriever_config.query_embedder_path,
                                         tags={})
            encode_fn = _make_encode_fn(query_encoder)
            encode_fn_strategy_run = _make_encode_fn_strategy_run_fn(
                strategy=strategy,
                encode_fn=encode_fn,
            )

            vocab_file = os.path.join(retriever_config.query_embedder_path,
                                      "assets", "vocab.txt")
            utils.check_exists(vocab_file)
            do_lower_case = query_encoder.signatures["tokenization_info"](
            )["do_lower_case"]
            tokenization_info = dict(vocab_file=vocab_file,
                                     do_lower_case=do_lower_case)

            tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
                query_encoder, tokenization_info)

        ############################################################################
        # Preprocess the dataset
        ############################################################################

        cls_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[CLS]")),
                               tf.int32)
        sep_token_id = tf.cast(vocab_lookup_table.lookup(tf.constant("[SEP]")),
                               tf.int32)
        transform = _make_transform_fn(
            bert_tokenizer=tokenizer,
            bert_cls_token_id=cls_token_id,
            bert_sep_token_id=sep_token_id,
        )

        with utils.log_duration(LOGGER, "main", "generating codes"):
            tqdm_splits = tqdm.tqdm(keys)
            for split in tqdm_splits:
                tqdm_splits.set_description(f"Split `{split}`")
                eli5: Dict[str, datasets.Dataset]
                write_start = 0

                if _FLAG_USE_SUBSET.value:
                    _warn_subset(tqdm_splits)
                    eli5[split] = eli5[split][:_FLAG_SUBSET_AMOUNT.value]
                    utils.check_operator(operator.le, len(eli5[split]["id"]),
                                         _FLAG_SUBSET_AMOUNT.value)
                    utils.check_operator(operator.le,
                                         len(eli5[split]["input"]),
                                         _FLAG_SUBSET_AMOUNT.value)
                else:
                    utils.check_equal(len(eli5[split]), len(eli5[split]["id"]))
                    utils.check_equal(len(eli5[split]),
                                      len(eli5[split]["input"]))

                if split != "test":
                    for_slices = dict(sample_id=eli5[split]["id"],
                                      question=eli5[split]["input"],
                                      answer=[
                                          sample["answer"][0]
                                          for sample in eli5[split]["output"]
                                      ])
                else:
                    for_slices = dict(
                        sample_id=eli5[split]["id"],
                        question=eli5[split]["input"],
                    )

                ds = tf.data.Dataset.from_tensor_slices(for_slices)
                ds = ds.map(transform,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)

                ds = ds.apply(
                    tf.data.experimental.dense_to_ragged_batch(batch_size))
                ds = ds.map(_squeeze,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)

                tqdm_inner = tqdm.tqdm(enumerate(ds),
                                       total=len(eli5[split]["id"]) //
                                       _FLAG_BATCH_SIZE.value,
                                       desc=f"Split `{split}`: Batches")

                for i, batch in tqdm_inner:
                    ######################################################################
                    # Enforce the current real batch size
                    ######################################################################
                    current_batch_size = batch["sample_id"].shape[0]
                    for k, v in batch.items():
                        utils.check_equal(v.shape[0], current_batch_size)
                    ######################################################################

                    gpt2_question_ids_inputs = _prep_field(
                        batch["question"], gpt2_tokenizer)
                    utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32)
                    utils.check_equal(gpt2_question_ids_inputs.shape[0],
                                      current_batch_size)

                    if split != "test":
                        gpt2_answer_ids_inputs = _prep_field(
                            batch["answer"], gpt2_tokenizer)
                        utils.check_equal(gpt2_answer_ids_inputs.dtype,
                                          np.int32)
                        utils.check_equal(gpt2_answer_ids_inputs.shape[0],
                                          current_batch_size)

                        assert len(gpt2_answer_ids_inputs.shape) == 2, (
                            gpt2_answer_ids_inputs.shape)

                    ######################################################################
                    # Save the gpt2 tokenized question and answer
                    ######################################################################
                    end = write_start + current_batch_size

                    utils.check_equal(
                        output_file[split][
                            constants.CTH5Fields.gpt2_question_ids_inputs]
                        [write_start:end].shape[0], current_batch_size)
                    output_file[split][
                        constants.CTH5Fields.gpt2_question_ids_inputs][
                            write_start:end] = gpt2_question_ids_inputs

                    if split != "test":
                        output_file[split][
                            constants.CTH5Fields.gpt2_answer_ids_inputs][
                                write_start:end] = gpt2_answer_ids_inputs

                    ######################################################################
                    # Encode the samples.
                    ######################################################################
                    batch = strategy.experimental_distribute_values_from_function(
                        tf_utils.make_dict_distribute_fn(batch))

                    embeddings = encode_fn_strategy_run(batch)
                    embeddings = tf_utils.process_strat_output(
                        embeddings, "embeddings", strategy, current_batch_size)
                    utils.check_isinstance(embeddings, ops.EagerTensor)
                    utils.check_equal(embeddings.shape[0], current_batch_size)

                    # pytype doesn't seem to see that we check the type
                    utils.check_equal(embeddings.shape[1],
                                      _FLAG_EMBEDDING_DEPTH.value)  # pytype: disable=attribute-error

                    ######################################################################
                    # Retrieve.
                    ######################################################################
                    with tf.device(reference_db_device):
                        top_k, inner_prods = tf_utils.mips_exact_search(
                            embeddings, _FLAG_NUM_RETRIEVALS.value,
                            reference_db)
                    top_k = tf_utils.process_strat_output(
                        top_k, "top_k", strategy, current_batch_size)
                    utils.check_equal(
                        inner_prods.shape,
                        (current_batch_size, _FLAG_NUM_RETRIEVALS.value))
                    utils.check_equal(
                        top_k.shape,
                        (current_batch_size, _FLAG_NUM_RETRIEVALS.value))

                    output_file[split]["distances"][
                        write_start:end] = inner_prods

                    gathered = tf.gather(blocks, top_k).numpy()
                    utils.check_equal(gathered.shape[0], current_batch_size)

                    utils.check_equal(write_start + gathered.shape[0], end)
                    for j in range(gathered.shape[0]):
                        local_gathered = gathered[j].tolist()
                        utils.check_equal(len(local_gathered),
                                          _FLAG_NUM_RETRIEVALS.value)
                        local_gathered = [
                            sample.decode() for sample in local_gathered
                        ]
                        token_ids = np.array(
                            gpt2_tokenizer.batch_encode_plus(
                                local_gathered,
                                padding="max_length",
                                truncation=True,
                            ).input_ids)
                        for line in token_ids:
                            assert not np.all(line == 0), line

                        token_ids[token_ids ==
                                  gpt2_tokenizer.eos_token_id] = -1
                        output_file[split][
                            constants.CTH5Fields.gpt2_retrieved_ids][
                                write_start +
                                j] = token_ids[:, :_FLAG_MAX_LENGTH_RETRIEVALS.
                                               value]

                    write_start += current_batch_size
        ############################################################################
        # Upload the results to GCS
        ############################################################################
        LOGGER.debug("DONE WITH THE PRODUCTION")
        output_file.close()
        with utils.log_duration(LOGGER, "main", "gsutil transfer"):
            command = [
                "/root/google-cloud-sdk/bin/gsutil", "-m", "cp", "-r",
                str(tmp_dir / "*"), target_path
            ]
            LOGGER.debug("Command: %s", " ".join(command))
            subprocess.check_call(command)
        LOGGER.debug("ALL DONE")