def test_rag_sequence_from_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="pt" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids input_ids = input_ids.to(torch_device) decoder_input_ids = decoder_input_ids.to(torch_device) with tempfile.TemporaryDirectory() as tmp_dirname: rag_sequence = RagSequenceForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, ).to(torch_device) # check that the from pretrained methods work rag_sequence.save_pretrained(tmp_dirname) rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever) rag_sequence.to(torch_device) with torch.no_grad(): output = rag_sequence( input_ids, labels=decoder_input_ids, ) loss_pretrained = output.loss del rag_sequence question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") rag_sequence = RagSequenceForGeneration( config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever ) rag_sequence.to(torch_device) with torch.no_grad(): output = rag_sequence( input_ids, labels=decoder_input_ids, ) loss_init = output.loss self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)
def load_model(self) -> None: logger.debug('loading rag retriever: %s', self.name) retriever = RagRetriever.from_pretrained(self.rag_sequence, index_name='custom', indexed_dataset=self.dataset) logger.debug('loading rag model: %s', self.name) self.model = RagSequenceForGeneration.from_pretrained( self.rag_sequence, retriever=retriever)
def sequence_model(self): return ( RagSequenceForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn" ) .to(torch_device) .eval() )
def test_rag_sequence_generate_batch_from_context_input_ids(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) rag_sequence = RagSequenceForGeneration.from_pretrained( "facebook/rag-sequence-nq", retriever=retriever).to(torch_device) input_dict = tokenizer( self.test_data_questions, return_tensors="pt", padding=True, truncation=True, ) input_ids = input_dict.input_ids.to(torch_device) attention_mask = input_dict.attention_mask.to(torch_device) question_hidden_states = rag_sequence.question_encoder( input_ids, attention_mask=attention_mask)[0] docs_dict = retriever(input_ids.cpu().detach().numpy(), question_hidden_states.cpu().detach().numpy(), return_tensors="pt") doc_scores = torch.bmm( question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].to( torch_device).float().transpose(1, 2), ).squeeze(1) output_ids = rag_sequence.generate( context_input_ids=docs_dict["context_input_ids"].to(torch_device), context_attention_mask=docs_dict["context_attention_mask"].to( torch_device), doc_scores=doc_scores.to(torch_device), do_deduplication=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", " reticular formation", " walls of the abdomen", " spodumene", " obama", " new orleans", " japan", " old trafford", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def __init__(self, **args): super(RagTrainer, self).__init__() self.save_hyperparameters() self.rag_retriever = RagRetriever.from_pretrained( self.hparams['rag_ckpt_path'], index_name='custom', passages_path=self.hparams['wiki_ds_path'], index_path=self.hparams['wiki_index_path']) self.rag = RagSequenceForGeneration.from_pretrained( self.hparams['rag_ckpt_path'], retriever=self.rag_retriever)
def test_rag_sequence_generate_batch(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) rag_sequence = RagSequenceForGeneration.from_pretrained( "facebook/rag-sequence-nq", retriever=retriever).to(torch_device) input_dict = tokenizer( self.test_data_questions, return_tensors="pt", padding=True, truncation=True, ) input_ids = input_dict.input_ids.to(torch_device) attention_mask = input_dict.attention_mask.to(torch_device) output_ids = rag_sequence.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", " reticular formation", " walls of the abdomen", " spodumene", " obama", " new orleans", " japan", " old trafford", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def main( rag_example_args: "RagExampleArguments", processing_args: "ProcessingArguments", index_hnsw_args: "IndexHnswArguments", ): ###################################### logger.info("Step 1 - Create the dataset") ###################################### # The dataset needed for RAG must have three columns: # - title (string): title of the document # - text (string): text of a passage of the document # - embeddings (array of dimension d): DPR representation of the passage # Let's say you have documents in tab-separated csv files with columns "title" and "text" assert os.path.isfile( rag_example_args.csv_path), "Please provide a valid path to a csv file" # You can load a Dataset object this way dataset = load_dataset("csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]) # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files # Then split the documents into passages of 100 words dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc) # And compute the embeddings ctx_encoder = DPRContextEncoder.from_pretrained( rag_example_args.dpr_ctx_encoder_model_name).to(device=device) ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( rag_example_args.dpr_ctx_encoder_model_name) dataset = dataset.map( partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), batched=True, batch_size=processing_args.batch_size, ) # And finally save your dataset passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset") dataset.save_to_disk(passages_path) # from datasets import load_from_disk # dataset = load_from_disk(passages_path) # to reload the dataset ###################################### logger.info("Step 2 - Index the dataset") ###################################### # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT) dataset.add_faiss_index("embeddings", custom_index=index) # And save the index index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss") dataset.get_index("embeddings").save(index_path) # dataset.load_faiss_index("embeddings", index_path) # to reload the index ###################################### logger.info("Step 3 - Load RAG") ###################################### # Easy way to load the model retriever = RagRetriever.from_pretrained(rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset) model = RagSequenceForGeneration.from_pretrained( rag_example_args.rag_model_name, retriever=retriever) tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name) # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately. # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path) ###################################### logger.info("Step 4 - Have fun") ###################################### question = rag_example_args.question or "What does Moses' rod turn into ?" input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"] generated = model.generate(input_ids) generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0] logger.info("Q: " + question) logger.info("A: " + generated_string)
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, RagConfig tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) model = RagSequenceForGeneration.from_pretrained( "facebook/rag-sequence-nq", retriever=retriever).to("cuda:0") model.add_tokens() model.config.n_docs = 6 retriever.config.n_docs = 6 model.config.n_docs_splits = 3 retriever.config.n_docs_splits = 3 model.skip_ec = True model.skip_ec = True input_dict = tokenizer.prepare_seq2seq_batch("am i a cool person", return_tensors="pt").to("cuda:0") generated = model.generate(**input_dict, extra_context=["Cats are cool animals!"], num_beams=4) print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) # should give 54 => google says either 44 or 51