def test_inference_block_sparse_pretraining(self): model = BigBirdForPreTraining.from_pretrained( "google/bigbird-roberta-base", attention_type="block_sparse") model.to(torch_device) input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024], dtype=torch.long, device=torch_device) outputs = model(input_ids) prediction_logits = outputs.prediction_logits seq_relationship_logits = outputs.seq_relationship_logits self.assertEqual(prediction_logits.shape, torch.Size((1, 4096, 50358))) self.assertEqual(seq_relationship_logits.shape, torch.Size((1, 2))) expected_prediction_logits_slice = torch.tensor( [ [-0.2420, -0.6048, -0.0614, 7.8422], [-0.0596, -0.0104, -1.8408, 9.3352], [1.0588, 0.7999, 5.0770, 8.7555], [-0.1385, -1.7199, -1.7613, 6.1094], ], device=torch_device, ) self.assertTrue( torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4)) expected_seq_relationship_logits = torch.tensor([[58.8196, 56.3629]], device=torch_device) self.assertTrue( torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
def test_inference_full_pretraining(self): model = BigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base", attention_type="original_full") model.to(torch_device) input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device) outputs = model(input_ids) prediction_logits = outputs.prediction_logits seq_relationship_logits = outputs.seq_relationship_logits self.assertEqual(prediction_logits.shape, torch.Size((1, 512 * 4, 50358))) self.assertEqual(seq_relationship_logits.shape, torch.Size((1, 2))) expected_prediction_logits_slice = torch.tensor( [ [0.1499, -1.1217, 0.1990, 8.4499], [-2.7757, -3.0687, -4.8577, 7.5156], [1.5446, 0.1982, 4.3016, 10.4281], [-1.3705, -4.0130, -3.9629, 5.1526], ], device=torch_device, ) self.assertTrue( torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4) ) expected_seq_relationship_logits = torch.tensor([[41.4503, 41.2406]], device=torch_device) self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa): # Initialise PyTorch model config = BigBirdConfig.from_json_file(big_bird_config_file) print("Building PyTorch model from configuration: {}".format(str(config))) if is_trivia_qa: model = BigBirdForQuestionAnswering(config) else: model = BigBirdForPreTraining(config) # Load weights from tf checkpoint load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa) # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") model.save_pretrained(pytorch_dump_path)
def create_and_check_for_pretraining( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = BigBirdForPreTraining(config=config) model.to(torch_device) model.eval() result = model( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels, next_sentence_label=sequence_labels, ) self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, config.num_labels))
def test_model_from_pretrained(self): for model_name in BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = BigBirdForPreTraining.from_pretrained(model_name) self.assertIsNotNone(model)