Ejemplo n.º 1
0
    def test_rag_sequence_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        )
        rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        output_ids = rag_sequence.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 2
0
    def __init__(
            self,
            model_name_or_path: str = "facebook/rag-token-nq",
            retriever: Optional[DensePassageRetriever] = None,
            generator_type: RAGeneratorType = RAGeneratorType.TOKEN,
            top_k_answers: int = 2,
            max_length: int = 200,
            min_length: int = 2,
            num_beams: int = 2,
            embed_title: bool = True,
            prefix: Optional[str] = None,
            use_gpu: bool = True,
    ):
        """
        Load a RAG model from Transformers along with passage_embedding_model.
        See https://huggingface.co/transformers/model_doc/rag.html for more details

        :param model_name_or_path: Directory of a saved model or the name of a public model e.g.
                                   'facebook/rag-token-nq', 'facebook/rag-sequence-nq'.
                                   See https://huggingface.co/models for full list of available models.
        :param retriever: `DensePassageRetriever` used to embedded passage
        :param generator_type: Which RAG generator implementation to use? RAG-TOKEN or RAG-SEQUENCE
        :param top_k_answers: Number of independently generated text to return
        :param max_length: Maximum length of generated text
        :param min_length: Minimum length of generated text
        :param num_beams: Number of beams for beam search. 1 means no beam search.
        :param embed_title: Embedded the title of passage while generating embedding
        :param prefix: The prefix used by the generator's tokenizer.
        :param use_gpu: Whether to use GPU (if available)
        """

        self.model_name_or_path = model_name_or_path
        self.max_length = max_length
        self.min_length = min_length
        self.generator_type = generator_type
        self.num_beams = num_beams
        self.embed_title = embed_title
        self.prefix = prefix
        self.retriever = retriever

        if top_k_answers > self.num_beams:
            top_k_answers = self.num_beams
            logger.warning(f'top_k_answers value should not be greater than num_beams, hence setting it to {num_beams}')

        self.top_k_answers = top_k_answers

        if use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.tokenizer = RagTokenizer.from_pretrained(model_name_or_path)

        if self.generator_type == RAGeneratorType.SEQUENCE:
            raise NotImplementedError("RagSequenceForGeneration is not implemented yet")
            # TODO: Enable when transformers have it. Refer https://github.com/huggingface/transformers/issues/7905
            # Also refer refer https://github.com/huggingface/transformers/issues/7829
            # self.model = RagSequenceForGeneration.from_pretrained(model_name_or_path)
        else:
            self.model = RagTokenForGeneration.from_pretrained(model_name_or_path).to(self.device)
    def test_rag_token_greedy_search(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_token = TFRagTokenForGeneration.from_pretrained(
            "facebook/rag-token-nq", retriever=retriever, from_pt=True)

        # check first two questions
        input_dict = tokenizer(
            self.test_data_questions[:2],
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        # make sure only 1 beam is used
        rag_token.config.num_beams = 1

        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 4
0
    def test_rag_token_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
        rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
            torch_device
        )

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids.to(torch_device)
        attention_mask = input_dict.attention_mask.to(torch_device)

        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
 def from_pretrained(cls,
                     retriever_name_or_path,
                     actor_handles,
                     indexed_dataset=None,
                     **kwargs):
     requires_datasets(cls)
     requires_faiss(cls)
     config = kwargs.pop("config", None) or RagConfig.from_pretrained(
         retriever_name_or_path, **kwargs)
     rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path,
                                                  config=config)
     question_encoder_tokenizer = rag_tokenizer.question_encoder
     generator_tokenizer = rag_tokenizer.generator
     if indexed_dataset is not None:
         config.index_name = "custom"
         index = CustomHFIndex(config.retrieval_vector_size,
                               indexed_dataset)
     else:
         index = cls._build_index(config)
     return cls(
         config,
         question_encoder_tokenizer=question_encoder_tokenizer,
         generator_tokenizer=generator_tokenizer,
         retrieval_workers=actor_handles,
         index=index,
     )
    def test_rag_sequence_generate_batch_from_context_input_ids(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_sequence = RagSequenceForGeneration.from_pretrained(
            "facebook/rag-sequence-nq", retriever=retriever).to(torch_device)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids.to(torch_device)
        attention_mask = input_dict.attention_mask.to(torch_device)

        question_hidden_states = rag_sequence.question_encoder(
            input_ids, attention_mask=attention_mask)[0]
        docs_dict = retriever(input_ids.cpu().detach().numpy(),
                              question_hidden_states.cpu().detach().numpy(),
                              return_tensors="pt")
        doc_scores = torch.bmm(
            question_hidden_states.unsqueeze(1),
            docs_dict["retrieved_doc_embeds"].to(
                torch_device).float().transpose(1, 2),
        ).squeeze(1)

        output_ids = rag_sequence.generate(
            context_input_ids=docs_dict["context_input_ids"].to(torch_device),
            context_attention_mask=docs_dict["context_attention_mask"].to(
                torch_device),
            doc_scores=doc_scores.to(torch_device),
            do_deduplication=True,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
            " reticular formation",
            " walls of the abdomen",
            " spodumene",
            " obama",
            " new orleans",
            " japan",
            " old trafford",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 7
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--model_path",
        type=str,
        default="/dccstor/dialog/sfeng/transformers_doc2dial/checkpoints/colbert-converted-60000/question_encoder/",
    )

    parser.add_argument(
        "--out_path",
        type=str,
        default="tmp",
    )

    parser.add_argument(
        "--index_name",
        type=str,
        default="exact",
    )

    args = parser.parse_args()

    model = RagTokenForGeneration.from_pretrained_question_encoder_generator(args.model_path, "facebook/bart-large")

    question_encoder_tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

    tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
    model.config.use_dummy_dataset = True
    model.config.index_name = args.index_name
    retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)

    model.save_pretrained(args.out_path)
    tokenizer.save_pretrained(args.out_path)
    retriever.save_pretrained(args.out_path)
    def test_rag_sequence_generate_batch_from_context_input_ids(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_sequence = TFRagSequenceForGeneration.from_pretrained(
            "facebook/rag-sequence-nq", retriever=retriever)
        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids

        question_hidden_states = rag_sequence.question_encoder(input_ids)[0]
        docs_dict = retriever(input_ids.numpy(),
                              question_hidden_states.numpy(),
                              return_tensors="tf")
        doc_scores = tf.squeeze(
            tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]),
                      docs_dict["retrieved_doc_embeds"],
                      transpose_b=True),
            axis=[1],
        )
        output_ids = rag_sequence.generate(
            context_input_ids=docs_dict["context_input_ids"],
            context_attention_mask=docs_dict["context_attention_mask"],
            doc_scores=doc_scores,
            do_deduplication=True,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
    def test_rag_token_generate_batch(self):
        # NOTE: gold labels comes from num_beam=4 -- if change gold labels to greedy-generated, test will pass
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_token = TFRagTokenForGeneration.from_pretrained(
            "facebook/rag-token-nq", retriever=retriever, from_pt=True)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        #         rag_token.config.num_beams = 1 -> different in 2 answers (obama, united stadium) to num_beams=4 labels
        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
            " step by step",
            " stomach",
            " spodumene",
            " obama",
            " northern new jersey",
            " india",
            " united stadium",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
    def test_rag_token_generate_batch(self):
        # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_token = TFRagTokenForGeneration.from_pretrained(
            "facebook/rag-token-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
        ]

        # Split into 2 batches of 4 examples to avoid GPU OOM.
        output_ids = rag_token.generate(
            input_ids[:4],
            attention_mask=attention_mask[:4],
        )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(outputs, EXPECTED_OUTPUTS[:4])

        output_ids = rag_token.generate(
            input_ids[4:],
            attention_mask=attention_mask[4:],
        )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(outputs, EXPECTED_OUTPUTS[4:])
