def test_rag_token_greedy_search(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_token = TFRagTokenForGeneration.from_pretrained(
            "facebook/rag-token-nq", retriever=retriever)

        # check first two questions
        input_dict = tokenizer(
            self.test_data_questions[:2],
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        # make sure only 1 beam is used
        rag_token.config.num_beams = 1

        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 2
0
    def test_rag_token_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
        rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
            torch_device
        )

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids.to(torch_device)
        attention_mask = input_dict.attention_mask.to(torch_device)

        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 3
0
    def test_rag_sequence_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        )
        rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        output_ids = rag_sequence.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 4
0
def main(args):
    model_kwargs = {}
    if args.model_type is None:
        args.model_type = infer_model_type(args.model_name_or_path)
        assert args.model_type is not None
    if args.model_type.startswith("rag"):
        model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
        model_kwargs["n_docs"] = args.n_docs
        if args.index_name is not None:
            model_kwargs["index_name"] = args.index_name
        if args.index_path is not None:
            model_kwargs["index_path"] = args.index_path
    else:
        model_class = BartForConditionalGeneration

    checkpoints = (
        [f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
        if args.eval_all_checkpoints
        else [args.model_name_or_path]
    )

    logger.info("Evaluate the following checkpoints: %s", checkpoints)

    score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
    evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval

    for checkpoint in checkpoints:
        if os.path.exists(args.predictions_path) and (not args.recalculate):
            logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
            score_fn(args, args.predictions_path, args.gold_data_path)
            continue

        logger.info("***** Running evaluation for {} *****".format(checkpoint))
        logger.info("  Batch size = %d", args.eval_batch_size)
        logger.info("  Predictions will be stored under {}".format(args.predictions_path))

        if args.model_type.startswith("rag"):
            retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
            model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
            model.retriever.init_retrieval()
        else:
            model = model_class.from_pretrained(checkpoint, **model_kwargs)
        model.to(args.device)

        with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
            questions = []
            for line in tqdm(eval_file):
                questions.append(line.strip())
                if len(questions) == args.eval_batch_size:
                    answers = evaluate_batch_fn(args, model, questions)
                    preds_file.write("\n".join(answers) + "\n")
                    preds_file.flush()
                    questions = []
            if len(questions) > 0:
                answers = evaluate_batch_fn(args, model, questions)
                preds_file.write("\n".join(answers))
                preds_file.flush()

            score_fn(args, args.predictions_path, args.gold_data_path)
Ejemplo n.º 5
0
 def load_model(self) -> None:
     logger.debug('loading rag retriever: %s', self.name)
     retriever = RagRetriever.from_pretrained(self.rag_sequence,
                                              index_name='custom',
                                              indexed_dataset=self.dataset)
     logger.debug('loading rag model: %s', self.name)
     self.model = RagSequenceForGeneration.from_pretrained(
         self.rag_sequence, retriever=retriever)
    def test_rag_sequence_generate_batch_from_context_input_ids(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_sequence = RagSequenceForGeneration.from_pretrained(
            "facebook/rag-sequence-nq", retriever=retriever).to(torch_device)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids.to(torch_device)
        attention_mask = input_dict.attention_mask.to(torch_device)

        question_hidden_states = rag_sequence.question_encoder(
            input_ids, attention_mask=attention_mask)[0]
        docs_dict = retriever(input_ids.cpu().detach().numpy(),
                              question_hidden_states.cpu().detach().numpy(),
                              return_tensors="pt")
        doc_scores = torch.bmm(
            question_hidden_states.unsqueeze(1),
            docs_dict["retrieved_doc_embeds"].to(
                torch_device).float().transpose(1, 2),
        ).squeeze(1)

        output_ids = rag_sequence.generate(
            context_input_ids=docs_dict["context_input_ids"].to(torch_device),
            context_attention_mask=docs_dict["context_attention_mask"].to(
                torch_device),
            doc_scores=doc_scores.to(torch_device),
            do_deduplication=True,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
            " reticular formation",
            " walls of the abdomen",
            " spodumene",
            " obama",
            " new orleans",
            " japan",
            " old trafford",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 7
0
    def __init__(self, **args):
        super(RagTrainer, self).__init__()
        self.save_hyperparameters()
        self.rag_retriever = RagRetriever.from_pretrained(
            self.hparams['rag_ckpt_path'],
            index_name='custom',
            passages_path=self.hparams['wiki_ds_path'],
            index_path=self.hparams['wiki_index_path'])

        self.rag = RagSequenceForGeneration.from_pretrained(
            self.hparams['rag_ckpt_path'], retriever=self.rag_retriever)
    def test_rag_sequence_generate_batch_from_context_input_ids(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_sequence = TFRagSequenceForGeneration.from_pretrained(
            "facebook/rag-sequence-nq", retriever=retriever)
        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids

        question_hidden_states = rag_sequence.question_encoder(input_ids)[0]
        docs_dict = retriever(input_ids.numpy(),
                              question_hidden_states.numpy(),
                              return_tensors="tf")
        doc_scores = tf.squeeze(
            tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]),
                      docs_dict["retrieved_doc_embeds"],
                      transpose_b=True),
            axis=[1],
        )
        output_ids = rag_sequence.generate(
            context_input_ids=docs_dict["context_input_ids"],
            context_attention_mask=docs_dict["context_attention_mask"],
            doc_scores=doc_scores,
            do_deduplication=True,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
    def test_rag_token_generate_batch(self):
        # NOTE: gold labels comes from num_beam=4 -- if change gold labels to greedy-generated, test will pass
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_token = TFRagTokenForGeneration.from_pretrained(
            "facebook/rag-token-nq", retriever=retriever, from_pt=True)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        #         rag_token.config.num_beams = 1 -> different in 2 answers (obama, united stadium) to num_beams=4 labels
        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
            " step by step",
            " stomach",
            " spodumene",
            " obama",
            " northern new jersey",
            " india",
            " united stadium",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
    def test_rag_token_generate_batch(self):
        # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_token = TFRagTokenForGeneration.from_pretrained(
            "facebook/rag-token-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
        ]

        # Split into 2 batches of 4 examples to avoid GPU OOM.
        output_ids = rag_token.generate(
            input_ids[:4],
            attention_mask=attention_mask[:4],
        )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(outputs, EXPECTED_OUTPUTS[:4])

        output_ids = rag_token.generate(
            input_ids[4:],
            attention_mask=attention_mask[4:],
        )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(outputs, EXPECTED_OUTPUTS[4:])
Ejemplo n.º 11
0
    def test_rag_sequence_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        )
        rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
            torch_device
        )

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids.to(torch_device)
        attention_mask = input_dict.attention_mask.to(torch_device)

        output_ids = rag_sequence.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
            " reticular formation",
            " walls of the abdomen",
            " spodumene",
            " obama",
            " grainger's compound",
            " japan",
            " old trafford stadium",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 12
