def test_rag_token_greedy_search(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) rag_token = TFRagTokenForGeneration.from_pretrained( "facebook/rag-token-nq", retriever=retriever) # check first two questions input_dict = tokenizer( self.test_data_questions[:2], return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask # make sure only 1 beam is used rag_token.config.num_beams = 1 output_ids = rag_token.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_token_generate_batch(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to( torch_device ) input_dict = tokenizer( self.test_data_questions, return_tensors="pt", padding=True, truncation=True, ) input_ids = input_dict.input_ids.to(torch_device) attention_mask = input_dict.attention_mask.to(torch_device) output_ids = rag_token.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", " amplitude modulation", " stefan persson", " april 20, 2018", " the 1970s", " 7.1. 2", " 13", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_sequence_generate_batch(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True ) rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask output_ids = rag_sequence.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def main(args): model_kwargs = {} if args.model_type is None: args.model_type = infer_model_type(args.model_name_or_path) assert args.model_type is not None if args.model_type.startswith("rag"): model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration model_kwargs["n_docs"] = args.n_docs if args.index_name is not None: model_kwargs["index_name"] = args.index_name if args.index_path is not None: model_kwargs["index_path"] = args.index_path else: model_class = BartForConditionalGeneration checkpoints = ( [f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()] if args.eval_all_checkpoints else [args.model_name_or_path] ) logger.info("Evaluate the following checkpoints: %s", checkpoints) score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval for checkpoint in checkpoints: if os.path.exists(args.predictions_path) and (not args.recalculate): logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path)) score_fn(args, args.predictions_path, args.gold_data_path) continue logger.info("***** Running evaluation for {} *****".format(checkpoint)) logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Predictions will be stored under {}".format(args.predictions_path)) if args.model_type.startswith("rag"): retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs) model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs) model.retriever.init_retrieval() else: model = model_class.from_pretrained(checkpoint, **model_kwargs) model.to(args.device) with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file: questions = [] for line in tqdm(eval_file): questions.append(line.strip()) if len(questions) == args.eval_batch_size: answers = evaluate_batch_fn(args, model, questions) preds_file.write("\n".join(answers) + "\n") preds_file.flush() questions = [] if len(questions) > 0: answers = evaluate_batch_fn(args, model, questions) preds_file.write("\n".join(answers)) preds_file.flush() score_fn(args, args.predictions_path, args.gold_data_path)
def load_model(self) -> None: logger.debug('loading rag retriever: %s', self.name) retriever = RagRetriever.from_pretrained(self.rag_sequence, index_name='custom', indexed_dataset=self.dataset) logger.debug('loading rag model: %s', self.name) self.model = RagSequenceForGeneration.from_pretrained( self.rag_sequence, retriever=retriever)
def test_rag_sequence_generate_batch_from_context_input_ids(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) rag_sequence = RagSequenceForGeneration.from_pretrained( "facebook/rag-sequence-nq", retriever=retriever).to(torch_device) input_dict = tokenizer( self.test_data_questions, return_tensors="pt", padding=True, truncation=True, ) input_ids = input_dict.input_ids.to(torch_device) attention_mask = input_dict.attention_mask.to(torch_device) question_hidden_states = rag_sequence.question_encoder( input_ids, attention_mask=attention_mask)[0] docs_dict = retriever(input_ids.cpu().detach().numpy(), question_hidden_states.cpu().detach().numpy(), return_tensors="pt") doc_scores = torch.bmm( question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].to( torch_device).float().transpose(1, 2), ).squeeze(1) output_ids = rag_sequence.generate( context_input_ids=docs_dict["context_input_ids"].to(torch_device), context_attention_mask=docs_dict["context_attention_mask"].to( torch_device), doc_scores=doc_scores.to(torch_device), do_deduplication=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", " reticular formation", " walls of the abdomen", " spodumene", " obama", " new orleans", " japan", " old trafford", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def __init__(self, **args): super(RagTrainer, self).__init__() self.save_hyperparameters() self.rag_retriever = RagRetriever.from_pretrained( self.hparams['rag_ckpt_path'], index_name='custom', passages_path=self.hparams['wiki_ds_path'], index_path=self.hparams['wiki_index_path']) self.rag = RagSequenceForGeneration.from_pretrained( self.hparams['rag_ckpt_path'], retriever=self.rag_retriever)
def test_rag_sequence_generate_batch_from_context_input_ids(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) rag_sequence = TFRagSequenceForGeneration.from_pretrained( "facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids question_hidden_states = rag_sequence.question_encoder(input_ids)[0] docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") doc_scores = tf.squeeze( tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), docs_dict["retrieved_doc_embeds"], transpose_b=True), axis=[1], ) output_ids = rag_sequence.generate( context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, do_deduplication=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_token_generate_batch(self): # NOTE: gold labels comes from num_beam=4 -- if change gold labels to greedy-generated, test will pass tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) rag_token = TFRagTokenForGeneration.from_pretrained( "facebook/rag-token-nq", retriever=retriever, from_pt=True) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask # rag_token.config.num_beams = 1 -> different in 2 answers (obama, united stadium) to num_beams=4 labels output_ids = rag_token.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", " amplitude modulation", " stefan persson", " april 20, 2018", " the 1970s", " 7.1. 2", " 13", " step by step", " stomach", " spodumene", " obama", " northern new jersey", " india", " united stadium", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_token_generate_batch(self): # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) rag_token = TFRagTokenForGeneration.from_pretrained( "facebook/rag-token-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", " amplitude modulation", " stefan persson", " april 20, 2018", " the 1970s", " 7.1. 2", " 13", ] # Split into 2 batches of 4 examples to avoid GPU OOM. output_ids = rag_token.generate( input_ids[:4], attention_mask=attention_mask[:4], ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(outputs, EXPECTED_OUTPUTS[:4]) output_ids = rag_token.generate( input_ids[4:], attention_mask=attention_mask[4:], ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(outputs, EXPECTED_OUTPUTS[4:])
def test_rag_sequence_generate_batch(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True ) rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to( torch_device ) input_dict = tokenizer( self.test_data_questions, return_tensors="pt", padding=True, truncation=True, ) input_ids = input_dict.input_ids.to(torch_device) attention_mask = input_dict.attention_mask.to(torch_device) output_ids = rag_sequence.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", " reticular formation", " walls of the abdomen", " spodumene", " obama", " grainger's compound", " japan", " old trafford stadium", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_token_generate_batch(self): # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) rag_token = TFRagTokenForGeneration.from_pretrained( "facebook/rag-token-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask output_ids = rag_token.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", " amplitude modulation", " stefan persson", " april 20, 2018", " the 1970s", " 7.1. 2", " 13", " evolution", " stomach", " spodumene", " obama", " northern new jersey", " india", " united stadium", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def main(): global args, best_acc1 args = parser.parse_args() ######################################################################################### # Create options ######################################################################################### options = { 'vqa': { 'trainsplit': args.vqa_trainsplit }, 'logs': { 'dir_logs': args.dir_logs }, 'model': { 'arch': args.arch, 'seq2vec': { 'type': args.st_type, 'dropout': args.st_dropout, 'fixed_emb': args.st_fixed_emb } }, 'optim': { 'lr': args.learning_rate, 'batch_size': args.batch_size, 'epochs': args.epochs } } if args.path_opt is not None: with open(args.path_opt, 'r') as handle: options_yaml = yaml.load(handle) options = utils.update_values(options, options_yaml) print('## args') pprint(vars(args)) print('## options') pprint(options) if args.help_opt: return # Set datasets options if 'vgenome' not in options: options['vgenome'] = None ######################################################################################### # Create needed datasets ######################################################################################### trainset = datasets.factory_VQA(options['vqa']['trainsplit'], options['vqa'], options['coco'], options['vgenome']) train_loader = trainset.data_loader( batch_size=options['optim']['batch_size'], num_workers=args.workers, shuffle=True) if options['vqa']['trainsplit'] == 'train': valset = datasets.factory_VQA('val', options['vqa'], options['coco']) val_loader = valset.data_loader(batch_size=2, num_workers=args.workers) if options['vqa']['trainsplit'] == 'trainval' or args.evaluate: testset = datasets.factory_VQA('test', options['vqa'], options['coco']) test_loader = testset.data_loader( batch_size=options['optim']['batch_size'], num_workers=args.workers) ######################################################################################### # Create model, criterion and optimizer ######################################################################################### config = RagConfig.from_pretrained("facebook/rag-token-nq") config.index_name = "legacy" config.use_dummy_dataset = False config.question_encoder.return_dict = True config.n_docs = 10 # config.n_docs = 15 # import pdb; # pdb.set_trace () if not args.evaluate and not args.resume: tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base", config=config) retriever = RagRetriever.from_pretrained("facebook/rag-token-base", config=config) model = RagTokenForGeneration.from_pretrained( "facebook/rag-token-base", retriever=retriever, config=config) else: tokenizer = RagTokenizer.from_pretrained(os.path.join( options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)), config=config) retriever = RagRetriever.from_pretrained(os.path.join( options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)), config=config) model = RagTokenForGeneration.from_pretrained(os.path.join( options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)), retriever=retriever, config=config) model.cuda() criterion = criterions.factory(options['vqa'], cuda=True) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=options['optim']['lr'], eps=1e-8) # optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=options['optim']['lr'], momentum=0.9) ######################################################################################### # args.resume: resume from a checkpoint OR create logs directory ######################################################################################### exp_logger = None # Or create logs directory # os.system('mkdir -p ' + options['logs']['dir_logs']) path_new_opt = os.path.join(options['logs']['dir_logs'], os.path.basename(args.path_opt)) path_args = os.path.join(options['logs']['dir_logs'], 'args.yaml') with open(path_new_opt, 'w') as f: yaml.dump(options, f, default_flow_style=False) with open(path_args, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if exp_logger is None: # Set loggers exp_name = os.path.basename( options['logs']['dir_logs']) # add timestamp exp_logger = logger.Experiment(exp_name, options) exp_logger.add_meters('train', make_meters()) exp_logger.add_meters('test', make_meters()) if options['vqa']['trainsplit'] == 'train': exp_logger.add_meters('val', make_meters()) exp_logger.info['model_params'] = utils.params_count(model) print('Model has {} parameters'.format( exp_logger.info['model_params'])) ######################################################################################### # args.evaluate: on valset OR/AND on testset ######################################################################################### if args.evaluate: path_logger_json = os.path.join(options['logs']['dir_logs'], 'logger.json') if options['vqa']['trainsplit'] == 'train': acc1, val_results = engine.validate(val_loader, model, retriever, tokenizer, criterion, exp_logger, args.start_epoch, 100) # save results and compute OpenEnd accuracy exp_logger.to_json(path_logger_json) save_results(val_results, args.start_epoch, valset.split_name(), options['logs']['dir_logs'], options['vqa']['dir']) return else: for epoch in range(args.start_epoch + 1, options['optim']['epochs']): engine.train(train_loader, model, retriever, tokenizer, criterion, optimizer, exp_logger, epoch, args.print_freq) # remember best prec@1 and save checkpoint is_best = True best_accs1 = -1 save_checkpoint( { 'epoch': epoch, 'arch': options['model']['arch'], 'best_acc1': best_acc1, 'exp_logger': exp_logger }, model, tokenizer, retriever, options['logs']['dir_logs'], args.save_model, True)
def main( rag_example_args: "RagExampleArguments", processing_args: "ProcessingArguments", index_hnsw_args: "IndexHnswArguments", ): ###################################### logger.info("Step 1 - Create the dataset") ###################################### # The dataset needed for RAG must have three columns: # - title (string): title of the document # - text (string): text of a passage of the document # - embeddings (array of dimension d): DPR representation of the passage # Let's say you have documents in tab-separated csv files with columns "title" and "text" assert os.path.isfile( rag_example_args.csv_path), "Please provide a valid path to a csv file" # You can load a Dataset object this way dataset = load_dataset("csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]) # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files # Then split the documents into passages of 100 words dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc) # And compute the embeddings ctx_encoder = DPRContextEncoder.from_pretrained( rag_example_args.dpr_ctx_encoder_model_name).to(device=device) ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( rag_example_args.dpr_ctx_encoder_model_name) dataset = dataset.map( partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), batched=True, batch_size=processing_args.batch_size, ) # And finally save your dataset passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset") dataset.save_to_disk(passages_path) # from datasets import load_from_disk # dataset = load_from_disk(passages_path) # to reload the dataset ###################################### logger.info("Step 2 - Index the dataset") ###################################### # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT) dataset.add_faiss_index("embeddings", custom_index=index) # And save the index index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss") dataset.get_index("embeddings").save(index_path) # dataset.load_faiss_index("embeddings", index_path) # to reload the index ###################################### logger.info("Step 3 - Load RAG") ###################################### # Easy way to load the model retriever = RagRetriever.from_pretrained(rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset) model = RagSequenceForGeneration.from_pretrained( rag_example_args.rag_model_name, retriever=retriever) tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name) # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately. # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path) ###################################### logger.info("Step 4 - Have fun") ###################################### question = rag_example_args.question or "What does Moses' rod turn into ?" input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"] generated = model.generate(input_ids) generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0] logger.info("Q: " + question) logger.info("A: " + generated_string)
k: v.to(args.device) for k, v in input_batch.items() } logits = model(**input_batch) logits = logits[0] attention = input_batch["attention_mask"] argmax = [ l[a == 1].softmax(1).max(1) for l, a in zip(logits, attention) ] preds = [idx[val >= args.thresh] for val, idx in argmax] if __name__ == "__main__": tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base") trainset = RAGEDataset(args, tokenizer) retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base") model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever) lit_rage = LitRage(args, trainset, model) trainloader = DataLoader( trainset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=trainset.collate, ) checkpoint = ModelCheckpoint( filepath=os.path.join(args.output, "{epoch:02d}-{global_step:02d}-{val_loss:.2f}"),
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, RagConfig tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) model = RagSequenceForGeneration.from_pretrained( "facebook/rag-sequence-nq", retriever=retriever).to("cuda:0") model.add_tokens() model.config.n_docs = 6 retriever.config.n_docs = 6 model.config.n_docs_splits = 3 retriever.config.n_docs_splits = 3 model.skip_ec = True model.skip_ec = True input_dict = tokenizer.prepare_seq2seq_batch("am i a cool person", return_tensors="pt").to("cuda:0") generated = model.generate(**input_dict, extra_context=["Cats are cool animals!"], num_beams=4) print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) # should give 54 => google says either 44 or 51
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration, BartForConditionalGeneration, RagConfig, DPRQuestionEncoder import torch from transformers.models.auto import AutoModel config = RagConfig.from_pretrained ("facebook/rag-token-nq") config.index_name = "legacy" config.use_dummy_dataset = False config.question_encoder.return_dict = True print("==> load tokenizer") tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") print("==> load retriever") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", config=config) print("dataset info") print(dir(retriever.index)) print("==> load generator") # question encoder # question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base') # generator = BartForConditionalGeneration.from_pretrained('facebook/bart-large') # config = RagConfig.from_question_encoder_generator_configs(question_encoder.config, generator.config) # model = RagTokenForGeneration(config, question_encoder=question_encoder,generator=generator, retriever=retriever) model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) # input_dict = tokenizer.prepare_seq2seq_batch("USA president in 1999?", return_tensors="pt") input_dict = tokenizer.prepare_seq2seq_batch("What kind of vehicle uses fire hydrant?", return_tensors="pt") # input_dict = tokenizer.prepare_seq2seq_batch("what phylum does cat belong to?", return_tensors="pt") print(input_dict.keys()) # dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels']) input_ids = input_dict['input_ids'] print("==> encode") question_hidden_states = model.question_encoder(input_ids)[0]
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", dataset="wiki_dpr", index_name='compressed') model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer.prepare_seq2seq_batch( "how many countries are in europe", return_tensors="pt") generated = model.generate(input_ids=input_dict["input_ids"]) print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])