Ejemplo n.º 11
0
    def test_rag_token_generate_batch(self):
        # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_token = TFRagTokenForGeneration.from_pretrained(
            "facebook/rag-token-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
            " evolution",
            " stomach",
            " spodumene",
            " obama",
            " northern new jersey",
            " india",
            " united stadium",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 12
0
    def test_rag_sequence_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        )
        rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
            torch_device
        )

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids.to(torch_device)
        attention_mask = input_dict.attention_mask.to(torch_device)

        output_ids = rag_sequence.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
            " reticular formation",
            " walls of the abdomen",
            " spodumene",
            " obama",
            " grainger's compound",
            " japan",
            " old trafford stadium",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 13
0
    def __init__(
        self,
        model_name_or_path: str = "facebook/rag-token-nq",
        model_version: Optional[str] = None,
        retriever: Optional[DensePassageRetriever] = None,
        generator_type: RAGeneratorType = RAGeneratorType.TOKEN,
        top_k: int = 2,
        max_length: int = 200,
        min_length: int = 2,
        num_beams: int = 2,
        embed_title: bool = True,
        prefix: Optional[str] = None,
        use_gpu: bool = True,
    ):
        """
        Load a RAG model from Transformers along with passage_embedding_model.
        See https://huggingface.co/transformers/model_doc/rag.html for more details

        :param model_name_or_path: Directory of a saved model or the name of a public model e.g.
                                   'facebook/rag-token-nq', 'facebook/rag-sequence-nq'.
                                   See https://huggingface.co/models for full list of available models.
        :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :param retriever: `DensePassageRetriever` used to embedded passages for the docs passed to `predict()`. This is optional and is only needed if the docs you pass don't already contain embeddings in `Document.embedding`.   
        :param generator_type: Which RAG generator implementation to use? RAG-TOKEN or RAG-SEQUENCE
        :param top_k: Number of independently generated text to return
        :param max_length: Maximum length of generated text
        :param min_length: Minimum length of generated text
        :param num_beams: Number of beams for beam search. 1 means no beam search.
        :param embed_title: Embedded the title of passage while generating embedding
        :param prefix: The prefix used by the generator's tokenizer.
        :param use_gpu: Whether to use GPU (if available)
        """

        # save init parameters to enable export of component config as YAML
        self.set_config(
            model_name_or_path=model_name_or_path,
            model_version=model_version,
            retriever=retriever,
            generator_type=generator_type,
            top_k=top_k,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            embed_title=embed_title,
            prefix=prefix,
            use_gpu=use_gpu,
        )

        self.model_name_or_path = model_name_or_path
        self.max_length = max_length
        self.min_length = min_length
        self.generator_type = generator_type
        self.num_beams = num_beams
        self.embed_title = embed_title
        self.prefix = prefix
        self.retriever = retriever

        if top_k > self.num_beams:
            top_k = self.num_beams
            logger.warning(
                f'top_k value should not be greater than num_beams, hence setting it to {num_beams}'
            )

        self.top_k = top_k

        self.device, _ = initialize_device_settings(use_cuda=use_gpu)

        self.tokenizer = RagTokenizer.from_pretrained(model_name_or_path)

        if self.generator_type == RAGeneratorType.SEQUENCE:
            raise NotImplementedError(
                "RagSequenceForGeneration is not implemented yet")
            # TODO: Enable when transformers have it. Refer https://github.com/huggingface/transformers/issues/7905
            # Also refer refer https://github.com/huggingface/transformers/issues/7829
            # self.model = RagSequenceForGeneration.from_pretrained(model_name_or_path)
        else:
            self.model = RagTokenForGeneration.from_pretrained(
                model_name_or_path, revision=model_version).to(self.device)
