def require_distributed_retrieval(test_case): """ Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with :class:`~transformers.RagRetriever`. These tests are skipped when respective libraries are not installed. """ if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()): test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case) return test_case
def require_retrieval(test_case): """ Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with :class:`~transformers.RagRetriever`. These tests are skipped when respective libraries are not installed. """ if not (is_tf_available() and is_datasets_available() and is_faiss_available()): test_case = unittest.skip("test requires tensorflow, datasets and faiss")(test_case) return test_case
def require_retrieval(test_case): """ Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with [`RagRetriever`]. These tests are skipped when respective libraries are not installed. """ if not (is_torch_available() and is_datasets_available() and is_faiss_available()): test_case = unittest.skip("test requires PyTorch, datasets and faiss")( test_case) return test_case
class RagTestMixin: all_model_classes = ((RagModel, RagTokenForGeneration, RagSequenceForGeneration) if is_torch_available() and is_datasets_available() and is_faiss_available() else ()) retrieval_vector_size = 32 n_docs = 3 max_combined_length = 16 def setUp(self): self.tmpdirname = tempfile.mkdtemp() # DPR tok vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest", ] dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") os.makedirs(dpr_tokenizer_path, exist_ok=True) self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) # BART tok vocab = [ "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "\u0120", "\u0120l", "\u0120n", "\u0120lo", "\u0120low", "er", "\u0120lowest", "\u0120newer", "\u0120wider", "<unk>", ] vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = [ "#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", "" ] self.special_tokens_map = {"unk_token": "<unk>"} bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") os.makedirs(bart_tokenizer_path, exist_ok=True) self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) with open(self.vocab_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(vocab_tokens) + "\n") with open(self.merges_file, "w", encoding="utf-8") as fp: fp.write("\n".join(merges)) t5_tokenizer = T5Tokenizer(T5_SAMPLE_VOCAB) t5_tokenizer_path = os.path.join(self.tmpdirname, "t5_tokenizer") t5_tokenizer.save_pretrained(t5_tokenizer_path) @cached_property def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: return DPRQuestionEncoderTokenizer.from_pretrained( os.path.join(self.tmpdirname, "dpr_tokenizer")) @cached_property def bart_tokenizer(self) -> BartTokenizer: return BartTokenizer.from_pretrained( os.path.join(self.tmpdirname, "bart_tokenizer")) @cached_property def t5_tokenizer(self) -> BartTokenizer: return T5Tokenizer.from_pretrained( os.path.join(self.tmpdirname, "t5_tokenizer")) def tearDown(self): shutil.rmtree(self.tmpdirname) def get_retriever(self, config): dataset = Dataset.from_dict({ "id": ["0", "1", "3"], "text": ["foo", "bar", "qux"], "title": ["Foo", "Bar", "Qux"], "embeddings": [ np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size), 3 * np.ones(self.retrieval_vector_size), ], }) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer with patch("transformers.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.dpr_tokenizer, generator_tokenizer=tokenizer, ) return retriever def check_model_with_retriever(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes: model = model_class( config, retriever=self.get_retriever(config)).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) outputs = model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def check_model_generate(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes[1:]: model = model_class( config, retriever=self.get_retriever(config)).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) outputs = model.generate( input_ids=input_ids, num_beams=2, num_return_sequences=2, decoder_start_token_id=config.generator.eos_token_id, ) self.assertIsNotNone(outputs) def check_model_without_retriever(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder( input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.cpu().detach().to( torch.float32).numpy(), prefix=config.generator.prefix, return_tensors="pt", ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) # cast retrieved_doc_embeds = retrieved_doc_embeds.to( question_hidden_states) context_input_ids = context_input_ids.to(input_ids) context_attention_mask = context_attention_mask.to(input_ids) # compute doc_scores doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose( 1, 2)).squeeze(1) outputs = model( context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def check_model_custom_n_docs(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, n_docs, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder( input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.cpu().detach().to( torch.float32).numpy(), prefix=config.generator.prefix, return_tensors="pt", n_docs=n_docs, ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) # cast retrieved_doc_embeds = retrieved_doc_embeds.to( question_hidden_states) context_input_ids = context_input_ids.to(input_ids) context_attention_mask = context_attention_mask.to(input_ids) # compute doc_scores doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose( 1, 2)).squeeze(1) outputs = model( context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, n_docs=n_docs, ) # logits self.assertEqual( outputs.logits.shape, (n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], n_docs)) def check_model_with_mismatch_n_docs_value(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, retriever_n_docs, generator_n_docs, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder( input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.cpu().detach().to( torch.float32).numpy(), prefix=config.generator.prefix, return_tensors="pt", n_docs=retriever_n_docs, ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) # cast retrieved_doc_embeds = retrieved_doc_embeds.to( question_hidden_states) context_input_ids = context_input_ids.to(input_ids) context_attention_mask = context_attention_mask.to(input_ids) # compute doc_scores doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose( 1, 2)).squeeze(1) self.assertRaises( AssertionError, model.__call__, context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, n_docs=generator_n_docs, ) def check_model_with_encoder_outputs(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes: model = model_class( config, retriever=self.get_retriever(config)).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) outputs = model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) encoder_outputs = BaseModelOutput( outputs.generator_enc_last_hidden_state) # run only generator outputs = model( encoder_outputs=encoder_outputs, doc_scores=outputs.doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def test_model_with_retriever(self): inputs_dict = self.config_and_inputs self.check_model_with_retriever(**inputs_dict) def test_model_without_retriever(self): inputs_dict = self.config_and_inputs self.check_model_without_retriever(**inputs_dict) def test_model_with_encoder_outputs(self): inputs_dict = self.config_and_inputs self.check_model_with_encoder_outputs(**inputs_dict) def test_model_generate(self): inputs_dict = self.config_and_inputs self.check_model_generate(**inputs_dict) def test_model_with_custom_n_docs(self): inputs_dict = self.config_and_inputs inputs_dict["n_docs"] = 1 self.check_model_custom_n_docs(**inputs_dict) def test_model_with_mismatch_n_docs_value(self): inputs_dict = self.config_and_inputs inputs_dict["retriever_n_docs"] = 3 inputs_dict["generator_n_docs"] = 2 self.check_model_with_mismatch_n_docs_value(**inputs_dict)
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES from .test_modeling_bart import ModelTester as BartModelTester from .test_modeling_dpr import DPRModelTester from .test_modeling_t5 import T5ModelTester TOLERANCE = 1e-3 T5_SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") if is_torch_available() and is_datasets_available() and is_faiss_available(): import torch from datasets import Dataset import faiss from transformers import ( AutoConfig, AutoModel, AutoModelForSeq2SeqLM, RagConfig, RagModel, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration, RagTokenizer, )
import numpy as np import pandas as pd from transformers import (AutoModel, AutoTokenizer, BertModel, BertTokenizer, BertTokenizerFast, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, RobertaTokenizer) from transformers.file_utils import is_faiss_available, requires_backends from pyserini.util import (download_encoded_queries, download_prebuilt_index, get_dense_indexes_info) from ._model import AnceEncoder import torch if is_faiss_available(): import faiss class QueryEncoder: def __init__(self, encoded_query_dir: str = None): self.has_model = False self.has_encoded_query = False if encoded_query_dir: self.embedding = self._load_embeddings(encoded_query_dir) self.has_encoded_query = True def encode(self, query: str): return self.embedding[query] @classmethod