0
    def test_rag_token_generate_batch(self):
        # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_token = TFRagTokenForGeneration.from_pretrained(
            "facebook/rag-token-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
            " evolution",
            " stomach",
            " spodumene",
            " obama",
            " northern new jersey",
            " india",
            " united stadium",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ejemplo n.º 13
0
def main():
    global args, best_acc1
    args = parser.parse_args()

    #########################################################################################
    # Create options
    #########################################################################################

    options = {
        'vqa': {
            'trainsplit': args.vqa_trainsplit
        },
        'logs': {
            'dir_logs': args.dir_logs
        },
        'model': {
            'arch': args.arch,
            'seq2vec': {
                'type': args.st_type,
                'dropout': args.st_dropout,
                'fixed_emb': args.st_fixed_emb
            }
        },
        'optim': {
            'lr': args.learning_rate,
            'batch_size': args.batch_size,
            'epochs': args.epochs
        }
    }
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options_yaml = yaml.load(handle)
        options = utils.update_values(options, options_yaml)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)
    if args.help_opt:
        return

    # Set datasets options
    if 'vgenome' not in options:
        options['vgenome'] = None

    #########################################################################################
    # Create needed datasets
    #########################################################################################

    trainset = datasets.factory_VQA(options['vqa']['trainsplit'],
                                    options['vqa'], options['coco'],
                                    options['vgenome'])
    train_loader = trainset.data_loader(
        batch_size=options['optim']['batch_size'],
        num_workers=args.workers,
        shuffle=True)

    if options['vqa']['trainsplit'] == 'train':
        valset = datasets.factory_VQA('val', options['vqa'], options['coco'])
        val_loader = valset.data_loader(batch_size=2, num_workers=args.workers)

    if options['vqa']['trainsplit'] == 'trainval' or args.evaluate:
        testset = datasets.factory_VQA('test', options['vqa'], options['coco'])
        test_loader = testset.data_loader(
            batch_size=options['optim']['batch_size'],
            num_workers=args.workers)

    #########################################################################################
    # Create model, criterion and optimizer
    #########################################################################################
    config = RagConfig.from_pretrained("facebook/rag-token-nq")
    config.index_name = "legacy"
    config.use_dummy_dataset = False
    config.question_encoder.return_dict = True
    config.n_docs = 10
    # config.n_docs = 15
    # import pdb;
    # pdb.set_trace ()
    if not args.evaluate and not args.resume:
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base",
                                                 config=config)
        retriever = RagRetriever.from_pretrained("facebook/rag-token-base",
                                                 config=config)
        model = RagTokenForGeneration.from_pretrained(
            "facebook/rag-token-base", retriever=retriever, config=config)
    else:
        tokenizer = RagTokenizer.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                 config=config)
        retriever = RagRetriever.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                 config=config)
        model = RagTokenForGeneration.from_pretrained(os.path.join(
            options['logs']['dir_logs'], "epoch_{}".format(args.start_epoch)),
                                                      retriever=retriever,
                                                      config=config)

    model.cuda()
    criterion = criterions.factory(options['vqa'], cuda=True)
    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=options['optim']['lr'],
                      eps=1e-8)
    # optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=options['optim']['lr'], momentum=0.9)

    #########################################################################################
    # args.resume: resume from a checkpoint OR create logs directory
    #########################################################################################

    exp_logger = None

    # Or create logs directory
    # os.system('mkdir -p ' + options['logs']['dir_logs'])
    path_new_opt = os.path.join(options['logs']['dir_logs'],
                                os.path.basename(args.path_opt))
    path_args = os.path.join(options['logs']['dir_logs'], 'args.yaml')
    with open(path_new_opt, 'w') as f:
        yaml.dump(options, f, default_flow_style=False)
    with open(path_args, 'w') as f:
        yaml.dump(vars(args), f, default_flow_style=False)

    if exp_logger is None:
        # Set loggers
        exp_name = os.path.basename(
            options['logs']['dir_logs'])  # add timestamp
        exp_logger = logger.Experiment(exp_name, options)
        exp_logger.add_meters('train', make_meters())
        exp_logger.add_meters('test', make_meters())
        if options['vqa']['trainsplit'] == 'train':
            exp_logger.add_meters('val', make_meters())
        exp_logger.info['model_params'] = utils.params_count(model)
        print('Model has {} parameters'.format(
            exp_logger.info['model_params']))

    #########################################################################################
    # args.evaluate: on valset OR/AND on testset
    #########################################################################################

    if args.evaluate:
        path_logger_json = os.path.join(options['logs']['dir_logs'],
                                        'logger.json')

        if options['vqa']['trainsplit'] == 'train':
            acc1, val_results = engine.validate(val_loader, model, retriever,
                                                tokenizer, criterion,
                                                exp_logger, args.start_epoch,
                                                100)
            # save results and compute OpenEnd accuracy
            exp_logger.to_json(path_logger_json)
            save_results(val_results, args.start_epoch, valset.split_name(),
                         options['logs']['dir_logs'], options['vqa']['dir'])

        return
    else:
        for epoch in range(args.start_epoch + 1, options['optim']['epochs']):
            engine.train(train_loader, model, retriever, tokenizer, criterion,
                         optimizer, exp_logger, epoch, args.print_freq)

            # remember best prec@1 and save checkpoint
            is_best = True
            best_accs1 = -1
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': options['model']['arch'],
                    'best_acc1': best_acc1,
                    'exp_logger': exp_logger
                }, model, tokenizer, retriever, options['logs']['dir_logs'],
                args.save_model, True)
Ejemplo n.º 14
0
def main(
    rag_example_args: "RagExampleArguments",
    processing_args: "ProcessingArguments",
    index_hnsw_args: "IndexHnswArguments",
):

    ######################################
    logger.info("Step 1 - Create the dataset")
    ######################################

    # The dataset needed for RAG must have three columns:
    # - title (string): title of the document
    # - text (string): text of a passage of the document
    # - embeddings (array of dimension d): DPR representation of the passage

    # Let's say you have documents in tab-separated csv files with columns "title" and "text"
    assert os.path.isfile(
        rag_example_args.csv_path), "Please provide a valid path to a csv file"

    # You can load a Dataset object this way
    dataset = load_dataset("csv",
                           data_files=[rag_example_args.csv_path],
                           split="train",
                           delimiter="\t",
                           column_names=["title", "text"])

    # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

    # Then split the documents into passages of 100 words
    dataset = dataset.map(split_documents,
                          batched=True,
                          num_proc=processing_args.num_proc)

    # And compute the embeddings
    ctx_encoder = DPRContextEncoder.from_pretrained(
        rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
    ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
        rag_example_args.dpr_ctx_encoder_model_name)
    dataset = dataset.map(
        partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
        batched=True,
        batch_size=processing_args.batch_size,
    )

    # And finally save your dataset
    passages_path = os.path.join(rag_example_args.output_dir,
                                 "my_knowledge_dataset")
    dataset.save_to_disk(passages_path)
    # from datasets import load_from_disk
    # dataset = load_from_disk(passages_path)  # to reload the dataset

    ######################################
    logger.info("Step 2 - Index the dataset")
    ######################################

    # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
    index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m,
                                faiss.METRIC_INNER_PRODUCT)
    dataset.add_faiss_index("embeddings", custom_index=index)

    # And save the index
    index_path = os.path.join(rag_example_args.output_dir,
                              "my_knowledge_dataset_hnsw_index.faiss")
    dataset.get_index("embeddings").save(index_path)
    # dataset.load_faiss_index("embeddings", index_path)  # to reload the index

    ######################################
    logger.info("Step 3 - Load RAG")
    ######################################

    # Easy way to load the model
    retriever = RagRetriever.from_pretrained(rag_example_args.rag_model_name,
                                             index_name="custom",
                                             indexed_dataset=dataset)
    model = RagSequenceForGeneration.from_pretrained(
        rag_example_args.rag_model_name, retriever=retriever)
    tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)

    # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
    # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)

    ######################################
    logger.info("Step 4 - Have fun")
    ######################################

    question = rag_example_args.question or "What does Moses' rod turn into ?"
    input_ids = tokenizer.question_encoder(question,
                                           return_tensors="pt")["input_ids"]
    generated = model.generate(input_ids)
    generated_string = tokenizer.batch_decode(generated,
                                              skip_special_tokens=True)[0]
    logger.info("Q: " + question)
    logger.info("A: " + generated_string)