Ejemplo n.º 14
0
    def __init__(self, hparams, **kwargs):
        # when loading from a pytorch lightning checkpoint, hparams are passed as dict
        if isinstance(hparams, dict):
            hparams = AttrDict(hparams)
        if hparams.model_type == "rag_sequence":
            self.model_class = RagSequenceForGeneration
        elif hparams.model_type == "rag_token":
            self.model_class = RagTokenForGeneration
        elif hparams.model_type == "bart":
            self.model_class = BartForConditionalGeneration
        else:
            self.model_class = T5ForConditionalGeneration
        self.is_rag_model = is_rag_model(hparams.model_type)

        config_class = RagConfig if self.is_rag_model else AutoConfig
        config = config_class.from_pretrained(hparams.model_name_or_path)

        # set retriever parameters
        config.index_name = hparams.index_name or config.index_name
        config.passages_path = hparams.passages_path or config.passages_path
        config.index_path = hparams.index_path or config.index_path
        config.use_dummy_dataset = hparams.use_dummy_dataset

        # set extra_model_params for generator configs and load_model
        extra_model_params = ("encoder_layerdrop", "decoder_layerdrop",
                              "attention_dropout", "dropout")
        if self.is_rag_model:
            if hparams.prefix is not None:
                config.generator.prefix = hparams.prefix
            config.label_smoothing = hparams.label_smoothing
            hparams, config.generator = set_extra_model_params(
                extra_model_params, hparams, config.generator)
            if hparams.distributed_retriever == "ray":
                # The Ray retriever needs the handles to the retriever actors.
                retriever = RagRayDistributedRetriever.from_pretrained(
                    hparams.model_name_or_path,
                    hparams.actor_handles,
                    config=config)

                if hparams.end2end:
                    ctx_encoder_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
                        "facebook/dpr-ctx_encoder-multiset-base")
                    retriever.set_ctx_encoder_tokenizer(ctx_encoder_tokenizer)
            else:
                logger.info(
                    "please use RAY as the distributed retrieval method")

            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config, retriever=retriever)
            if hparams.end2end:
                ctx_encoder = DPRContextEncoder.from_pretrained(
                    hparams.context_encoder_name)
                model.set_context_encoder_for_training(ctx_encoder)
            prefix = config.question_encoder.prefix
        else:
            if hparams.prefix is not None:
                config.prefix = hparams.prefix
            hparams, config = set_extra_model_params(extra_model_params,
                                                     hparams, config)
            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config)
            prefix = config.prefix

        tokenizer = (RagTokenizer.from_pretrained(hparams.model_name_or_path)
                     if self.is_rag_model else AutoTokenizer.from_pretrained(
                         hparams.model_name_or_path))

        self.config_dpr = DPRConfig.from_pretrained(
            hparams.context_encoder_name)
        self.custom_config = hparams
        self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
            hparams.context_encoder_name)

        super().__init__(hparams,
                         config=config,
                         tokenizer=tokenizer,
                         model=model)

        save_git_info(self.hparams.output_dir)
        self.output_dir = Path(self.hparams.output_dir)
        self.dpr_ctx_check_dir = str(Path(
            self.hparams.output_dir)) + "/dpr_ctx_checkpoint"
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        pickle_save(self.hparams, self.hparams_save_path)
        self.step_count = 0
        self.metrics = defaultdict(list)

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            prefix=prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {
            k: v if v >= 0 else None
            for k, v in n_observations_per_split.items()
        }
        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens[
            "val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens[
            "test"], f"target_lens: {self.target_lens}"

        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.distributed_port = self.hparams.distributed_port

        # For single GPU training, init_ddp_connection is not called.
        # So we need to initialize the retrievers here.
        if hparams.gpus <= 1:
            if hparams.distributed_retriever == "ray":
                self.model.retriever.init_retrieval()
            else:
                logger.info(
                    "please use RAY as the distributed retrieval method")

        self.distributed_retriever = hparams.distributed_retriever
