def test_inference_block_sparse_pretraining(self):
        model = BigBirdForPreTraining.from_pretrained(
            "google/bigbird-roberta-base", attention_type="block_sparse")
        model.to(torch_device)

        input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024],
                                 dtype=torch.long,
                                 device=torch_device)
        outputs = model(input_ids)
        prediction_logits = outputs.prediction_logits
        seq_relationship_logits = outputs.seq_relationship_logits

        self.assertEqual(prediction_logits.shape, torch.Size((1, 4096, 50358)))
        self.assertEqual(seq_relationship_logits.shape, torch.Size((1, 2)))

        expected_prediction_logits_slice = torch.tensor(
            [
                [-0.2420, -0.6048, -0.0614, 7.8422],
                [-0.0596, -0.0104, -1.8408, 9.3352],
                [1.0588, 0.7999, 5.0770, 8.7555],
                [-0.1385, -1.7199, -1.7613, 6.1094],
            ],
            device=torch_device,
        )
        self.assertTrue(
            torch.allclose(prediction_logits[0, 128:132, 128:132],
                           expected_prediction_logits_slice,
                           atol=1e-4))

        expected_seq_relationship_logits = torch.tensor([[58.8196, 56.3629]],
                                                        device=torch_device)
        self.assertTrue(
            torch.allclose(seq_relationship_logits,
                           expected_seq_relationship_logits,
                           atol=1e-4))
    def test_inference_full_pretraining(self):
        model = BigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base", attention_type="original_full")
        model.to(torch_device)

        input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device)
        outputs = model(input_ids)
        prediction_logits = outputs.prediction_logits
        seq_relationship_logits = outputs.seq_relationship_logits

        self.assertEqual(prediction_logits.shape, torch.Size((1, 512 * 4, 50358)))
        self.assertEqual(seq_relationship_logits.shape, torch.Size((1, 2)))

        expected_prediction_logits_slice = torch.tensor(
            [
                [0.1499, -1.1217, 0.1990, 8.4499],
                [-2.7757, -3.0687, -4.8577, 7.5156],
                [1.5446, 0.1982, 4.3016, 10.4281],
                [-1.3705, -4.0130, -3.9629, 5.1526],
            ],
            device=torch_device,
        )
        self.assertTrue(
            torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4)
        )

        expected_seq_relationship_logits = torch.tensor([[41.4503, 41.2406]], device=torch_device)
        self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
Esempio n. 3
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
    # Initialise PyTorch model
    config = BigBirdConfig.from_json_file(big_bird_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))

    if is_trivia_qa:
        model = BigBirdForQuestionAnswering(config)
    else:
        model = BigBirdForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    model.save_pretrained(pytorch_dump_path)
 def create_and_check_for_pretraining(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = BigBirdForPreTraining(config=config)
     model.to(torch_device)
     model.eval()
     result = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         labels=token_labels,
         next_sentence_label=sequence_labels,
     )
     self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
     self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, config.num_labels))
 def test_model_from_pretrained(self):
     for model_name in BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
         model = BigBirdForPreTraining.from_pretrained(model_name)
         self.assertIsNotNone(model)