def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model_path", type=str, default="/dccstor/dialog/sfeng/transformers_doc2dial/checkpoints/colbert-converted-60000/question_encoder/", ) parser.add_argument( "--out_path", type=str, default="tmp", ) parser.add_argument( "--index_name", type=str, default="exact", ) args = parser.parse_args() model = RagTokenForGeneration.from_pretrained_question_encoder_generator(args.model_path, "facebook/bart-large") question_encoder_tokenizer = AutoTokenizer.from_pretrained(args.model_path) generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer) model.config.use_dummy_dataset = True model.config.index_name = args.index_name retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer) model.save_pretrained(args.out_path) tokenizer.save_pretrained(args.out_path) retriever.save_pretrained(args.out_path)
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer model = RagTokenForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large") question_encoder_tokenizer = AutoTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer) model.config.use_dummy_dataset = True model.config.index_name = "exact" retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer) model.save_pretrained("./") tokenizer.save_pretrained("./") retriever.save_pretrained("./")