Ejemplo n.º 15
0
    def __init__(self, hparams, **kwargs):
        # when loading from a pytorch lightning checkpoint, hparams are passed as dict
        if isinstance(hparams, dict):
            hparams = AttrDict(hparams)
        if hparams.model_type == "rag_sequence":
            self.model_class = RagSequenceForGeneration
        elif hparams.model_type == "rag_token":
            self.model_class = RagTokenForGeneration
        elif hparams.model_type == "bart":
            self.model_class = BartForConditionalGeneration
        else:
            self.model_class = T5ForConditionalGeneration
        self.is_rag_model = is_rag_model(hparams.model_type)

        config_class = RagConfig if self.is_rag_model else AutoConfig
        config = config_class.from_pretrained(hparams.model_name_or_path)

        # set retriever parameters
        config.n_docs = hparams.n_docs
        config.do_marginalize = hparams.do_marginalize or config.do_marginalize
        config.scoring_func = hparams.scoring_func or config.scoring_func
        logger.info("Using scoring function - {}".format(config.scoring_func))
        config.segmentation = hparams.segmentation or config.segmentation
        config.max_combined_length = hparams.max_combined_length or config.max_combined_length
        config.max_source_length = hparams.max_source_length or config.max_source_length
        config.index_name = hparams.index_name or config.index_name
        config.passages_path = hparams.passages_path or config.passages_path
        config.index_path = hparams.index_path or config.index_path
        config.use_dummy_dataset = hparams.use_dummy_dataset

        if hparams.bm25:
            # hparams.bm25 = load_bm25_results(hparams.bm25)
            bm25 = load_bm25(hparams.bm25)
            config.bm25 = hparams.bm25
        else:
            bm25 = None

        # set extra_model_params for generator configs and load_model
        extra_model_params = ("encoder_layerdrop", "decoder_layerdrop",
                              "attention_dropout", "dropout")
        if self.is_rag_model:
            if hparams.prefix is not None:
                config.generator.prefix = hparams.prefix
            config.label_smoothing = hparams.label_smoothing
            hparams, config.generator = set_extra_model_params(
                extra_model_params, hparams, config.generator)
            if hparams.distributed_retriever == "pytorch":
                # pdb.set_trace()
                retriever = RagPyTorchDistributedRetriever.from_pretrained(
                    hparams.model_name_or_path, config=config)
            elif hparams.distributed_retriever == "ray":
                # The Ray retriever needs the handles to the retriever actors.
                retriever = RagRayDistributedRetriever.from_pretrained(
                    hparams.model_name_or_path,
                    hparams.actor_handles,
                    config=config)
            model = self.model_class.from_pretrained(
                hparams.model_name_or_path,
                config=config,
                retriever=retriever,
                bm25=bm25)
            prefix = config.question_encoder.prefix
            model.bm25 = bm25
        else:
            if hparams.prefix is not None:
                config.prefix = hparams.prefix
            hparams, config = set_extra_model_params(extra_model_params,
                                                     hparams, config)
            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config)
            prefix = config.prefix

        tokenizer = (RagTokenizer.from_pretrained(hparams.model_name_or_path)
                     if self.is_rag_model else AutoTokenizer.from_pretrained(
                         hparams.model_name_or_path))

        super().__init__(hparams,
                         config=config,
                         tokenizer=tokenizer,
                         model=model)

        save_git_info(self.hparams.output_dir)
        self.output_dir = Path(self.hparams.output_dir)
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        pickle_save(self.hparams, self.hparams_save_path)
        self.step_count = 0
        self.metrics = defaultdict(list)

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            prefix=prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {
            k: v if v >= 0 else None
            for k, v in n_observations_per_split.items()
        }

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens[
            "val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens[
            "test"], f"target_lens: {self.target_lens}"

        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.distributed_port = self.hparams.distributed_port

        # For single GPU training, init_ddp_connection is not called.
        # So we need to initialize the retrievers here.
        if hparams.gpus <= 1:
            if hparams.distributed_retriever == "ray":
                self.model.retriever.init_retrieval()
            elif hparams.distributed_retriever == "pytorch":
                self.model.retriever.init_retrieval(self.distributed_port)

        self.distributed_retriever = hparams.distributed_retriever