Ejemplo n.º 15
0
                k: v.to(args.device)
                for k, v in input_batch.items()
            }
            logits = model(**input_batch)
            logits = logits[0]
            attention = input_batch["attention_mask"]
            argmax = [
                l[a == 1].softmax(1).max(1) for l, a in zip(logits, attention)
            ]
            preds = [idx[val >= args.thresh] for val, idx in argmax]


if __name__ == "__main__":
    tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
    trainset = RAGEDataset(args, tokenizer)
    retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base")
    model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base",
                                                  retriever=retriever)
    lit_rage = LitRage(args, trainset, model)

    trainloader = DataLoader(
        trainset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        collate_fn=trainset.collate,
    )

    checkpoint = ModelCheckpoint(
        filepath=os.path.join(args.output,
                              "{epoch:02d}-{global_step:02d}-{val_loss:.2f}"),
Ejemplo n.º 16
0
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, RagConfig

tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq",
                                         index_name="exact",
                                         use_dummy_dataset=True)
model = RagSequenceForGeneration.from_pretrained(
    "facebook/rag-sequence-nq", retriever=retriever).to("cuda:0")
model.add_tokens()

model.config.n_docs = 6
retriever.config.n_docs = 6

model.config.n_docs_splits = 3
retriever.config.n_docs_splits = 3

model.skip_ec = True
model.skip_ec = True

input_dict = tokenizer.prepare_seq2seq_batch("am i a cool person",
                                             return_tensors="pt").to("cuda:0")
generated = model.generate(**input_dict,
                           extra_context=["Cats are cool animals!"],
                           num_beams=4)
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])

# should give 54 => google says either 44 or 51
Ejemplo n.º 17
0
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration, BartForConditionalGeneration, RagConfig, DPRQuestionEncoder
import torch
from transformers.models.auto import AutoModel

config = RagConfig.from_pretrained ("facebook/rag-token-nq")
config.index_name = "legacy"
config.use_dummy_dataset = False
config.question_encoder.return_dict = True
print("==> load tokenizer")
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
print("==> load retriever")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", config=config)
print("dataset info")
print(dir(retriever.index))
print("==> load generator")
# question encoder
# question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
# generator = BartForConditionalGeneration.from_pretrained('facebook/bart-large')

# config = RagConfig.from_question_encoder_generator_configs(question_encoder.config, generator.config)
# model = RagTokenForGeneration(config, question_encoder=question_encoder,generator=generator, retriever=retriever)
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

# input_dict = tokenizer.prepare_seq2seq_batch("USA president in 1999?", return_tensors="pt")
input_dict = tokenizer.prepare_seq2seq_batch("What kind of vehicle uses fire hydrant?", return_tensors="pt")
# input_dict = tokenizer.prepare_seq2seq_batch("what phylum does cat belong to?", return_tensors="pt")

print(input_dict.keys()) # dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
input_ids = input_dict['input_ids']
print("==> encode")
question_hidden_states = model.question_encoder(input_ids)[0]
Ejemplo n.º 18
0
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq",
                                         dataset="wiki_dpr",
                                         index_name='compressed')
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq",
                                                 retriever=retriever)

input_dict = tokenizer.prepare_seq2seq_batch(
    "how many countries are in europe", return_tensors="pt")

generated = model.generate(input_ids=input_dict["input_ids"])
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])