Exemple #1
0
 def token_model(self):
     return (
         RagTokenForGeneration.from_pretrained_question_encoder_generator(
             "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
         )
         .to(torch_device)
         .eval()
     )
    def test_rag_token_from_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="pt"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)
        decoder_input_ids = decoder_input_ids.to(torch_device)

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_token = RagTokenForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            ).to(torch_device)
            # check that the from pretrained methods work
            rag_token.save_pretrained(tmp_dirname)
            rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
            rag_token.to(torch_device)

            with torch.no_grad():
                output = rag_token(
                    input_ids,
                    labels=decoder_input_ids,
                )

            loss_pretrained = output.loss
            del rag_token

        question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
        rag_token = RagTokenForGeneration(
            config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
        )
        rag_token.to(torch_device)

        with torch.no_grad():
            output = rag_token(
                input_ids,
                labels=decoder_input_ids,
            )

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)
Exemple #3
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--model_path",
        type=str,
        default="/dccstor/dialog/sfeng/transformers_doc2dial/checkpoints/colbert-converted-60000/question_encoder/",
    )

    parser.add_argument(
        "--out_path",
        type=str,
        default="tmp",
    )

    parser.add_argument(
        "--index_name",
        type=str,
        default="exact",
    )

    args = parser.parse_args()

    model = RagTokenForGeneration.from_pretrained_question_encoder_generator(args.model_path, "facebook/bart-large")

    question_encoder_tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

    tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
    model.config.use_dummy_dataset = True
    model.config.index_name = args.index_name
    retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)

    model.save_pretrained(args.out_path)
    tokenizer.save_pretrained(args.out_path)
    retriever.save_pretrained(args.out_path)
Exemple #4
0
def get_rag_generator_components(args, inference_only: bool = False, **kwargs):

    # tokenizer
    tensorizer = get_rag_tensorizer(args)

    # generator
    dropout = args.dropout if hasattr(args, 'dropout') else 0.0
    rag_config = RagConfig.from_pretrained("facebook/rag-token-nq")
    if dropout != 0:
        rag_config.attention_probs_dropout_prob = dropout
        rag_config.hidden_dropout_prob = dropout

    # facebook/rag-token-nq
    # rag = RagTokenForGeneration.from_pretrained(args.pretrained_model_cfg, config=rag_config, use_dummy_dataset=True)

    # customize rag generator/question_encoder  
    # Notice: question_encoder not required.
    generator_name_or_path = args.pretrained_model_cfg
    question_encoder_name_or_path = generator_name_or_path
    gen_config =  AutoConfig.from_pretrained(generator_name_or_path)
    question_encoder_config = AutoConfig.from_pretrained(question_encoder_name_or_path)
    rag_config.generator = gen_config
    rag_config.question_encoder = question_encoder_config

    rag = RagTokenForGeneration.from_pretrained_question_encoder_generator(
        question_encoder_name_or_path, generator_name_or_path, config=rag_config, dummy_dataset=True
    )

    generator = Generator(rag, tensorizer)

    # optimizer
    optimizer = get_optimizer(generator,
                              learning_rate=args.learning_rate,
                              adam_eps=args.adam_eps, weight_decay=args.weight_decay,
                              ) if not inference_only else None

    return tensorizer, generator, optimizer
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer

model = RagTokenForGeneration.from_pretrained_question_encoder_generator(
    "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")

question_encoder_tokenizer = AutoTokenizer.from_pretrained(
    "facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = True
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer,
                         generator_tokenizer)

model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")