def test_inference_classification_head(self):
        model = SqueezeBertForSequenceClassification.from_pretrained("squeezebert/squeezebert-mnli")

        input_ids = torch.tensor([[1, 29414, 232, 328, 740, 1140, 12695, 69, 13, 1588, 2]])
        output = model(input_ids)[0]
        expected_shape = torch.Size((1, 3))
        self.assertEqual(output.shape, expected_shape)
        expected_tensor = torch.tensor([[0.6401, -0.0349, -0.6041]])
        self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
 def create_and_check_squeezebert_for_sequence_classification(
     self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     config.num_labels = self.num_labels
     model = SqueezeBertForSequenceClassification(config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
     self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))