Beispiel #1
0
    def config_and_inputs(self):
        question_encoder_tester = DPRModelTester(self)
        dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs(
        )
        generator_tester = T5ModelTester(self, vocab_size=1100, n_positions=30)
        t5_config_and_inputs = generator_tester.prepare_config_and_inputs()

        (question_encoder_config, input_ids, _, input_mask, _, _,
         _) = dpr_config_and_inputs
        (generator_config, _, decoder_input_ids, _, decoder_attention_mask,
         _) = t5_config_and_inputs
        config = RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            n_docs=self.n_docs,
            retrieval_vector_size=self.retrieval_vector_size,
            max_combined_length=self.max_combined_length,
            use_cache=False,
        )

        return {
            "config": config,
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }
Beispiel #2
0
 def get_rag_config(self):
     question_encoder_config = AutoConfig.from_pretrained(
         "facebook/dpr-question_encoder-single-nq-base")
     generator_config = AutoConfig.from_pretrained(
         "facebook/bart-large-cnn")
     return RagConfig.from_question_encoder_generator_configs(
         question_encoder_config,
         generator_config,
         bos_token_id=0,
         decoder_start_token_id=2,
         eos_token_id=2,
         is_encoder_decoder=True,
         pad_token_id=1,
         vocab_size=50264,
         title_sep=" / ",
         doc_sep=" // ",
         n_docs=5,
         max_combined_length=300,
         dataset="wiki_dpr",
         dataset_split="train",
         index_name="exact",
         index_path=None,
         use_dummy_dataset=True,
         retrieval_vector_size=768,
         retrieval_batch_size=8,
     )
Beispiel #3
0
    def config_and_inputs(self):
        question_encoder_tester = DPRModelTester(self)
        dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs(
        )
        generator_tester = BartModelTester(self)
        bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common(
        )

        (question_encoder_config, input_ids, _, input_mask, _, _,
         _) = dpr_config_and_inputs
        (generator_config, bart_inputs_dict) = bart_config_and_inputs
        decoder_input_ids, decoder_attention_mask = bart_inputs_dict[
            "input_ids"], bart_inputs_dict["attention_mask"]

        config = RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            n_docs=self.n_docs,
            retrieval_vector_size=self.retrieval_vector_size,
            max_combined_length=self.max_combined_length,
            use_cache=False,
        )

        return {
            "config": config,
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }