def get_dummy_custom_hf_index_pytorch_retriever(self,
                                                 init_retrieval: bool,
                                                 from_disk: bool,
                                                 port=12345):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     else:
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     if init_retrieval:
         retriever.init_retrieval(port)
     return retriever
예제 #2
0
 def get_rag_config(self):
     question_encoder_config = AutoConfig.from_pretrained(
         "facebook/dpr-question_encoder-single-nq-base")
     generator_config = AutoConfig.from_pretrained(
         "facebook/bart-large-cnn")
     return RagConfig.from_question_encoder_generator_configs(
         question_encoder_config,
         generator_config,
         bos_token_id=0,
         decoder_start_token_id=2,
         eos_token_id=2,
         is_encoder_decoder=True,
         pad_token_id=1,
         vocab_size=50264,
         title_sep=" / ",
         doc_sep=" // ",
         n_docs=5,
         max_combined_length=300,
         dataset="wiki_dpr",
         dataset_split="train",
         index_name="exact",
         index_path=None,
         use_dummy_dataset=True,
         retrieval_vector_size=768,
         retrieval_batch_size=8,
     )
예제 #3
0
    def config_and_inputs(self):
        question_encoder_tester = DPRModelTester(self)
        dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs(
        )
        generator_tester = T5ModelTester(self, vocab_size=1100, n_positions=30)
        t5_config_and_inputs = generator_tester.prepare_config_and_inputs()

        (question_encoder_config, input_ids, _, input_mask, _, _,
         _) = dpr_config_and_inputs
        (generator_config, _, decoder_input_ids, _, decoder_attention_mask,
         _) = t5_config_and_inputs
        config = RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            n_docs=self.n_docs,
            retrieval_vector_size=self.retrieval_vector_size,
            max_combined_length=self.max_combined_length,
            use_cache=False,
        )

        return {
            "config": config,
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }
예제 #4
0
    def config_and_inputs(self):
        question_encoder_tester = DPRModelTester(self)
        dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs(
        )
        generator_tester = BartModelTester(self)
        bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common(
        )

        (question_encoder_config, input_ids, _, input_mask, _, _,
         _) = dpr_config_and_inputs
        (generator_config, bart_inputs_dict) = bart_config_and_inputs
        decoder_input_ids, decoder_attention_mask = bart_inputs_dict[
            "input_ids"], bart_inputs_dict["attention_mask"]

        config = RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            n_docs=self.n_docs,
            retrieval_vector_size=self.retrieval_vector_size,
            max_combined_length=self.max_combined_length,
            use_cache=False,
        )

        return {
            "config": config,
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }
    def test_init_and_from_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base")
        rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="tf").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="tf").input_ids

        rag(
            input_ids,
            decoder_input_ids=decoder_input_ids,
        )

        # this should not give any warnings
        with tempfile.TemporaryDirectory() as tmpdirname:
            rag.save_pretrained(tmpdirname)
            rag = TFRagTokenForGeneration.from_pretrained(
                tmpdirname, retriever=rag_retriever)
 def from_pretrained(cls,
                     retriever_name_or_path,
                     actor_handles,
                     indexed_dataset=None,
                     **kwargs):
     requires_datasets(cls)
     requires_faiss(cls)
     config = kwargs.pop("config", None) or RagConfig.from_pretrained(
         retriever_name_or_path, **kwargs)
     rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path,
                                                  config=config)
     question_encoder_tokenizer = rag_tokenizer.question_encoder
     generator_tokenizer = rag_tokenizer.generator
     if indexed_dataset is not None:
         config.index_name = "custom"
         index = CustomHFIndex(config.retrieval_vector_size,
                               indexed_dataset)
     else:
         index = cls._build_index(config)
     return cls(
         config,
         question_encoder_tokenizer=question_encoder_tokenizer,
         generator_tokenizer=generator_tokenizer,
         retrieval_workers=actor_handles,
         index=index,
     )
 def get_dummy_ray_distributed_retriever(
         self, init_retrieval: bool) -> RagRayDistributedRetriever:
     # Have to run in local mode because sys.path modifications at top of
     # file are not propogated to remote workers.
     # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
     ray.init(local_mode=True)
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     remote_cls = ray.remote(RayRetriever)
     workers = [remote_cls.remote() for _ in range(1)]
     with patch("transformers.models.rag.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = self.get_dummy_dataset()
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
         )
         if init_retrieval:
             retriever.init_retrieval()
     return retriever
 def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool,
                                             from_disk: bool):
     # Have to run in local mode because sys.path modifications at top of
     # file are not propogated to remote workers.
     # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
     ray.init(local_mode=True)
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     remote_cls = ray.remote(RayRetriever)
     workers = [remote_cls.remote() for _ in range(1)]
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
             index=CustomHFIndex.load_from_disk(
                 vector_size=config.retrieval_vector_size,
                 dataset_path=config.passages_path,
                 index_path=config.index_path,
             ),
         )
     else:
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     if init_retrieval:
         retriever.init_retrieval()
     return retriever
