Esempio n. 1
0
    def test_scann_searcher(self):
        temp_dir = self.create_tempdir().full_path
        checkpoint_path = os.path.join(temp_dir, "dummy_db.ckpt")

        dummy_db = np.random.uniform(size=[1024, 32]).astype(np.float32)
        scann_utils.write_array_to_checkpoint("dummy_db", dummy_db,
                                              checkpoint_path)

        dummy_queries = np.random.uniform(size=[4, 32]).astype(np.float32)
        _, searcher = scann_utils.load_scann_searcher(
            var_name="dummy_db",
            checkpoint_path=checkpoint_path,
            num_neighbors=10)
        distance, index = searcher.search_batched(dummy_queries)
        self.assertAllEqual(distance.numpy().shape, [4, 10])
        self.assertAllEqual(index.numpy().shape, [4, 10])
Esempio n. 2
0
def retrieve(features, retriever_beam_size, mode, params):
    """Do retrieval."""
    tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
        params["retriever_module_path"])

    question_token_ids = tokenizer.tokenize(
        tf.expand_dims(features["question"], 0))
    question_token_ids = tf.cast(
        question_token_ids.merge_dims(1, 2).to_tensor(), tf.int32)
    cls_token_id = vocab_lookup_table.lookup(tf.constant("[CLS]"))
    sep_token_id = vocab_lookup_table.lookup(tf.constant("[SEP]"))
    question_token_ids = tf.concat(
        [[[tf.cast(cls_token_id, tf.int32)]], question_token_ids,
         [[tf.cast(sep_token_id, tf.int32)]]], -1)

    retriever_module = hub.Module(
        params["retriever_module_path"],
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)

    # [1, projection_size]
    question_emb = retriever_module(inputs=dict(
        input_ids=question_token_ids,
        input_mask=tf.ones_like(question_token_ids),
        segment_ids=tf.zeros_like(question_token_ids)),
                                    signature="projected")

    block_emb, searcher = scann_utils.load_scann_searcher(
        var_name="block_emb",
        checkpoint_path=os.path.join(params["retriever_module_path"],
                                     "encoded", "encoded.ckpt"),
        num_neighbors=retriever_beam_size)

    # [1, retriever_beam_size]
    retrieved_block_ids, _ = searcher.search_batched(question_emb)

    # [1, retriever_beam_size, projection_size]
    retrieved_block_emb = tf.gather(block_emb, retrieved_block_ids)

    # [retriever_beam_size]
    retrieved_block_ids = tf.squeeze(retrieved_block_ids)

    # [retriever_beam_size, projection_size]
    retrieved_block_emb = tf.squeeze(retrieved_block_emb)

    # [1, retriever_beam_size]
    retrieved_logits = tf.matmul(question_emb,
                                 retrieved_block_emb,
                                 transpose_b=True)

    # [retriever_beam_size]
    retrieved_logits = tf.squeeze(retrieved_logits, 0)

    blocks_dataset = tf.data.TFRecordDataset(params["block_records_path"],
                                             buffer_size=512 * 1024 * 1024)
    blocks_dataset = blocks_dataset.batch(params["num_block_records"],
                                          drop_remainder=True)
    blocks = tf.get_local_variable(
        "blocks",
        initializer=tf.data.experimental.get_single_element(blocks_dataset))
    retrieved_blocks = tf.gather(blocks, retrieved_block_ids)
    return RetrieverOutputs(logits=retrieved_logits, blocks=retrieved_blocks)