Ejemplo n.º 16
0
    def __init__(self, hparams, **kwargs):
        # when loading from a pytorch lightning checkpoint, hparams are passed as dict
        if isinstance(hparams, dict):
            hparams = AttrDict(hparams)
        if hparams.model_type == "rag_sequence":
            self.model_class = RagSequenceForGeneration
        elif hparams.model_type == "rag_token":
            self.model_class = RagTokenForGeneration
        elif hparams.model_type == "bart":
            self.model_class = BartForConditionalGeneration
        else:
            self.model_class = T5ForConditionalGeneration
        self.is_rag_model = is_rag_model(hparams.model_type)

        config_class = RagConfig if self.is_rag_model else AutoConfig
        config = config_class.from_pretrained(hparams.model_name_or_path)

        # set retriever parameters
        config.index_name = hparams.index_name or config.index_name
        config.passages_path = hparams.passages_path or config.passages_path
        config.index_path = hparams.index_path or config.index_path
        config.use_dummy_dataset = hparams.use_dummy_dataset
        config.n_docs = 4
        config.n_docs_splits = 4
        config.max_combined_length = 500
        config.n_words_to_src = 40  # using 40 tokens to add to src
        config.skip_ec = False
        config.bart_base_qe = True  # using bart encoder as qe
        config.do_deduplication = True

        # set extra_model_params for generator configs and load_model
        extra_model_params = ("encoder_layerdrop", "decoder_layerdrop",
                              "attention_dropout", "dropout")
        if self.is_rag_model:
            if hparams.prefix is not None:
                config.generator.prefix = hparams.prefix
            config.label_smoothing = hparams.label_smoothing
            hparams, config.generator = set_extra_model_params(
                extra_model_params, hparams, config.generator)
            if hparams.distributed_retriever == "pytorch":
                retriever = RagPyTorchDistributedRetriever.from_pretrained(
                    hparams.model_name_or_path, config=config)
            elif hparams.distributed_retriever == "ray":
                # The Ray retriever needs the handles to the retriever actors.
                retriever = RagRayDistributedRetriever.from_pretrained(
                    hparams.model_name_or_path,
                    hparams.actor_handles,
                    config=config)
            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config, retriever=retriever)
            prefix = config.question_encoder.prefix
        else:
            if hparams.prefix is not None:
                config.prefix = hparams.prefix
            hparams, config = set_extra_model_params(extra_model_params,
                                                     hparams, config)
            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config)
            prefix = config.prefix

        tokenizer = (RagTokenizer.from_pretrained(hparams.model_name_or_path)
                     if self.is_rag_model else AutoTokenizer.from_pretrained(
                         hparams.model_name_or_path))

        # if the bart base qe wants to be used
        if config.bart_base_qe:
            #print("yuh")
            # load bbforrag
            bart_base_model = BartForConditionalGeneration.from_pretrained(
                "facebook/bart-base").cuda()
            model.question_encoder = bart_base_model.model.encoder
            #sys.exit()

        super().__init__(hparams,
                         config=config,
                         tokenizer=tokenizer,
                         model=model)

        save_git_info(self.hparams.output_dir)
        self.output_dir = Path(self.hparams.output_dir)
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        pickle_save(self.hparams, self.hparams_save_path)
        self.step_count = 0
        self.metrics = defaultdict(list)

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            prefix=prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {
            k: v if v >= 0 else None
            for k, v in n_observations_per_split.items()
        }

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens[
            "val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens[
            "test"], f"target_lens: {self.target_lens}"

        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.distributed_port = self.hparams.distributed_port

        # For single GPU training, init_ddp_connection is not called.
        # So we need to initialize the retrievers here.
        if hparams.gpus <= 1:
            if hparams.distributed_retriever == "ray":
                self.model.retriever.init_retrieval()
            elif hparams.distributed_retriever == "pytorch":
                self.model.retriever.init_retrieval(self.distributed_port)

        self.distributed_retriever = hparams.distributed_retriever
        self.source_tokenizer = (self.tokenizer.question_encoder if isinstance(
            self.tokenizer, RagTokenizer) else self.tokenizer)
