def test_fill_mask(self):
        tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
        model = BigBirdForMaskedLM.from_pretrained("google/bigbird-roberta-base")
        model.to(torch_device)

        input_ids = tokenizer("The goal of life is [MASK] .", return_tensors="pt").input_ids.to(torch_device)
        logits = model(input_ids).logits

        # [MASK] is token at 6th position
        pred_token = tokenizer.decode(torch.argmax(logits[0, 6:7], axis=-1))
        self.assertEqual(pred_token, "happiness")
    def test_tokenizer_inference(self):
        tokenizer = BigBirdTokenizer.from_pretrained(
            "google/bigbird-roberta-base")
        model = BigBirdModel.from_pretrained("google/bigbird-roberta-base",
                                             attention_type="block_sparse",
                                             num_random_blocks=3,
                                             block_size=16)
        model.to(torch_device)

        text = [
            "Transformer-based models are unable to process long sequences due to their self-attention operation,"
            " which scales quadratically with the sequence length. To address this limitation, we introduce the"
            " Longformer with an attention mechanism that scales linearly with sequence length, making it easy to"
            " process documents of thousands of tokens or longer. Longformer’s attention mechanism is a drop-in"
            " replacement for the standard self-attention and combines a local windowed attention with a task"
            " motivated global attention. Following prior work on long-sequence transformers, we evaluate Longformer"
            " on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In"
            " contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream"
            " tasks. Our pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new"
            " state-of-the-art results on WikiHop and TriviaQA."
        ]
        inputs = tokenizer(text)

        for k in inputs:
            inputs[k] = torch.tensor(inputs[k],
                                     device=torch_device,
                                     dtype=torch.long)

        prediction = model(**inputs)
        prediction = prediction[0]

        self.assertEqual(prediction.shape, torch.Size((1, 199, 768)))

        expected_prediction = torch.tensor(
            [
                [-0.0213, -0.2213, -0.0061, 0.0687],
                [0.0977, 0.1858, 0.2374, 0.0483],
                [0.2112, -0.2524, 0.5793, 0.0967],
                [0.2473, -0.5070, -0.0630, 0.2174],
                [0.2885, 0.1139, 0.6071, 0.2991],
                [0.2328, -0.2373, 0.3648, 0.1058],
                [0.2517, -0.0689, 0.0555, 0.0880],
                [0.1021, -0.1495, -0.0635, 0.1891],
                [0.0591, -0.0722, 0.2243, 0.2432],
                [-0.2059, -0.2679, 0.3225, 0.6183],
                [0.2280, -0.2618, 0.1693, 0.0103],
                [0.0183, -0.1375, 0.2284, -0.1707],
            ],
            device=torch_device,
        )
        self.assertTrue(
            torch.allclose(prediction[0, 52:64, 320:324],
                           expected_prediction,
                           atol=1e-4))
    def test_inference_question_answering(self):
        tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-base-trivia-itc")
        model = BigBirdForQuestionAnswering.from_pretrained(
            "google/bigbird-base-trivia-itc", attention_type="block_sparse", block_size=16, num_random_blocks=3
        )
        model.to(torch_device)

        context = "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and random attention approximates full attention, while being computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, BigBird has shown improved performance on various long document NLP tasks, such as question answering and summarization, compared to BERT or RoBERTa."

        question = [
            "Which is better for longer sequences- BigBird or BERT?",
            "What is the benefit of using BigBird over BERT?",
        ]
        inputs = tokenizer(
            question,
            [context, context],
            padding=True,
            return_tensors="pt",
            add_special_tokens=True,
            max_length=256,
            truncation=True,
        )

        inputs = {k: v.to(torch_device) for k, v in inputs.items()}

        start_logits, end_logits = model(**inputs).to_tuple()

        # fmt: off
        target_start_logits = torch.tensor(
            [[-8.9304, -10.3849, -14.4997, -9.6497, -13.9469, -7.8134, -8.9687, -13.3585, -9.7987, -13.8869, -9.2632, -8.9294, -13.6721, -7.3198, -9.5434, -11.2641, -14.3245, -9.5705, -12.7367, -8.6168, -11.083, -13.7573, -8.1151, -14.5329, -7.6876, -15.706, -12.8558, -9.1135, 8.0909, -3.1925, -11.5812, -9.4822], [-11.5595, -14.5591, -10.2978, -14.8445, -10.2092, -11.1899, -13.8356, -10.5644, -14.7706, -9.9841, -11.0052, -14.1862, -8.8173, -11.1098, -12.4686, -15.0531, -11.0196, -13.6614, -10.0236, -11.8151, -14.8744, -9.5123, -15.1605, -8.6472, -15.4184, -8.898, -9.6328, -7.0258, -11.3365, -14.4065, -10.2587, -8.9103]],  # noqa: E231
            device=torch_device,
        )
        target_end_logits = torch.tensor(
            [[-12.4131, -8.5959, -15.7163, -11.1524, -15.9913, -12.2038, -7.8902, -16.0296, -12.164, -16.5017, -13.3332, -6.9488, -15.7756, -13.8506, -11.0779, -9.2893, -15.0426, -10.1963, -17.3292, -12.2945, -11.5337, -16.4514, -9.1564, -17.5001, -9.1562, -16.2971, -13.3199, -7.5724, -5.1175, 7.2168, -10.3804, -11.9873], [-10.8654, -14.9967, -11.4144, -16.9189, -14.2673, -9.7068, -15.0182, -12.8846, -16.8716, -13.665, -10.3113, -15.1436, -14.9069, -13.3364, -11.2339, -16.0118, -11.8331, -17.0613, -13.8852, -12.4163, -16.8978, -10.7772, -17.2324, -10.6979, -16.9811, -10.3427, -9.497, -13.7104, -11.1107, -13.2936, -13.855, -14.1264]],  # noqa: E231
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(torch.allclose(start_logits[:, 64:96], target_start_logits, atol=1e-4))
        self.assertTrue(torch.allclose(end_logits[:, 64:96], target_end_logits, atol=1e-4))

        input_ids = inputs["input_ids"].tolist()
        answer = [
            input_ids[i][torch.argmax(start_logits, dim=-1)[i] : torch.argmax(end_logits, dim=-1)[i] + 1]
            for i in range(len(input_ids))
        ]
        answer = tokenizer.batch_decode(answer)

        self.assertTrue(answer == ["BigBird", "global attention"])
    def test_tokenizer_inference(self):
        tokenizer = BigBirdTokenizer.from_pretrained(
            "google/bigbird-roberta-base")
        model = BigBirdModel.from_pretrained("google/bigbird-roberta-base",
                                             attention_type="block_sparse",
                                             num_random_blocks=3,
                                             block_size=16)
        model.to(torch_device)

        text = [
            'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth ... This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth ,, I was born in 92000, and this is falsé.'
        ]
        inputs = tokenizer(text)

        for k in inputs:
            inputs[k] = torch.tensor(inputs[k],
                                     device=torch_device,
                                     dtype=torch.long)

        prediction = model(**inputs)
        prediction = prediction[0]

        self.assertEqual(prediction.shape, torch.Size((1, 128, 768)))

        expected_prediction = torch.tensor(
            [
                [-0.0745, 0.0689, -0.1126, -0.0610],
                [-0.0343, 0.0111, -0.0269, -0.0858],
                [0.1150, 0.0896, 0.0492, 0.0149],
                [-0.0657, 0.2035, 0.0444, -0.0535],
                [0.1143, 0.0465, 0.1583, -0.1855],
                [-0.0216, 0.0807, 0.0536, 0.1371],
                [-0.1879, 0.0097, -0.1916, 0.1701],
                [0.7616, 0.1240, 0.0669, 0.2588],
                [0.1096, -0.1810, -0.1987, 0.0445],
                [0.1810, -0.3608, -0.0081, 0.1764],
                [-0.0472, 0.0460, 0.0976, -0.0021],
                [-0.0274, -0.3274, -0.0788, 0.0465],
            ],
            device=torch_device,
        )
        self.assertTrue(
            torch.allclose(prediction[0, 52:64, 320:324],
                           expected_prediction,
                           atol=1e-4))
    def test_inference_question_answering(self):
        tokenizer = BigBirdTokenizer.from_pretrained(
            "google/bigbird-base-trivia-itc")
        model = BigBirdForQuestionAnswering.from_pretrained(
            "google/bigbird-base-trivia-itc",
            attention_type="block_sparse",
            block_size=16,
            num_random_blocks=3)
        model.to(torch_device)

        context = "🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch. Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a question answering dataset is the SQuAD dataset"

        question = [
            "How many pretrained models are available in 🤗 Transformers?",
            "🤗 Transformers provides interoperability between which frameworks?",
        ]
        inputs = tokenizer(
            question,
            [context, context],
            padding=True,
            return_tensors="pt",
            add_special_tokens=True,
            max_length=128,
            truncation=True,
        )

        inputs = {k: v.to(torch_device) for k, v in inputs.items()}

        start_logits, end_logits = model(**inputs).to_tuple()

        # fmt: off
        target_start_logits = torch.tensor(
            [[
                -9.5889, -10.2121, -14.2158, -11.1457, -10.7376, -7.3907,
                -10.2084, -9.5659, -15.0336, -8.6686, -9.1737, -11.1457,
                -13.4722, -6.3336, -9.6311, -8.4821, -15.141, -9.1226,
                -10.3328, -11.1457, -6.6793, -3.9627, 2.7126, -5.5607, -8.4625,
                -12.499, -11.4757, -9.6334, -4.0565, -10.0474, -7.4126,
                -13.5669
            ],
             [
                 -15.3796, -12.6863, -10.3951, -7.6706, -10.1808, -11.4401,
                 -15.5868, -12.7959, -11.0186, -12.6863, -14.2198, -8.1182,
                 -11.1353, -11.6512, -15.702, -12.8964, -12.5173, -12.6863,
                 -14.4133, -13.1532, -12.2846, -14.1572, -11.2747, -11.1159,
                 -11.5219, -13.1115, -11.8779, -13.989, -11.5234, -15.0459,
                 -10.0178, -12.9253
             ]],  # noqa: E231
            device=torch_device,
        )
        target_end_logits = torch.tensor(
            [[
                -12.4895, -10.9826, -13.8226, -11.9922, -13.2647, -12.4584,
                -10.6143, -9.4091, -16.844, -14.0393, -9.5914, -11.9922,
                -15.5142, -11.4073, -10.1064, -8.3961, -16.4374, -13.9323,
                -10.791, -11.9922, -8.736, -9.5672, 0.2844, -4.0976, -13.849,
                -11.8035, -12.7784, -14.1314, -7.4138, -10.5488, -8.0133,
                -14.8779
            ],
             [
                 -14.9831, -13.4818, -13.1566, -12.7259, -10.5892, -10.8605,
                 -17.2376, -15.9398, -12.8739, -13.4818, -16.6979, -13.3403,
                 -11.6416, -11.392, -16.9553, -15.723, -13.2643, -13.4818,
                 -16.2067, -15.6688, -15.0449, -15.1253, -15.1373, -12.385,
                 -13.3652, -15.9473, -14.9587, -15.5024, -13.1482, -16.6358,
                 -12.3908, -15.7493
             ]],  # noqa: E231
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(
            torch.allclose(start_logits[:, 64:96],
                           target_start_logits,
                           atol=1e-4))
        self.assertTrue(
            torch.allclose(end_logits[:, 64:96], target_end_logits, atol=1e-4))

        input_ids = inputs["input_ids"].tolist()
        answer = [
            input_ids[i][torch.argmax(start_logits, dim=-1)[i]:torch.
                         argmax(end_logits, dim=-1)[i] + 1]
            for i in range(len(input_ids))
        ]
        answer = tokenizer.batch_decode(answer)

        self.assertTrue(answer == ["32", "[SEP]"])