def get_dummy_custom_hf_index_pytorch_retriever(self,
                                                 init_retrieval: bool,
                                                 from_disk: bool,
                                                 port=12345):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     else:
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     if init_retrieval:
         retriever.init_retrieval(port)
     return retriever
    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
            input_mask = random_attention_mask([self.batch_size, self.seq_length])

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = DPRConfig(
            projection_dim=self.projection_dim,
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            initializer_range=self.initializer_range,
        )

        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 def get_dummy_ray_distributed_retriever(
         self, init_retrieval: bool) -> RagRayDistributedRetriever:
     # Have to run in local mode because sys.path modifications at top of
     # file are not propogated to remote workers.
     # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
     ray.init(local_mode=True)
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     remote_cls = ray.remote(RayRetriever)
     workers = [remote_cls.remote() for _ in range(1)]
     with patch("transformers.models.rag.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = self.get_dummy_dataset()
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
         )
         if init_retrieval:
             retriever.init_retrieval()
     return retriever
 def load_dpr_model(self):
     model = DPRReader(
         DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR reader from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     state_dict = {}
     for key, value in saved_state.model_dict.items():
         if key.startswith(
                 "encoder.") and not key.startswith("encoder.encode_proj"):
             key = "encoder.bert_model." + key[len("encoder."):]
         state_dict[key] = value
     model.span_predictor.load_state_dict(state_dict)
     return model
Example #5
0
 def get_config(self):
     return DPRConfig(
         projection_dim=self.projection_dim,
         vocab_size=self.vocab_size,
         hidden_size=self.hidden_size,
         num_hidden_layers=self.num_hidden_layers,
         num_attention_heads=self.num_attention_heads,
         intermediate_size=self.intermediate_size,
         hidden_act=self.hidden_act,
         hidden_dropout_prob=self.hidden_dropout_prob,
         attention_probs_dropout_prob=self.attention_probs_dropout_prob,
         max_position_embeddings=self.max_position_embeddings,
         type_vocab_size=self.type_vocab_size,
         initializer_range=self.initializer_range,
     )
 def load_dpr_model(self):
     model = DPRQuestionEncoder(
         DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR biencoder from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     encoder, prefix = model.question_encoder, "question_model."
     state_dict = {}
     for key, value in saved_state.model_dict.items():
         if key.startswith(prefix):
             key = key[len(prefix):]
             if not key.startswith("encode_proj."):
                 key = "bert_model." + key
             state_dict[key] = value
     encoder.load_state_dict(state_dict)
     return model
 def load_dpr_model(self):
     model = DPRReader(
         DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR reader from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
     state_dict = {
         "encoder.bert_model.embeddings.position_ids":
         model.span_predictor.encoder.bert_model.embeddings.position_ids
     }
     for key, value in saved_state.model_dict.items():
         if key.startswith(
                 "encoder.") and not key.startswith("encoder.encode_proj"):
             key = "encoder.bert_model." + key[len("encoder."):]
         state_dict[key] = value
     model.span_predictor.load_state_dict(state_dict)
     return model
 def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool,
                                             from_disk: bool):
     # Have to run in local mode because sys.path modifications at top of
     # file are not propogated to remote workers.
     # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
     ray.init(local_mode=True)
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     remote_cls = ray.remote(RayRetriever)
     workers = [remote_cls.remote() for _ in range(1)]
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
             index=CustomHFIndex.load_from_disk(
                 vector_size=config.retrieval_vector_size,
                 dataset_path=config.passages_path,
                 index_path=config.index_path,
             ),
         )
     else:
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     if init_retrieval:
         retriever.init_retrieval()
     return retriever
    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length],
                               self.vocab_size)

        input_mask = None
        if self.use_input_mask:
            input_mask = ids_tensor(
                [self.batch_size, self.seq_length],
                vocab_size=2)  # follow test_modeling_tf_ctrl.py

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length],
                                        self.type_vocab_size)

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size],
                                         self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length],
                                      self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = BertConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            is_decoder=False,
            initializer_range=self.initializer_range,
            # MODIFY
            return_dict=False,
        )
        config = DPRConfig(projection_dim=self.projection_dim,
                           **config.to_dict())

        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 def get_dummy_pytorch_distributed_retriever(
     self, init_retrieval: bool, port=12345
 ) -> RagPyTorchDistributedRetriever:
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
         if init_retrieval:
             retriever.init_retrieval(port)
     return retriever
 def load_dpr_model(self):
     model = DPRQuestionEncoder(
         DPRConfig(**BertConfig.get_config_dict(
             "bert-base-multilingual-uncased")[0]))
     print("Loading DPR biencoder from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     encoder, prefix = model.question_encoder, "question_model."
     # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
     state_dict = {
         "bert_model.embeddings.position_ids":
         model.question_encoder.bert_model.embeddings.position_ids
     }
     for key, value in saved_state.model_dict.items():
         if key.startswith(prefix):
             key = key[len(prefix):]
             if not key.startswith("encode_proj."):
                 key = "bert_model." + key
             state_dict[key] = value
     encoder.load_state_dict(state_dict)
     return model
Example #12
0
    def __init__(self, hparams, **kwargs):
        # when loading from a pytorch lightning checkpoint, hparams are passed as dict
        if isinstance(hparams, dict):
            hparams = AttrDict(hparams)
        if hparams.model_type == "rag_sequence":
            self.model_class = RagSequenceForGeneration
        elif hparams.model_type == "rag_token":
            self.model_class = RagTokenForGeneration
        elif hparams.model_type == "bart":
            self.model_class = BartForConditionalGeneration
        else:
            self.model_class = T5ForConditionalGeneration
        self.is_rag_model = is_rag_model(hparams.model_type)

        config_class = RagConfig if self.is_rag_model else AutoConfig
        config = config_class.from_pretrained(hparams.model_name_or_path)

        # set retriever parameters
        config.index_name = hparams.index_name or config.index_name
        config.passages_path = hparams.passages_path or config.passages_path
        config.index_path = hparams.index_path or config.index_path
        config.use_dummy_dataset = hparams.use_dummy_dataset

        # set extra_model_params for generator configs and load_model
        extra_model_params = ("encoder_layerdrop", "decoder_layerdrop",
                              "attention_dropout", "dropout")
        if self.is_rag_model:
            if hparams.prefix is not None:
                config.generator.prefix = hparams.prefix
            config.label_smoothing = hparams.label_smoothing
            hparams, config.generator = set_extra_model_params(
                extra_model_params, hparams, config.generator)
            if hparams.distributed_retriever == "ray":
                # The Ray retriever needs the handles to the retriever actors.
                retriever = RagRayDistributedRetriever.from_pretrained(
                    hparams.model_name_or_path,
                    hparams.actor_handles,
                    config=config)

                if hparams.end2end:
                    ctx_encoder_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
                        "facebook/dpr-ctx_encoder-multiset-base")
                    retriever.set_ctx_encoder_tokenizer(ctx_encoder_tokenizer)
            else:
                logger.info(
                    "please use RAY as the distributed retrieval method")

            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config, retriever=retriever)
            if hparams.end2end:
                ctx_encoder = DPRContextEncoder.from_pretrained(
                    hparams.context_encoder_name)
                model.set_context_encoder_for_training(ctx_encoder)
            prefix = config.question_encoder.prefix
        else:
            if hparams.prefix is not None:
                config.prefix = hparams.prefix
            hparams, config = set_extra_model_params(extra_model_params,
                                                     hparams, config)
            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config)
            prefix = config.prefix

        tokenizer = (RagTokenizer.from_pretrained(hparams.model_name_or_path)
                     if self.is_rag_model else AutoTokenizer.from_pretrained(
                         hparams.model_name_or_path))

        self.config_dpr = DPRConfig.from_pretrained(
            hparams.context_encoder_name)
        self.custom_config = hparams
        self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
            hparams.context_encoder_name)

        super().__init__(hparams,
                         config=config,
                         tokenizer=tokenizer,
                         model=model)

        save_git_info(self.hparams.output_dir)
        self.output_dir = Path(self.hparams.output_dir)
        self.dpr_ctx_check_dir = str(Path(
            self.hparams.output_dir)) + "/dpr_ctx_checkpoint"
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        pickle_save(self.hparams, self.hparams_save_path)
        self.step_count = 0
        self.metrics = defaultdict(list)

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            prefix=prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {
            k: v if v >= 0 else None
            for k, v in n_observations_per_split.items()
        }
        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens[
            "val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens[
            "test"], f"target_lens: {self.target_lens}"

        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.distributed_port = self.hparams.distributed_port

        # For single GPU training, init_ddp_connection is not called.
        # So we need to initialize the retrievers here.
        if hparams.gpus <= 1:
            if hparams.distributed_retriever == "ray":
                self.model.retriever.init_retrieval()
            else:
                logger.info(
                    "please use RAY as the distributed retrieval method")

        self.distributed_retriever = hparams.distributed_retriever