def test_inference_block_sparse_pretraining(self):
        model = BigBirdForPreTraining.from_pretrained(
            "google/bigbird-roberta-base", attention_type="block_sparse")
        model.to(torch_device)

        input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024],
                                 dtype=torch.long,
                                 device=torch_device)
        outputs = model(input_ids)
        prediction_logits = outputs.prediction_logits
        seq_relationship_logits = outputs.seq_relationship_logits

        self.assertEqual(prediction_logits.shape, torch.Size((1, 4096, 50358)))
        self.assertEqual(seq_relationship_logits.shape, torch.Size((1, 2)))

        expected_prediction_logits_slice = torch.tensor(
            [
                [-0.2420, -0.6048, -0.0614, 7.8422],
                [-0.0596, -0.0104, -1.8408, 9.3352],
                [1.0588, 0.7999, 5.0770, 8.7555],
                [-0.1385, -1.7199, -1.7613, 6.1094],
            ],
            device=torch_device,
        )
        self.assertTrue(
            torch.allclose(prediction_logits[0, 128:132, 128:132],
                           expected_prediction_logits_slice,
                           atol=1e-4))

        expected_seq_relationship_logits = torch.tensor([[58.8196, 56.3629]],
                                                        device=torch_device)
        self.assertTrue(
            torch.allclose(seq_relationship_logits,
                           expected_seq_relationship_logits,
                           atol=1e-4))
Exemplo n.º 2
0
    def test_inference_full_pretraining(self):
        model = BigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base", attention_type="original_full")
        model.to(torch_device)

        input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device)
        outputs = model(input_ids)
        prediction_logits = outputs.prediction_logits
        seq_relationship_logits = outputs.seq_relationship_logits

        self.assertEqual(prediction_logits.shape, torch.Size((1, 512 * 4, 50358)))
        self.assertEqual(seq_relationship_logits.shape, torch.Size((1, 2)))

        expected_prediction_logits_slice = torch.tensor(
            [
                [0.1499, -1.1217, 0.1990, 8.4499],
                [-2.7757, -3.0687, -4.8577, 7.5156],
                [1.5446, 0.1982, 4.3016, 10.4281],
                [-1.3705, -4.0130, -3.9629, 5.1526],
            ],
            device=torch_device,
        )
        self.assertTrue(
            torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4)
        )

        expected_seq_relationship_logits = torch.tensor([[41.4503, 41.2406]], device=torch_device)
        self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
 def test_model_from_pretrained(self):
     for model_name in BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
         model = BigBirdForPreTraining.from_pretrained(model_name)
         self.assertIsNotNone(model)