示例#1
0
def require_distributed_retrieval(test_case):
    """
    Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
    :class:`~transformers.RagRetriever`.

    These tests are skipped when respective libraries are not installed.

    """
    if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
        test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
    return test_case
示例#2
0
def require_retrieval(test_case):
    """
    Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
    :class:`~transformers.RagRetriever`.

    These tests are skipped when respective libraries are not installed.

    """
    if not (is_tf_available() and is_datasets_available() and is_faiss_available()):
        test_case = unittest.skip("test requires tensorflow, datasets and faiss")(test_case)
    return test_case
示例#3
0
def require_retrieval(test_case):
    """
    Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
    [`RagRetriever`].

    These tests are skipped when respective libraries are not installed.

    """
    if not (is_torch_available() and is_datasets_available()
            and is_faiss_available()):
        test_case = unittest.skip("test requires PyTorch, datasets and faiss")(
            test_case)
    return test_case
示例#4
0
    def test_finetune_bert2bert(self):
        if not is_datasets_available():
            return

        import datasets

        bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained(
            "prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
        bert2bert.config.eos_token_id = tokenizer.sep_token_id
        bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
        bert2bert.config.max_length = 128

        train_dataset = datasets.load_dataset("cnn_dailymail",
                                              "3.0.0",
                                              split="train[:1%]")
        val_dataset = datasets.load_dataset("cnn_dailymail",
                                            "3.0.0",
                                            split="validation[:1%]")

        train_dataset = train_dataset.select(range(32))
        val_dataset = val_dataset.select(range(16))

        rouge = datasets.load_metric("rouge")

        batch_size = 4

        def _map_to_encoder_decoder_inputs(batch):
            # Tokenizer will automatically set [BOS] <text> [EOS]
            inputs = tokenizer(batch["article"],
                               padding="max_length",
                               truncation=True,
                               max_length=512)
            outputs = tokenizer(batch["highlights"],
                                padding="max_length",
                                truncation=True,
                                max_length=128)
            batch["input_ids"] = inputs.input_ids
            batch["attention_mask"] = inputs.attention_mask

            batch["decoder_input_ids"] = outputs.input_ids
            batch["labels"] = outputs.input_ids.copy()
            batch["labels"] = [[
                -100 if token == tokenizer.pad_token_id else token
                for token in labels
            ] for labels in batch["labels"]]
            batch["decoder_attention_mask"] = outputs.attention_mask

            assert all([len(x) == 512 for x in inputs.input_ids])
            assert all([len(x) == 128 for x in outputs.input_ids])

            return batch

        def _compute_metrics(pred):
            labels_ids = pred.label_ids
            pred_ids = pred.predictions

            # all unnecessary tokens are removed
            pred_str = tokenizer.batch_decode(pred_ids,
                                              skip_special_tokens=True)
            label_str = tokenizer.batch_decode(labels_ids,
                                               skip_special_tokens=True)

            rouge_output = rouge.compute(predictions=pred_str,
                                         references=label_str,
                                         rouge_types=["rouge2"])["rouge2"].mid

            return {
                "rouge2_precision": round(rouge_output.precision, 4),
                "rouge2_recall": round(rouge_output.recall, 4),
                "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
            }

        # map train dataset
        train_dataset = train_dataset.map(
            _map_to_encoder_decoder_inputs,
            batched=True,
            batch_size=batch_size,
            remove_columns=["article", "highlights"],
        )
        train_dataset.set_format(
            type="torch",
            columns=[
                "input_ids", "attention_mask", "decoder_input_ids",
                "decoder_attention_mask", "labels"
            ],
        )

        # same for validation dataset
        val_dataset = val_dataset.map(
            _map_to_encoder_decoder_inputs,
            batched=True,
            batch_size=batch_size,
            remove_columns=["article", "highlights"],
        )
        val_dataset.set_format(
            type="torch",
            columns=[
                "input_ids", "attention_mask", "decoder_input_ids",
                "decoder_attention_mask", "labels"
            ],
        )

        output_dir = self.get_auto_remove_tmp_dir()

        training_args = Seq2SeqTrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            predict_with_generate=True,
            evaluation_strategy="steps",
            do_train=True,
            do_eval=True,
            warmup_steps=0,
            eval_steps=2,
            logging_steps=2,
        )

        # instantiate trainer
        trainer = Seq2SeqTrainer(
            model=bert2bert,
            args=training_args,
            compute_metrics=_compute_metrics,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
        )

        # start training
        trainer.train()
