Ejemplo n.º 1
0
def require_distributed_retrieval(test_case):
    """
    Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
    :class:`~transformers.RagRetriever`.

    These tests are skipped when respective libraries are not installed.

    """
    if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
        test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
    return test_case
Ejemplo n.º 2
0
def require_retrieval(test_case):
    """
    Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
    :class:`~transformers.RagRetriever`.

    These tests are skipped when respective libraries are not installed.

    """
    if not (is_tf_available() and is_datasets_available() and is_faiss_available()):
        test_case = unittest.skip("test requires tensorflow, datasets and faiss")(test_case)
    return test_case
Ejemplo n.º 3
0
def require_retrieval(test_case):
    """
    Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
    [`RagRetriever`].

    These tests are skipped when respective libraries are not installed.

    """
    if not (is_torch_available() and is_datasets_available()
            and is_faiss_available()):
        test_case = unittest.skip("test requires PyTorch, datasets and faiss")(
            test_case)
    return test_case
Ejemplo n.º 4
0
class RagTestMixin:

    all_model_classes = ((RagModel, RagTokenForGeneration,
                          RagSequenceForGeneration)
                         if is_torch_available() and is_datasets_available()
                         and is_faiss_available() else ())

    retrieval_vector_size = 32
    n_docs = 3
    max_combined_length = 16

    def setUp(self):
        self.tmpdirname = tempfile.mkdtemp()

        # DPR tok
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "[PAD]",
            "[MASK]",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
        ]
        dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
        os.makedirs(dpr_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(dpr_tokenizer_path,
                                       DPR_VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

        # BART tok
        vocab = [
            "l",
            "o",
            "w",
            "e",
            "r",
            "s",
            "t",
            "i",
            "d",
            "n",
            "\u0120",
            "\u0120l",
            "\u0120n",
            "\u0120lo",
            "\u0120low",
            "er",
            "\u0120lowest",
            "\u0120newer",
            "\u0120wider",
            "<unk>",
        ]
        vocab_tokens = dict(zip(vocab, range(len(vocab))))
        merges = [
            "#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""
        ]
        self.special_tokens_map = {"unk_token": "<unk>"}

        bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
        os.makedirs(bart_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(bart_tokenizer_path,
                                       BART_VOCAB_FILES_NAMES["vocab_file"])
        self.merges_file = os.path.join(bart_tokenizer_path,
                                        BART_VOCAB_FILES_NAMES["merges_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as fp:
            fp.write(json.dumps(vocab_tokens) + "\n")
        with open(self.merges_file, "w", encoding="utf-8") as fp:
            fp.write("\n".join(merges))

        t5_tokenizer = T5Tokenizer(T5_SAMPLE_VOCAB)
        t5_tokenizer_path = os.path.join(self.tmpdirname, "t5_tokenizer")
        t5_tokenizer.save_pretrained(t5_tokenizer_path)

    @cached_property
    def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
        return DPRQuestionEncoderTokenizer.from_pretrained(
            os.path.join(self.tmpdirname, "dpr_tokenizer"))

    @cached_property
    def bart_tokenizer(self) -> BartTokenizer:
        return BartTokenizer.from_pretrained(
            os.path.join(self.tmpdirname, "bart_tokenizer"))

    @cached_property
    def t5_tokenizer(self) -> BartTokenizer:
        return T5Tokenizer.from_pretrained(
            os.path.join(self.tmpdirname, "t5_tokenizer"))

    def tearDown(self):
        shutil.rmtree(self.tmpdirname)

    def get_retriever(self, config):
        dataset = Dataset.from_dict({
            "id": ["0", "1", "3"],
            "text": ["foo", "bar", "qux"],
            "title": ["Foo", "Bar", "Qux"],
            "embeddings": [
                np.ones(self.retrieval_vector_size),
                2 * np.ones(self.retrieval_vector_size),
                3 * np.ones(self.retrieval_vector_size),
            ],
        })
        dataset.add_faiss_index("embeddings",
                                string_factory="Flat",
                                metric_type=faiss.METRIC_INNER_PRODUCT)
        tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer
        with patch("transformers.retrieval_rag.load_dataset"
                   ) as mock_load_dataset:
            mock_load_dataset.return_value = dataset
            retriever = RagRetriever(
                config,
                question_encoder_tokenizer=self.dpr_tokenizer,
                generator_tokenizer=tokenizer,
            )
        return retriever

    def check_model_with_retriever(self, config, input_ids, attention_mask,
                                   decoder_input_ids, decoder_attention_mask,
                                   **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(
                config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], self.n_docs))

    def check_model_generate(self, config, input_ids, attention_mask,
                             decoder_input_ids, decoder_attention_mask,
                             **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes[1:]:
            model = model_class(
                config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model.generate(
                input_ids=input_ids,
                num_beams=2,
                num_return_sequences=2,
                decoder_start_token_id=config.generator.eos_token_id,
            )

            self.assertIsNotNone(outputs)

    def check_model_without_retriever(self, config, input_ids, attention_mask,
                                      decoder_input_ids,
                                      decoder_attention_mask, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config).to(torch_device)
            model.eval()
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(
                input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.cpu().detach().to(
                    torch.float32).numpy(),
                prefix=config.generator.prefix,
                return_tensors="pt",
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            # cast
            retrieved_doc_embeds = retrieved_doc_embeds.to(
                question_hidden_states)
            context_input_ids = context_input_ids.to(input_ids)
            context_attention_mask = context_attention_mask.to(input_ids)

            # compute doc_scores
            doc_scores = torch.bmm(question_hidden_states.unsqueeze(1),
                                   retrieved_doc_embeds.transpose(
                                       1, 2)).squeeze(1)

            outputs = model(
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], self.n_docs))

    def check_model_custom_n_docs(self, config, input_ids, attention_mask,
                                  decoder_input_ids, decoder_attention_mask,
                                  n_docs, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config).to(torch_device)
            model.eval()
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(
                input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.cpu().detach().to(
                    torch.float32).numpy(),
                prefix=config.generator.prefix,
                return_tensors="pt",
                n_docs=n_docs,
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            # cast
            retrieved_doc_embeds = retrieved_doc_embeds.to(
                question_hidden_states)
            context_input_ids = context_input_ids.to(input_ids)
            context_attention_mask = context_attention_mask.to(input_ids)

            # compute doc_scores
            doc_scores = torch.bmm(question_hidden_states.unsqueeze(1),
                                   retrieved_doc_embeds.transpose(
                                       1, 2)).squeeze(1)

            outputs = model(
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                n_docs=n_docs,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (n_docs * decoder_input_ids.shape[0], self.max_combined_length,
                 config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], n_docs))

    def check_model_with_mismatch_n_docs_value(self, config, input_ids,
                                               attention_mask,
                                               decoder_input_ids,
                                               decoder_attention_mask,
                                               retriever_n_docs,
                                               generator_n_docs, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config).to(torch_device)
            model.eval()
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(
                input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.cpu().detach().to(
                    torch.float32).numpy(),
                prefix=config.generator.prefix,
                return_tensors="pt",
                n_docs=retriever_n_docs,
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            # cast
            retrieved_doc_embeds = retrieved_doc_embeds.to(
                question_hidden_states)
            context_input_ids = context_input_ids.to(input_ids)
            context_attention_mask = context_attention_mask.to(input_ids)

            # compute doc_scores
            doc_scores = torch.bmm(question_hidden_states.unsqueeze(1),
                                   retrieved_doc_embeds.transpose(
                                       1, 2)).squeeze(1)

            self.assertRaises(
                AssertionError,
                model.__call__,
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                n_docs=generator_n_docs,
            )

    def check_model_with_encoder_outputs(self, config, input_ids,
                                         attention_mask, decoder_input_ids,
                                         decoder_attention_mask, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(
                config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            encoder_outputs = BaseModelOutput(
                outputs.generator_enc_last_hidden_state)

            # run only generator
            outputs = model(
                encoder_outputs=encoder_outputs,
                doc_scores=outputs.doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], self.n_docs))

    def test_model_with_retriever(self):
        inputs_dict = self.config_and_inputs
        self.check_model_with_retriever(**inputs_dict)

    def test_model_without_retriever(self):
        inputs_dict = self.config_and_inputs
        self.check_model_without_retriever(**inputs_dict)

    def test_model_with_encoder_outputs(self):
        inputs_dict = self.config_and_inputs
        self.check_model_with_encoder_outputs(**inputs_dict)

    def test_model_generate(self):
        inputs_dict = self.config_and_inputs
        self.check_model_generate(**inputs_dict)

    def test_model_with_custom_n_docs(self):
        inputs_dict = self.config_and_inputs
        inputs_dict["n_docs"] = 1
        self.check_model_custom_n_docs(**inputs_dict)

    def test_model_with_mismatch_n_docs_value(self):
        inputs_dict = self.config_and_inputs
        inputs_dict["retriever_n_docs"] = 3
        inputs_dict["generator_n_docs"] = 2
        self.check_model_with_mismatch_n_docs_value(**inputs_dict)
Ejemplo n.º 5
0
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES

from .test_modeling_bart import ModelTester as BartModelTester
from .test_modeling_dpr import DPRModelTester
from .test_modeling_t5 import T5ModelTester

TOLERANCE = 1e-3

T5_SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                               "fixtures/test_sentencepiece.model")

if is_torch_available() and is_datasets_available() and is_faiss_available():
    import torch
    from datasets import Dataset

    import faiss
    from transformers import (
        AutoConfig,
        AutoModel,
        AutoModelForSeq2SeqLM,
        RagConfig,
        RagModel,
        RagRetriever,
        RagSequenceForGeneration,
        RagTokenForGeneration,
        RagTokenizer,
    )
Ejemplo n.º 6
0
import numpy as np
import pandas as pd

from transformers import (AutoModel, AutoTokenizer, BertModel, BertTokenizer,
                          BertTokenizerFast, DPRQuestionEncoder,
                          DPRQuestionEncoderTokenizer, RobertaTokenizer)
from transformers.file_utils import is_faiss_available, requires_backends

from pyserini.util import (download_encoded_queries, download_prebuilt_index,
                           get_dense_indexes_info)

from ._model import AnceEncoder
import torch

if is_faiss_available():
    import faiss


class QueryEncoder:
    def __init__(self, encoded_query_dir: str = None):
        self.has_model = False
        self.has_encoded_query = False
        if encoded_query_dir:
            self.embedding = self._load_embeddings(encoded_query_dir)
            self.has_encoded_query = True

    def encode(self, query: str):
        return self.embedding[query]

    @classmethod