示例#1
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
    # Initialise PyTorch model
    config = BigBirdConfig.from_json_file(big_bird_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))

    if is_trivia_qa:
        model = BigBirdForQuestionAnswering(config)
    else:
        model = BigBirdForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    model.save_pretrained(pytorch_dump_path)
    def test_inference_question_answering(self):
        tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-base-trivia-itc")
        model = BigBirdForQuestionAnswering.from_pretrained(
            "google/bigbird-base-trivia-itc", attention_type="block_sparse", block_size=16, num_random_blocks=3
        )
        model.to(torch_device)

        context = "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and random attention approximates full attention, while being computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, BigBird has shown improved performance on various long document NLP tasks, such as question answering and summarization, compared to BERT or RoBERTa."

        question = [
            "Which is better for longer sequences- BigBird or BERT?",
            "What is the benefit of using BigBird over BERT?",
        ]
        inputs = tokenizer(
            question,
            [context, context],
            padding=True,
            return_tensors="pt",
            add_special_tokens=True,
            max_length=256,
            truncation=True,
        )

        inputs = {k: v.to(torch_device) for k, v in inputs.items()}

        start_logits, end_logits = model(**inputs).to_tuple()

        # fmt: off
        target_start_logits = torch.tensor(
            [[-8.9304, -10.3849, -14.4997, -9.6497, -13.9469, -7.8134, -8.9687, -13.3585, -9.7987, -13.8869, -9.2632, -8.9294, -13.6721, -7.3198, -9.5434, -11.2641, -14.3245, -9.5705, -12.7367, -8.6168, -11.083, -13.7573, -8.1151, -14.5329, -7.6876, -15.706, -12.8558, -9.1135, 8.0909, -3.1925, -11.5812, -9.4822], [-11.5595, -14.5591, -10.2978, -14.8445, -10.2092, -11.1899, -13.8356, -10.5644, -14.7706, -9.9841, -11.0052, -14.1862, -8.8173, -11.1098, -12.4686, -15.0531, -11.0196, -13.6614, -10.0236, -11.8151, -14.8744, -9.5123, -15.1605, -8.6472, -15.4184, -8.898, -9.6328, -7.0258, -11.3365, -14.4065, -10.2587, -8.9103]],  # noqa: E231
            device=torch_device,
        )
        target_end_logits = torch.tensor(
            [[-12.4131, -8.5959, -15.7163, -11.1524, -15.9913, -12.2038, -7.8902, -16.0296, -12.164, -16.5017, -13.3332, -6.9488, -15.7756, -13.8506, -11.0779, -9.2893, -15.0426, -10.1963, -17.3292, -12.2945, -11.5337, -16.4514, -9.1564, -17.5001, -9.1562, -16.2971, -13.3199, -7.5724, -5.1175, 7.2168, -10.3804, -11.9873], [-10.8654, -14.9967, -11.4144, -16.9189, -14.2673, -9.7068, -15.0182, -12.8846, -16.8716, -13.665, -10.3113, -15.1436, -14.9069, -13.3364, -11.2339, -16.0118, -11.8331, -17.0613, -13.8852, -12.4163, -16.8978, -10.7772, -17.2324, -10.6979, -16.9811, -10.3427, -9.497, -13.7104, -11.1107, -13.2936, -13.855, -14.1264]],  # noqa: E231
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(torch.allclose(start_logits[:, 64:96], target_start_logits, atol=1e-4))
        self.assertTrue(torch.allclose(end_logits[:, 64:96], target_end_logits, atol=1e-4))

        input_ids = inputs["input_ids"].tolist()
        answer = [
            input_ids[i][torch.argmax(start_logits, dim=-1)[i] : torch.argmax(end_logits, dim=-1)[i] + 1]
            for i in range(len(input_ids))
        ]
        answer = tokenizer.batch_decode(answer)

        self.assertTrue(answer == ["BigBird", "global attention"])
 def create_and_check_for_question_answering(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = BigBirdForQuestionAnswering(config=config)
     model.to(torch_device)
     model.eval()
     result = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         start_positions=sequence_labels,
         end_positions=sequence_labels,
     )
     self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
     self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
    def test_inference_question_answering(self):
        tokenizer = BigBirdTokenizer.from_pretrained(
            "google/bigbird-base-trivia-itc")
        model = BigBirdForQuestionAnswering.from_pretrained(
            "google/bigbird-base-trivia-itc",
            attention_type="block_sparse",
            block_size=16,
            num_random_blocks=3)
        model.to(torch_device)

        context = "🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch. Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a question answering dataset is the SQuAD dataset"

        question = [
            "How many pretrained models are available in 🤗 Transformers?",
            "🤗 Transformers provides interoperability between which frameworks?",
        ]
        inputs = tokenizer(
            question,
            [context, context],
            padding=True,
            return_tensors="pt",
            add_special_tokens=True,
            max_length=128,
            truncation=True,
        )

        inputs = {k: v.to(torch_device) for k, v in inputs.items()}

        start_logits, end_logits = model(**inputs).to_tuple()

        # fmt: off
        target_start_logits = torch.tensor(
            [[
                -9.5889, -10.2121, -14.2158, -11.1457, -10.7376, -7.3907,
                -10.2084, -9.5659, -15.0336, -8.6686, -9.1737, -11.1457,
                -13.4722, -6.3336, -9.6311, -8.4821, -15.141, -9.1226,
                -10.3328, -11.1457, -6.6793, -3.9627, 2.7126, -5.5607, -8.4625,
                -12.499, -11.4757, -9.6334, -4.0565, -10.0474, -7.4126,
                -13.5669
            ],
             [
                 -15.3796, -12.6863, -10.3951, -7.6706, -10.1808, -11.4401,
                 -15.5868, -12.7959, -11.0186, -12.6863, -14.2198, -8.1182,
                 -11.1353, -11.6512, -15.702, -12.8964, -12.5173, -12.6863,
                 -14.4133, -13.1532, -12.2846, -14.1572, -11.2747, -11.1159,
                 -11.5219, -13.1115, -11.8779, -13.989, -11.5234, -15.0459,
                 -10.0178, -12.9253
             ]],  # noqa: E231
            device=torch_device,
        )
        target_end_logits = torch.tensor(
            [[
                -12.4895, -10.9826, -13.8226, -11.9922, -13.2647, -12.4584,
                -10.6143, -9.4091, -16.844, -14.0393, -9.5914, -11.9922,
                -15.5142, -11.4073, -10.1064, -8.3961, -16.4374, -13.9323,
                -10.791, -11.9922, -8.736, -9.5672, 0.2844, -4.0976, -13.849,
                -11.8035, -12.7784, -14.1314, -7.4138, -10.5488, -8.0133,
                -14.8779
            ],
             [
                 -14.9831, -13.4818, -13.1566, -12.7259, -10.5892, -10.8605,
                 -17.2376, -15.9398, -12.8739, -13.4818, -16.6979, -13.3403,
                 -11.6416, -11.392, -16.9553, -15.723, -13.2643, -13.4818,
                 -16.2067, -15.6688, -15.0449, -15.1253, -15.1373, -12.385,
                 -13.3652, -15.9473, -14.9587, -15.5024, -13.1482, -16.6358,
                 -12.3908, -15.7493
             ]],  # noqa: E231
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(
            torch.allclose(start_logits[:, 64:96],
                           target_start_logits,
                           atol=1e-4))
        self.assertTrue(
            torch.allclose(end_logits[:, 64:96], target_end_logits, atol=1e-4))

        input_ids = inputs["input_ids"].tolist()
        answer = [
            input_ids[i][torch.argmax(start_logits, dim=-1)[i]:torch.
                         argmax(end_logits, dim=-1)[i] + 1]
            for i in range(len(input_ids))
        ]
        answer = tokenizer.batch_decode(answer)

        self.assertTrue(answer == ["32", "[SEP]"])