def get_dummy_pytorch_distributed_retriever(
         self,
         init_retrieval,
         port=12345) -> RagPyTorchDistributedRetriever:
     dataset = Dataset.from_dict({
         "id": ["0", "1"],
         "text": ["foo", "bar"],
         "title": ["Foo", "Bar"],
         "embeddings": [
             np.ones(self.retrieval_vector_size),
             2 * np.ones(self.retrieval_vector_size)
         ],
     })
     dataset.add_faiss_index("embeddings",
                             string_factory="Flat",
                             metric_type=faiss.METRIC_INNER_PRODUCT)
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
         if init_retrieval:
             retriever.init_retrieval(port)
     return retriever
Exemple #2
0
 def get_dummy_pytorch_distributed_retriever(
     self, init_retrieval: bool, port=12345
 ) -> RagPyTorchDistributedRetriever:
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
         if init_retrieval:
             retriever.init_retrieval(port)
     return retriever
Exemple #3
0
 def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     else:
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     if init_retrieval:
         retriever.init_retrieval(port)
     return retriever
Exemple #4
0
    def __init__(self, hparams, **kwargs):
        # when loading from a pytorch lightning checkpoint, hparams are passed as dict
        if isinstance(hparams, dict):
            hparams = AttrDict(hparams)
        if hparams.model_type == "rag_sequence":
            self.model_class = RagSequenceForGeneration
        elif hparams.model_type == "rag_token":
            self.model_class = RagTokenForGeneration
        elif hparams.model_type == "bart":
            self.model_class = BartForConditionalGeneration
        else:
            self.model_class = T5ForConditionalGeneration
        self.is_rag_model = is_rag_model(hparams.model_type)

        config_class = RagConfig if self.is_rag_model else AutoConfig
        config = config_class.from_pretrained(hparams.model_name_or_path)

        # set retriever parameters
        config.index_name = args.index_name or config.index_name
        config.passages_path = args.passages_path or config.passages_path
        config.index_path = args.index_path or config.index_path

        # set extra_model_params for generator configs and load_model
        extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
        if self.is_rag_model:
            if args.prefix is not None:
                config.generator.prefix = args.prefix
            config.label_smoothing = hparams.label_smoothing
            hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
            retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
            model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
            prefix = config.question_encoder.prefix
        else:
            if args.prefix is not None:
                config.prefix = args.prefix
            hparams, config = set_extra_model_params(extra_model_params, hparams, config)
            model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
            prefix = config.prefix

        tokenizer = (
            RagTokenizer.from_pretrained(hparams.model_name_or_path)
            if self.is_rag_model
            else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
        )

        super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)

        save_git_info(self.hparams.output_dir)
        self.output_dir = Path(self.hparams.output_dir)
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        pickle_save(self.hparams, self.hparams_save_path)
        self.step_count = 0
        self.metrics = defaultdict(list)

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir, max_source_length=self.hparams.max_source_length, prefix=prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.distributed_port = self.hparams.distributed_port