def consolidate(
    model_type,
    generator_name_or_path: str,
    question_encoder_name_or_path: str,
    dest_dir: Path,
    config_name_or_path: str = None,
    generator_tokenizer_name_or_path: str = None,
    question_encoder_tokenizer_name_or_path: str = None,
):

    if config_name_or_path is None:
        config_name_or_path = "facebook/rag-token-base" if model_type == "rag_token" else "facebook/rag-sequence-base"

    if generator_tokenizer_name_or_path is None:
        generator_tokenizer_name_or_path = generator_name_or_path

    if question_encoder_tokenizer_name_or_path is None:
        question_encoder_tokenizer_name_or_path = question_encoder_name_or_path

    model_class = RagTokenForGeneration if model_type == "rag_token" else RagSequenceForGeneration

    # Save model.
    rag_config = RagConfig.from_pretrained(config_name_or_path)
    gen_config = AutoConfig.from_pretrained(generator_name_or_path)
    question_encoder_config = AutoConfig.from_pretrained(
        question_encoder_name_or_path)

    rag_config.generator = gen_config
    rag_config.question_encoder = question_encoder_config

    rag_model = model_class.from_pretrained_question_encoder_generator(
        question_encoder_name_or_path,
        generator_name_or_path,
        config=rag_config)
    rag_model.save_pretrained(dest_dir)

    # Sanity check.
    model_class.from_pretrained(dest_dir)

    # Save tokenizers.
    gen_tokenizer = AutoTokenizer.from_pretrained(
        generator_tokenizer_name_or_path)
    gen_tokenizer.save_pretrained(dest_dir / "generator_tokenizer/")
    question_encoder_tokenizer = AutoTokenizer.from_pretrained(
        question_encoder_tokenizer_name_or_path)
    question_encoder_tokenizer.save_pretrained(dest_dir /
                                               "question_encoder_tokenizer/")
 def get_dummy_pytorch_distributed_retriever(
     self, init_retrieval: bool, port=12345
 ) -> RagPyTorchDistributedRetriever:
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
         if init_retrieval:
             retriever.init_retrieval(port)
     return retriever
예제 #11
0
def get_rag_generator_components(args, inference_only: bool = False, **kwargs):

    # tokenizer
    tensorizer = get_rag_tensorizer(args)

    # generator
    dropout = args.dropout if hasattr(args, 'dropout') else 0.0
    rag_config = RagConfig.from_pretrained("facebook/rag-token-nq")
    if dropout != 0:
        rag_config.attention_probs_dropout_prob = dropout
        rag_config.hidden_dropout_prob = dropout

    # facebook/rag-token-nq
    # rag = RagTokenForGeneration.from_pretrained(args.pretrained_model_cfg, config=rag_config, use_dummy_dataset=True)

    # customize rag generator/question_encoder  
    # Notice: question_encoder not required.
    generator_name_or_path = args.pretrained_model_cfg
    question_encoder_name_or_path = generator_name_or_path
    gen_config =  AutoConfig.from_pretrained(generator_name_or_path)
    question_encoder_config = AutoConfig.from_pretrained(question_encoder_name_or_path)
    rag_config.generator = gen_config
    rag_config.question_encoder = question_encoder_config

    rag = RagTokenForGeneration.from_pretrained_question_encoder_generator(
        question_encoder_name_or_path, generator_name_or_path, config=rag_config, dummy_dataset=True
    )

    generator = Generator(rag, tensorizer)

    # optimizer
    optimizer = get_optimizer(generator,
                              learning_rate=args.learning_rate,
                              adam_eps=args.adam_eps, weight_decay=args.weight_decay,
                              ) if not inference_only else None

    return tensorizer, generator, optimizer