Ejemplo n.º 17
0
from transformers import pipeline

# Open and read the article
question = "What is the capital of the Netherlands?"

# The 'r' means raw string so ignores escape codes e.g. ignores /n
context = r"The four largest cities in the Netherlands are Amsterdam, Rotterdam, The Hague and Utrecht.[17] Amsterdam is the country's most populous city and nominal capital,[18] while The Hague holds the seat of the States General, Cabinet and Supreme Court.[19] The Port of Rotterdam is the busiest seaport in Europe, and the busiest in any country outside East Asia and Southeast Asia, behind only China and Singapore."

# Generating an answer to the question in context
qa = pipeline("question-answering")
answer = qa(question=question, context=context)

# Print the answer
print(f"Question: {question}")
print(f"Answer: '{answer['answer']}' with score {answer['score']}")

# Test RAG working
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                         index_name="exact",
                                         use_dummy_dataset=True)
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq",
                                              retriever=retriever)

input_dict = tokenizer.prepare_seq2seq_batch(
    "who holds the record in 100m freestyle", return_tensors="pt")

generated = model.generate(input_ids=input_dict["input_ids"])
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])
Ejemplo n.º 18
0
    def __init__(self, hparams, **kwargs):
        # when loading from a pytorch lightning checkpoint, hparams are passed as dict
        if isinstance(hparams, dict):
            hparams = AttrDict(hparams)
        if hparams.model_type == "rag_sequence":
            self.model_class = RagSequenceForGeneration
        elif hparams.model_type == "rag_token":
            self.model_class = RagTokenForGeneration
        elif hparams.model_type == "bart":
            self.model_class = BartForConditionalGeneration
        else:
            self.model_class = T5ForConditionalGeneration
        self.is_rag_model = is_rag_model(hparams.model_type)

        config_class = RagConfig if self.is_rag_model else AutoConfig
        config = config_class.from_pretrained(hparams.model_name_or_path)

        # set retriever parameters
        config.index_name = args.index_name or config.index_name
        config.passages_path = args.passages_path or config.passages_path
        config.index_path = args.index_path or config.index_path

        # set extra_model_params for generator configs and load_model
        extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
        if self.is_rag_model:
            if args.prefix is not None:
                config.generator.prefix = args.prefix
            config.label_smoothing = hparams.label_smoothing
            hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
            retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
            model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
            prefix = config.question_encoder.prefix
        else:
            if args.prefix is not None:
                config.prefix = args.prefix
            hparams, config = set_extra_model_params(extra_model_params, hparams, config)
            model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
            prefix = config.prefix

        tokenizer = (
            RagTokenizer.from_pretrained(hparams.model_name_or_path)
            if self.is_rag_model
            else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
        )

        super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)

        save_git_info(self.hparams.output_dir)
        self.output_dir = Path(self.hparams.output_dir)
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        pickle_save(self.hparams, self.hparams_save_path)
        self.step_count = 0
        self.metrics = defaultdict(list)

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir, max_source_length=self.hparams.max_source_length, prefix=prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.distributed_port = self.hparams.distributed_port
Ejemplo n.º 19
0
def main(
    rag_example_args: "RagExampleArguments",
    processing_args: "ProcessingArguments",
    index_hnsw_args: "IndexHnswArguments",
):

    ######################################
    logger.info("Step 1 - Create the dataset")
    ######################################

    # The dataset needed for RAG must have three columns:
    # - title (string): title of the document
    # - text (string): text of a passage of the document
    # - embeddings (array of dimension d): DPR representation of the passage

    # Let's say you have documents in tab-separated csv files with columns "title" and "text"
    assert os.path.isfile(
        rag_example_args.csv_path), "Please provide a valid path to a csv file"

    # You can load a Dataset object this way
    dataset = load_dataset("csv",
                           data_files=[rag_example_args.csv_path],
                           split="train",
                           delimiter="\t",
                           column_names=["title", "text"])

    # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

    # Then split the documents into passages of 100 words
    dataset = dataset.map(split_documents,
                          batched=True,
                          num_proc=processing_args.num_proc)

    # And compute the embeddings
    ctx_encoder = DPRContextEncoder.from_pretrained(
        rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
    ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
        rag_example_args.dpr_ctx_encoder_model_name)
    dataset = dataset.map(
        partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
        batched=True,
        batch_size=processing_args.batch_size,
    )

    # And finally save your dataset
    passages_path = os.path.join(rag_example_args.output_dir,
                                 "my_knowledge_dataset")
    dataset.save_to_disk(passages_path)
    # from datasets import load_from_disk
    # dataset = load_from_disk(passages_path)  # to reload the dataset

    ######################################
    logger.info("Step 2 - Index the dataset")
    ######################################

    # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
    index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m,
                                faiss.METRIC_INNER_PRODUCT)
    dataset.add_faiss_index("embeddings", custom_index=index)

    # And save the index
    index_path = os.path.join(rag_example_args.output_dir,
                              "my_knowledge_dataset_hnsw_index.faiss")
    dataset.get_index("embeddings").save(index_path)
    # dataset.load_faiss_index("embeddings", index_path)  # to reload the index

    ######################################
    logger.info("Step 3 - Load RAG")
    ######################################

    # Easy way to load the model
    retriever = RagRetriever.from_pretrained(rag_example_args.rag_model_name,
                                             index_name="custom",
                                             indexed_dataset=dataset)
    model = RagSequenceForGeneration.from_pretrained(
        rag_example_args.rag_model_name, retriever=retriever)
    tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)

    # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
    # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)

    ######################################
    logger.info("Step 4 - Have fun")
    ######################################

    question = rag_example_args.question or "What does Moses' rod turn into ?"
    input_ids = tokenizer.question_encoder(question,
                                           return_tensors="pt")["input_ids"]
    generated = model.generate(input_ids)
    generated_string = tokenizer.batch_decode(generated,
                                              skip_special_tokens=True)[0]
    logger.info("Q: " + question)
    logger.info("A: " + generated_string)
