def test_inference_classification_head(self): model = SqueezeBertForSequenceClassification.from_pretrained("squeezebert/squeezebert-mnli") input_ids = torch.tensor([[1, 29414, 232, 328, 740, 1140, 12695, 69, 13, 1588, 2]]) output = model(input_ids)[0] expected_shape = torch.Size((1, 3)) self.assertEqual(output.shape, expected_shape) expected_tensor = torch.tensor([[0.6401, -0.0349, -0.6041]]) self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
def create_and_check_squeezebert_for_sequence_classification( self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels model = SqueezeBertForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))