예제 #12
0
def main(
    rag_example_args: "RagExampleArguments",
    processing_args: "ProcessingArguments",
    index_hnsw_args: "IndexHnswArguments",
):

    ######################################
    logger.info("Step 1 - Create the dataset")
    ######################################

    # The dataset needed for RAG must have three columns:
    # - title (string): title of the document
    # - text (string): text of a passage of the document
    # - embeddings (array of dimension d): DPR representation of the passage

    # Let's say you have documents in tab-separated csv files with columns "title" and "text"
    assert os.path.isfile(
        rag_example_args.csv_path), "Please provide a valid path to a csv file"

    # You can load a Dataset object this way
    dataset = load_dataset("csv",
                           data_files=[rag_example_args.csv_path],
                           split="train",
                           delimiter="\t",
                           column_names=["title", "text"])

    # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

    # Then split the documents into passages of n words (changing the param in split_text to what you want n to be)
    dataset = dataset.map(split_documents,
                          batched=True,
                          num_proc=processing_args.num_proc)
    # And compute the embeddings
    if use_generated_model:
        model_path = 'ragfinetune_4_4_false_50_true/checkpoint2/'  # SET PATH TO CHECKPOINT
        config = RagConfig.from_pretrained(model_path)
        config.n_docs = 4
        config.n_docs_splits = 4
        retriever = RagRetriever.from_pretrained(model_path, config=config)
        checkpoint_model = RagSequenceForGeneration.from_pretrained(
            model_path, config=config, retriever=retriever).cuda()
        ctx_encoder = checkpoint_model.generator.get_encoder()
        ctx_tokenizer = checkpoint_model.retriever.generator_tokenizer
    else:
        ctx_encoder = DPRContextEncoder.from_pretrained(
            rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
        ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
            rag_example_args.dpr_ctx_encoder_model_name)
    new_features = Features({
        "text": Value("string"),
        "title": Value("string"),
        "embeddings": Sequence(Value("float32"))
    })  # optional, save as float32 instead of float64 to save space
    dataset = dataset.map(
        partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
        batched=True,
        batch_size=processing_args.batch_size,
        features=new_features,
    )

    # And finally save your dataset
    passages_path = os.path.join(rag_example_args.output_dir,
                                 "my_knowledge_dataset")
    dataset.save_to_disk(passages_path)
    # from datasets import load_from_disk
    # dataset = load_from_disk(passages_path)  # to reload the dataset

    ######################################
    logger.info("Step 2 - Index the dataset")
    ######################################

    # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
    index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m,
                                faiss.METRIC_INNER_PRODUCT)
    dataset.add_faiss_index("embeddings", custom_index=index)

    # And save the index
    index_path = os.path.join(rag_example_args.output_dir,
                              "my_knowledge_dataset_hnsw_index.faiss")
    dataset.get_index("embeddings").save(index_path)
    # dataset.load_faiss_index("embeddings", index_path)  # to reload the index

    ######################################
    logger.info("Step 3 - Load RAG")
    ######################################

    # Easy way to load the model
    retriever = RagRetriever.from_pretrained(rag_example_args.rag_model_name,
                                             index_name="custom",
                                             indexed_dataset=dataset)
    model = RagSequenceForGeneration.from_pretrained(
        rag_example_args.rag_model_name, retriever=retriever)
    tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)

    # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
    # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)

    ######################################
    logger.info("Step 4 - Have fun")
    ######################################

    question = rag_example_args.question or "What does Moses' rod turn into ?"
    input_ids = tokenizer.question_encoder(question,
                                           return_tensors="pt")["input_ids"]
    generated = model.generate(input_ids)
    generated_string = tokenizer.batch_decode(generated,
                                              skip_special_tokens=True)[0]
    logger.info("Q: " + question)
    logger.info("A: " + generated_string)
예제 #13
0
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration, BartForConditionalGeneration, RagConfig, DPRQuestionEncoder
import torch
from transformers.models.auto import AutoModel

config = RagConfig.from_pretrained ("facebook/rag-token-nq")
config.index_name = "legacy"
config.use_dummy_dataset = False
config.question_encoder.return_dict = True
print("==> load tokenizer")
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
print("==> load retriever")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", config=config)
print("dataset info")
print(dir(retriever.index))
print("==> load generator")
# question encoder
# question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
# generator = BartForConditionalGeneration.from_pretrained('facebook/bart-large')

# config = RagConfig.from_question_encoder_generator_configs(question_encoder.config, generator.config)
# model = RagTokenForGeneration(config, question_encoder=question_encoder,generator=generator, retriever=retriever)
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

# input_dict = tokenizer.prepare_seq2seq_batch("USA president in 1999?", return_tensors="pt")
input_dict = tokenizer.prepare_seq2seq_batch("What kind of vehicle uses fire hydrant?", return_tensors="pt")
# input_dict = tokenizer.prepare_seq2seq_batch("what phylum does cat belong to?", return_tensors="pt")

