Ejemplo n.º 1
0
    def test_inference_pretrained_batched(self):
        model = SEWModel.from_pretrained("asapp/sew-tiny-100k").to(torch_device)
        processor = Wav2Vec2FeatureExtractor.from_pretrained("asapp/sew-tiny-100k")

        input_speech = self._load_datasamples(2)

        inputs = processor(input_speech, return_tensors="pt", padding=True)

        input_values = inputs.input_values.to(torch_device)

        with torch.no_grad():
            outputs = model(input_values).last_hidden_state

        # expected outputs taken from the original SEW implementation
        expected_outputs_first = torch.tensor(
            [
                [
                    [0.1509, 0.5372, 0.3061, -0.1694],
                    [-0.1700, 0.5764, 0.2753, -0.1299],
                    [0.1281, 0.7949, 0.2342, -0.1624],
                    [-0.1627, 0.6710, 0.2215, -0.1317],
                ],
                [
                    [0.0408, 1.4355, 0.8605, -0.0968],
                    [0.0393, 1.2368, 0.6826, 0.0364],
                    [-0.1269, 1.9215, 1.1677, -0.1297],
                    [-0.1654, 1.6524, 0.6877, -0.0196],
                ],
            ],
            device=torch_device,
        )
        expected_outputs_last = torch.tensor(
            [
                [
                    [1.3379, -0.1450, -0.1500, -0.0515],
                    [0.8364, -0.1680, -0.1248, -0.0689],
                    [1.2791, -0.1507, -0.1523, -0.0564],
                    [0.8208, -0.1690, -0.1199, -0.0751],
                ],
                [
                    [0.6959, -0.0861, -0.1235, -0.0861],
                    [0.4700, -0.1686, -0.1141, -0.1199],
                    [1.0776, -0.1137, -0.0124, -0.0472],
                    [0.5774, -0.1675, -0.0376, -0.0823],
                ],
            ],
            device=torch_device,
        )
        expected_output_sum = 62146.7422

        self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
        self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
        self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
Ejemplo n.º 2
0
 def test_model_from_pretrained(self):
     model = SEWModel.from_pretrained("asapp/sew-tiny-100k")
     self.assertIsNotNone(model)