示例#5
0
class RagTestMixin:

    all_model_classes = ((RagModel, RagTokenForGeneration,
                          RagSequenceForGeneration)
                         if is_torch_available() and is_datasets_available()
                         and is_faiss_available() else ())

    retrieval_vector_size = 32
    n_docs = 3
    max_combined_length = 16

    def setUp(self):
        self.tmpdirname = tempfile.mkdtemp()

        # DPR tok
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "[PAD]",
            "[MASK]",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
        ]
        dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
        os.makedirs(dpr_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(dpr_tokenizer_path,
                                       DPR_VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

        # BART tok
        vocab = [
            "l",
            "o",
            "w",
            "e",
            "r",
            "s",
            "t",
            "i",
            "d",
            "n",
            "\u0120",
            "\u0120l",
            "\u0120n",
            "\u0120lo",
            "\u0120low",
            "er",
            "\u0120lowest",
            "\u0120newer",
            "\u0120wider",
            "<unk>",
        ]
        vocab_tokens = dict(zip(vocab, range(len(vocab))))
        merges = [
            "#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""
        ]
        self.special_tokens_map = {"unk_token": "<unk>"}

        bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
        os.makedirs(bart_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(bart_tokenizer_path,
                                       BART_VOCAB_FILES_NAMES["vocab_file"])
        self.merges_file = os.path.join(bart_tokenizer_path,
                                        BART_VOCAB_FILES_NAMES["merges_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as fp:
            fp.write(json.dumps(vocab_tokens) + "\n")
        with open(self.merges_file, "w", encoding="utf-8") as fp:
            fp.write("\n".join(merges))

        t5_tokenizer = T5Tokenizer(T5_SAMPLE_VOCAB)
        t5_tokenizer_path = os.path.join(self.tmpdirname, "t5_tokenizer")
        t5_tokenizer.save_pretrained(t5_tokenizer_path)

    @cached_property
    def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
        return DPRQuestionEncoderTokenizer.from_pretrained(
            os.path.join(self.tmpdirname, "dpr_tokenizer"))

    @cached_property
    def bart_tokenizer(self) -> BartTokenizer:
        return BartTokenizer.from_pretrained(
            os.path.join(self.tmpdirname, "bart_tokenizer"))

    @cached_property
    def t5_tokenizer(self) -> BartTokenizer:
        return T5Tokenizer.from_pretrained(
            os.path.join(self.tmpdirname, "t5_tokenizer"))

    def tearDown(self):
        shutil.rmtree(self.tmpdirname)

    def get_retriever(self, config):
        dataset = Dataset.from_dict({
            "id": ["0", "1", "3"],
            "text": ["foo", "bar", "qux"],
            "title": ["Foo", "Bar", "Qux"],
            "embeddings": [
                np.ones(self.retrieval_vector_size),
                2 * np.ones(self.retrieval_vector_size),
                3 * np.ones(self.retrieval_vector_size),
            ],
        })
        dataset.add_faiss_index("embeddings",
                                string_factory="Flat",
                                metric_type=faiss.METRIC_INNER_PRODUCT)
        tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer
        with patch("transformers.retrieval_rag.load_dataset"
                   ) as mock_load_dataset:
            mock_load_dataset.return_value = dataset
            retriever = RagRetriever(
                config,
                question_encoder_tokenizer=self.dpr_tokenizer,
                generator_tokenizer=tokenizer,
            )
        return retriever

    def check_model_with_retriever(self, config, input_ids, attention_mask,
                                   decoder_input_ids, decoder_attention_mask,
                                   **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(
                config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], self.n_docs))

    def check_model_generate(self, config, input_ids, attention_mask,
                             decoder_input_ids, decoder_attention_mask,
                             **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes[1:]:
            model = model_class(
                config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model.generate(
                input_ids=input_ids,
                num_beams=2,
                num_return_sequences=2,
                decoder_start_token_id=config.generator.eos_token_id,
            )

            self.assertIsNotNone(outputs)

    def check_model_without_retriever(self, config, input_ids, attention_mask,
                                      decoder_input_ids,
                                      decoder_attention_mask, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config).to(torch_device)
            model.eval()
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(
                input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.cpu().detach().to(
                    torch.float32).numpy(),
                prefix=config.generator.prefix,
                return_tensors="pt",
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            # cast
            retrieved_doc_embeds = retrieved_doc_embeds.to(
                question_hidden_states)
            context_input_ids = context_input_ids.to(input_ids)
            context_attention_mask = context_attention_mask.to(input_ids)

            # compute doc_scores
            doc_scores = torch.bmm(question_hidden_states.unsqueeze(1),
                                   retrieved_doc_embeds.transpose(
                                       1, 2)).squeeze(1)

            outputs = model(
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], self.n_docs))

    def check_model_custom_n_docs(self, config, input_ids, attention_mask,
                                  decoder_input_ids, decoder_attention_mask,
                                  n_docs, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config).to(torch_device)
            model.eval()
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(
                input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.cpu().detach().to(
                    torch.float32).numpy(),
                prefix=config.generator.prefix,
                return_tensors="pt",
                n_docs=n_docs,
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            # cast
            retrieved_doc_embeds = retrieved_doc_embeds.to(
                question_hidden_states)
            context_input_ids = context_input_ids.to(input_ids)
            context_attention_mask = context_attention_mask.to(input_ids)

            # compute doc_scores
            doc_scores = torch.bmm(question_hidden_states.unsqueeze(1),
                                   retrieved_doc_embeds.transpose(
                                       1, 2)).squeeze(1)

            outputs = model(
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                n_docs=n_docs,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (n_docs * decoder_input_ids.shape[0], self.max_combined_length,
                 config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], n_docs))

    def check_model_with_mismatch_n_docs_value(self, config, input_ids,
                                               attention_mask,
                                               decoder_input_ids,
                                               decoder_attention_mask,
                                               retriever_n_docs,
                                               generator_n_docs, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config).to(torch_device)
            model.eval()
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(
                input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.cpu().detach().to(
                    torch.float32).numpy(),
                prefix=config.generator.prefix,
                return_tensors="pt",
                n_docs=retriever_n_docs,
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            # cast
            retrieved_doc_embeds = retrieved_doc_embeds.to(
                question_hidden_states)
            context_input_ids = context_input_ids.to(input_ids)
            context_attention_mask = context_attention_mask.to(input_ids)

            # compute doc_scores
            doc_scores = torch.bmm(question_hidden_states.unsqueeze(1),
                                   retrieved_doc_embeds.transpose(
                                       1, 2)).squeeze(1)

            self.assertRaises(
                AssertionError,
                model.__call__,
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                n_docs=generator_n_docs,
            )

    def check_model_with_encoder_outputs(self, config, input_ids,
                                         attention_mask, decoder_input_ids,
                                         decoder_attention_mask, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(
                config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            encoder_outputs = BaseModelOutput(
                outputs.generator_enc_last_hidden_state)

            # run only generator
            outputs = model(
                encoder_outputs=encoder_outputs,
                doc_scores=outputs.doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], self.n_docs))

    def test_model_with_retriever(self):
        inputs_dict = self.config_and_inputs
        self.check_model_with_retriever(**inputs_dict)

    def test_model_without_retriever(self):
        inputs_dict = self.config_and_inputs
        self.check_model_without_retriever(**inputs_dict)

    def test_model_with_encoder_outputs(self):
        inputs_dict = self.config_and_inputs
        self.check_model_with_encoder_outputs(**inputs_dict)

    def test_model_generate(self):
        inputs_dict = self.config_and_inputs
        self.check_model_generate(**inputs_dict)

    def test_model_with_custom_n_docs(self):
        inputs_dict = self.config_and_inputs
        inputs_dict["n_docs"] = 1
        self.check_model_custom_n_docs(**inputs_dict)

    def test_model_with_mismatch_n_docs_value(self):
        inputs_dict = self.config_and_inputs
        inputs_dict["retriever_n_docs"] = 3
        inputs_dict["generator_n_docs"] = 2
        self.check_model_with_mismatch_n_docs_value(**inputs_dict)
示例#6
0
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES

from .test_modeling_bart import ModelTester as BartModelTester
from .test_modeling_dpr import DPRModelTester
from .test_modeling_t5 import T5ModelTester

TOLERANCE = 1e-3

T5_SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                               "fixtures/test_sentencepiece.model")

if is_torch_available() and is_datasets_available() and is_faiss_available():
    import torch
    from datasets import Dataset

    import faiss
    from transformers import (
        AutoConfig,
        AutoModel,
        AutoModelForSeq2SeqLM,
        RagConfig,
        RagModel,
        RagRetriever,
        RagSequenceForGeneration,
        RagTokenForGeneration,
        RagTokenizer,
    )
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, require_datasets, slow


if is_datasets_available():
    import datasets


class Seq2seqTrainerTester(TestCasePlus):
    @slow
    @require_datasets
    def test_finetune_bert2bert(self):

        bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
        bert2bert.config.eos_token_id = tokenizer.sep_token_id
        bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
        bert2bert.config.max_length = 128