print(input_dict.keys()) # dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
input_ids = input_dict['input_ids']
print("==> encode")
question_hidden_states = model.question_encoder(input_ids)[0]
예제 #14
0
def main():
    global args, best_acc1
    args = parser.parse_args()

    #########################################################################################
    # Create options
    #########################################################################################

    options = {
        'vqa': {
            'trainsplit': args.vqa_trainsplit
        },
        'logs': {
            'dir_logs': args.dir_logs
        },
        'model': {
            'arch': args.arch,
            'seq2vec': {
                'type': args.st_type,
                'dropout': args.st_dropout,
                'fixed_emb': args.st_fixed_emb
            }
        },
        'optim': {
            'lr': args.learning_rate,
            'batch_size': args.batch_size,
            'epochs': args.epochs
        }
    }
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options_yaml = yaml.load(handle)
        options = utils.update_values(options, options_yaml)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)
    if args.help_opt:
        return

    # Set datasets options
    if 'vgenome' not in options:
        options['vgenome'] = None

    #########################################################################################
    # Create needed datasets
    #########################################################################################

    trainset = datasets.factory_VQA(options['vqa']['trainsplit'],
                                    options['vqa'], options['coco'],
                                    options['vgenome'])
    train_loader = trainset.data_loader(
        batch_size=options['optim']['batch_size'],
        num_workers=args.workers,
        shuffle=True)

    if options['vqa']['trainsplit'] == 'train':
        valset = datasets.factory_VQA('val', options['vqa'], options['coco'])
        val_loader = valset.data_loader(batch_size=2, num_workers=args.workers)

    if options['vqa']['trainsplit'] == 'trainval' or args.evaluate:
        testset = datasets.factory_VQA('test', options['vqa'], options['coco'])
        test_loader = testset.data_loader(
            batch_size=options['optim']['batch_size'],
            num_workers=args.workers)

    #########################################################################################
    # Create model, criterion and optimizer
    #########################################################################################
    config = RagConfig.from_pretrained("facebook/rag-token-nq")
    config.index_name = "legacy"
    config.use_dummy_dataset = False
    config.question_encoder.return_dict = True
    config.n_docs = 10
    # config.n_docs = 15
    # import pdb;
    # pdb.set_trace ()
    if not args.evaluate and not args.resume:
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base",
                                                 config=config)
        retriever = RagRetriever.from_pretrained("facebook/rag-token-base",
                                                 config=config)
        model = RagTokenForGeneration.from_pretrained(
            "facebook/rag-token-base", retriever=retriever, config=config)
    else:
        tokenizer = RagTokenizer.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                 config=config)
        retriever = RagRetriever.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                 config=config)
        model = RagTokenForGeneration.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                      retriever=retriever,
                                                      config=config)

    model.cuda()
    criterion = criterions.factory(options['vqa'], cuda=True)
    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=options['optim']['lr'],
                      eps=1e-8)
    # optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=options['optim']['lr'], momentum=0.9)

    #########################################################################################
    # args.resume: resume from a checkpoint OR create logs directory
    #########################################################################################

    exp_logger = None

    # Or create logs directory
    # os.system('mkdir -p ' + options['logs']['dir_logs'])
    path_new_opt = os.path.join(options['logs']['dir_logs'],
                                os.path.basename(args.path_opt))
    path_args = os.path.join(options['logs']['dir_logs'], 'args.yaml')
    with open(path_new_opt, 'w') as f:
        yaml.dump(options, f, default_flow_style=False)
    with open(path_args, 'w') as f:
        yaml.dump(vars(args), f, default_flow_style=False)

    if exp_logger is None:
        # Set loggers
        exp_name = os.path.basename(
            options['logs']['dir_logs'])  # add timestamp
        exp_logger = logger.Experiment(exp_name, options)
        exp_logger.add_meters('train', make_meters())
        exp_logger.add_meters('test', make_meters())
        if options['vqa']['trainsplit'] == 'train':
            exp_logger.add_meters('val', make_meters())
        exp_logger.info['model_params'] = utils.params_count(model)
        print('Model has {} parameters'.format(
            exp_logger.info['model_params']))

    #########################################################################################
    # args.evaluate: on valset OR/AND on testset
    #########################################################################################

    if args.evaluate:
        path_logger_json = os.path.join(options['logs']['dir_logs'],
                                        'logger.json')

        if options['vqa']['trainsplit'] == 'train':
            acc1, val_results = engine.validate(val_loader, model, retriever,
                                                tokenizer, criterion,
                                                exp_logger, args.start_epoch,
                                                100)
            # save results and compute OpenEnd accuracy
            exp_logger.to_json(path_logger_json)
            save_results(val_results, args.start_epoch, valset.split_name(),
                         options['logs']['dir_logs'], options['vqa']['dir'])

        return
    else:
        for epoch in range(args.start_epoch + 1, options['optim']['epochs']):
            engine.train(train_loader, model, retriever, tokenizer, criterion,
                         optimizer, exp_logger, epoch, args.print_freq)

            # remember best prec@1 and save checkpoint
            is_best = True
            best_accs1 = -1
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': options['model']['arch'],
                    'best_acc1': best_acc1,
                    'exp_logger': exp_logger
                }, model, tokenizer, retriever, options['logs']['dir_logs'],
                args.save_model, True)