def test_recreate_finetuned_model(self, config, factory_func):
        """Imported finetuned models can be recreated via a factory function without fairseq."""
        num_out = 28
        batch_size, num_frames = 3, 1024

        original = self._get_model(config, num_out).eval()
        imported = import_fairseq_model(original).eval()

        reloaded = factory_func(num_out=num_out)
        reloaded.load_state_dict(imported.state_dict())
        reloaded.eval()

        # Without mask
        torch.manual_seed(0)
        x = torch.randn(batch_size, num_frames)
        ref, _ = imported(x)
        hyp, _ = reloaded(x)
        self.assertEqual(ref, hyp)

        # With mask
        lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
        ref, ref_lengths = imported(x, lengths)
        hyp, hyp_lengths = reloaded(x, lengths)
        self.assertEqual(ref, hyp)
        self.assertEqual(ref_lengths, hyp_lengths)
Beispiel #2
0
def _import_model(model):
    from torchaudio.models.wav2vec2.utils import import_fairseq_model

    if model.__class__.__name__ in ['HubertCtc', 'Wav2VecCtc']:
        model = model.w2v_encoder
    model = import_fairseq_model(model)
    return model
    def test_recreate_pretraining_model(self, config, factory_func):
        """Imported pretraining models can be recreated via a factory function without fairseq."""
        batch_size, num_frames = 3, 1024

        original = self._get_model(config).eval()
        imported = import_fairseq_model(original).eval()

        reloaded = factory_func()
        reloaded.load_state_dict(imported.state_dict())
        reloaded.eval()

        x = torch.randn(batch_size, num_frames)
        lengths = torch.randint(low=0, high=num_frames, size=[
            batch_size,
        ])
        # Without mask
        ref, _ = imported(x)
        hyp, _ = reloaded(x)
        self.assertEqual(ref, hyp)

        # With mask
        ref, ref_lengths = imported(x, lengths)
        hyp, hyp_lengths = reloaded(x, lengths)
        self.assertEqual(ref, hyp)
        self.assertEqual(ref_lengths, hyp_lengths)
    def test_import_hubert_pretraining_model(self, config, factory_func):
        """HuBERT pretraining models from fairseq can be imported and yields the same results"""
        batch_size, num_frames = 3, 1024

        torch.manual_seed(0)
        original = self._get_model(config).eval()
        imported = import_fairseq_model(original).eval()

        x = torch.randn(batch_size, num_frames)
        mask = torch.zeros_like(x)
        hyp, _ = imported.extract_features(x)

        # check the last layer
        ref, _ = original.extract_features(x,
                                           padding_mask=mask,
                                           output_layer=len(
                                               original.encoder.layers))
        atol = 3.0e-05 if factory_func is hubert_xlarge else 1.0e-5
        self.assertEqual(hyp[-1], ref, atol=atol, rtol=1.3e-6)

        # check the first layer
        ref, _ = original.extract_features(x,
                                           padding_mask=mask,
                                           output_layer=1)
        self.assertEqual(hyp[0], ref)
    def test_import_pretrained_model(self, config):
        """Pretrained wav2vec2 models from fairseq can be imported and yields the same results"""
        num_out = 28
        batch_size, num_frames = 3, 1024

        original = self._get_model(config, num_out).eval()
        imported = import_fairseq_model(original, 28).eval()

        x = torch.randn(batch_size, num_frames)
        ref = original.feature_extractor(x).transpose(1, 2)
        hyp, _ = imported.extract_features(x)
        self.assertEqual(ref, hyp)
    def test_import_wave2vec2_pretraining_model(self, config, _):
        """Wav2vec2 pretraining models from fairseq can be imported and yields the same results"""
        batch_size, num_frames = 3, 1024

        torch.manual_seed(0)
        original = self._get_model(config).eval()
        imported = import_fairseq_model(original).eval()

        x = torch.randn(batch_size, num_frames)
        hyp, _ = imported.extract_features(x)
        refs = original.extract_features(x,
                                         padding_mask=torch.zeros_like(x),
                                         layer=-1)
        for i, (ref, _) in enumerate(refs['layer_results']):
            self.assertEqual(hyp[i], ref.transpose(0, 1))
    def test_import_finetuned_model(self, config):
        """Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
        num_out = 28
        batch_size, num_frames = 3, 1024

        original = self._get_model(config, num_out).eval()
        imported = import_fairseq_model(original).eval()

        # Without mask
        x = torch.randn(batch_size, num_frames)
        ref = original(x, torch.zeros_like(x))['encoder_out'].transpose(0, 1)
        hyp, _ = imported(x)
        self.assertEqual(ref, hyp)

        # With mask
        lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
        mask = torch.arange(num_frames).expand(batch_size, num_frames) >= lengths[:, None]
        ref = original(x, mask)['encoder_out'].transpose(0, 1)
        hyp, output_lengths = imported(x, lengths)
        for i, l in enumerate(output_lengths):
            self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])