Ejemplo n.º 20
0
def get_rag_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True):
    # return RagTokenizer.from_pretrained(pretrained_cfg_name, do_lower_case=do_lower_case)
    tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
    tokenizer.generator = AutoTokenizer.from_pretrained(pretrained_cfg_name, do_lower_case=do_lower_case)
    tokenizer.question_encoder = AutoTokenizer.from_pretrained(pretrained_cfg_name, do_lower_case=do_lower_case)
    return tokenizer
Ejemplo n.º 21
0
        for input_batch in tqdm(dataloader):
            input_batch = {
                k: v.to(args.device)
                for k, v in input_batch.items()
            }
            logits = model(**input_batch)
            logits = logits[0]
            attention = input_batch["attention_mask"]
            argmax = [
                l[a == 1].softmax(1).max(1) for l, a in zip(logits, attention)
            ]
            preds = [idx[val >= args.thresh] for val, idx in argmax]


if __name__ == "__main__":
    tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
    trainset = RAGEDataset(args, tokenizer)
    retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base")
    model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base",
                                                  retriever=retriever)
    lit_rage = LitRage(args, trainset, model)

    trainloader = DataLoader(
        trainset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        collate_fn=trainset.collate,
    )

    checkpoint = ModelCheckpoint(
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer

model = RagTokenForGeneration.from_pretrained_question_encoder_generator(
    "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")

question_encoder_tokenizer = AutoTokenizer.from_pretrained(
    "facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = True
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer,
                         generator_tokenizer)

model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")
Ejemplo n.º 23
0
 def load_tokenizer(self) -> None:
     logger.debug('loading rag tokenizer: %s', self.name)
     self.tokenizer = RagTokenizer.from_pretrained(self.rag_sequence)
Ejemplo n.º 24
0
def main():
    global args, best_acc1
    args = parser.parse_args()

    #########################################################################################
    # Create options
    #########################################################################################

    options = {
        'vqa': {
            'trainsplit': args.vqa_trainsplit
        },
        'logs': {
            'dir_logs': args.dir_logs
        },
        'model': {
            'arch': args.arch,
            'seq2vec': {
                'type': args.st_type,
                'dropout': args.st_dropout,
                'fixed_emb': args.st_fixed_emb
            }
        },
        'optim': {
            'lr': args.learning_rate,
            'batch_size': args.batch_size,
            'epochs': args.epochs
        }
    }
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options_yaml = yaml.load(handle)
        options = utils.update_values(options, options_yaml)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)
    if args.help_opt:
        return

    # Set datasets options
    if 'vgenome' not in options:
        options['vgenome'] = None

    #########################################################################################
    # Create needed datasets
    #########################################################################################

    trainset = datasets.factory_VQA(options['vqa']['trainsplit'],
                                    options['vqa'], options['coco'],
                                    options['vgenome'])
    train_loader = trainset.data_loader(
        batch_size=options['optim']['batch_size'],
        num_workers=args.workers,
        shuffle=True)

    if options['vqa']['trainsplit'] == 'train':
        valset = datasets.factory_VQA('val', options['vqa'], options['coco'])
        val_loader = valset.data_loader(batch_size=2, num_workers=args.workers)

    if options['vqa']['trainsplit'] == 'trainval' or args.evaluate:
        testset = datasets.factory_VQA('test', options['vqa'], options['coco'])
        test_loader = testset.data_loader(
            batch_size=options['optim']['batch_size'],
            num_workers=args.workers)

    #########################################################################################
    # Create model, criterion and optimizer
    #########################################################################################
    config = RagConfig.from_pretrained("facebook/rag-token-nq")
    config.index_name = "legacy"
    config.use_dummy_dataset = False
    config.question_encoder.return_dict = True
    config.n_docs = 10
    # config.n_docs = 15
    # import pdb;
    # pdb.set_trace ()
    if not args.evaluate and not args.resume:
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base",
                                                 config=config)
        retriever = RagRetriever.from_pretrained("facebook/rag-token-base",
                                                 config=config)
        model = RagTokenForGeneration.from_pretrained(
            "facebook/rag-token-base", retriever=retriever, config=config)
    else:
        tokenizer = RagTokenizer.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                 config=config)
        retriever = RagRetriever.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                 config=config)
        model = RagTokenForGeneration.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                      retriever=retriever,
                                                      config=config)

    model.cuda()
    criterion = criterions.factory(options['vqa'], cuda=True)
    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=options['optim']['lr'],
                      eps=1e-8)
    # optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=options['optim']['lr'], momentum=0.9)

    #########################################################################################
    # args.resume: resume from a checkpoint OR create logs directory
    #########################################################################################

    exp_logger = None

    # Or create logs directory
    # os.system('mkdir -p ' + options['logs']['dir_logs'])
    path_new_opt = os.path.join(options['logs']['dir_logs'],
                                os.path.basename(args.path_opt))
    path_args = os.path.join(options['logs']['dir_logs'], 'args.yaml')
    with open(path_new_opt, 'w') as f:
        yaml.dump(options, f, default_flow_style=False)
    with open(path_args, 'w') as f:
        yaml.dump(vars(args), f, default_flow_style=False)

    if exp_logger is None:
        # Set loggers
        exp_name = os.path.basename(
            options['logs']['dir_logs'])  # add timestamp
        exp_logger = logger.Experiment(exp_name, options)
        exp_logger.add_meters('train', make_meters())
        exp_logger.add_meters('test', make_meters())
        if options['vqa']['trainsplit'] == 'train':
            exp_logger.add_meters('val', make_meters())
        exp_logger.info['model_params'] = utils.params_count(model)
        print('Model has {} parameters'.format(
            exp_logger.info['model_params']))

    #########################################################################################
    # args.evaluate: on valset OR/AND on testset
    #########################################################################################

    if args.evaluate:
        path_logger_json = os.path.join(options['logs']['dir_logs'],
                                        'logger.json')

        if options['vqa']['trainsplit'] == 'train':
            acc1, val_results = engine.validate(val_loader, model, retriever,
                                                tokenizer, criterion,
                                                exp_logger, args.start_epoch,
                                                100)
            # save results and compute OpenEnd accuracy
            exp_logger.to_json(path_logger_json)
            save_results(val_results, args.start_epoch, valset.split_name(),
                         options['logs']['dir_logs'], options['vqa']['dir'])

        return
    else:
        for epoch in range(args.start_epoch + 1, options['optim']['epochs']):
            engine.train(train_loader, model, retriever, tokenizer, criterion,
                         optimizer, exp_logger, epoch, args.print_freq)

            # remember best prec@1 and save checkpoint
            is_best = True
            best_accs1 = -1
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': options['model']['arch'],
                    'best_acc1': best_acc1,
                    'exp_logger': exp_logger
                }, model, tokenizer, retriever, options['logs']['dir_logs'],
                args.save_model, True)