Пример #1
0
    def test_run_wav2vec2_pretraining(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_wav2vec2_pretraining_no_trainer.py
            --output_dir {tmp_dir}
            --model_name_or_path hf-internal-testing/tiny-random-wav2vec2
            --dataset_name hf-internal-testing/librispeech_asr_dummy
            --dataset_config_names clean
            --dataset_split_names validation
            --learning_rate 1e-4
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 2
            --preprocessing_num_workers 16
            --max_train_steps 5
            --validation_split_percentage 5
            --seed 42
        """.split()

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        with patch.object(sys, "argv", testargs):
            run_wav2vec2_pretraining_no_trainer.main()
            model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
            self.assertIsNotNone(model)
    def test_inference_integration(self):
        model = Wav2Vec2ForPreTraining.from_pretrained(
            "facebook/wav2vec2-base")
        model.to(torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base", return_attention_mask=True)
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(
                torch.tensor(inputs_dict["input_values"].shape[1])),
        )

        torch.manual_seed(0)
        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            device=inputs_dict["input_values"].device,
            min_masks=2,
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
                mask_time_indices=mask_time_indices,
            )

        # compute cosine similarity
        cosine_sim = torch.cosine_similarity(
            outputs.projected_states,
            outputs.projected_quantized_states,
            dim=-1)

        # retrieve cosine sim of masked features
        cosine_sim_masked = cosine_sim[mask_time_indices]

        # fmt: off
        expected_cosine_sim_masked = torch.tensor(
            [
                0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257,
                0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696,
                0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839,
                0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001,
                0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997
            ],
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(
            torch.allclose(cosine_sim_masked,
                           expected_cosine_sim_masked,
                           atol=1e-3))
    def test_loss_pretraining(self):
        model = Wav2Vec2ForPreTraining.from_pretrained(
            "facebook/wav2vec2-base",
            attention_dropout=0.0,
            feat_proj_dropout=0.0,
            hidden_dropout=0.0,
            layerdrop=0.0,
        )
        model.to(torch_device).train()

        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base", return_attention_mask=True)
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(
                inputs_dict["input_values"].shape[1]),
        )

        torch.manual_seed(0)
        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            device=inputs_dict["input_values"].device,
            min_masks=2,
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
                mask_time_indices=mask_time_indices,
            )

        # check diversity loss
        num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
        diversity_loss = (num_codevectors -
                          outputs.codevector_perplexity) / num_codevectors
        self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)

        # check overall loss (contrastive loss + diversity loss)
        expected_loss = 62.5170

        self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
    def test_inference_pretrained(self):
        model = Wav2Vec2ForPreTraining.from_pretrained(
            "facebook/wav2vec2-base")
        model.to(torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base", return_attention_mask=True)
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(
                torch.tensor(inputs_dict["input_values"].shape[1])),
        )

        torch.manual_seed(0)
        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            device=inputs_dict["input_values"].device,
            min_masks=2,
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
                mask_time_indices=mask_time_indices,
            )

        # compute cosine similarity
        cosine_sim = torch.cosine_similarity(
            outputs.projected_states,
            outputs.projected_quantized_states,
            dim=-1)

        # retrieve cosine sim of masked features
        cosine_sim_masked = cosine_sim[mask_time_indices]

        # ... now compare to randomly initialized model

        config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base")
        model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()

        with torch.no_grad():
            outputs_rand = model_rand(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
                mask_time_indices=mask_time_indices,
            )

        # compute cosine similarity
        cosine_sim_rand = torch.cosine_similarity(
            outputs_rand.projected_states,
            outputs_rand.projected_quantized_states,
            dim=-1)

        # retrieve cosine sim of masked features
        cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]

        # a pretrained wav2vec2 model has learned to predict the quantized latent states
        # => the cosine similarity between quantized states and predicted states > 0.5
        # a random wav2vec2 model has not learned to predict the quantized latent states
        # => the cosine similarity between quantized states and predicted states is very likely < 0.1
        self.assertTrue(cosine_sim_masked.mean().item() -
                        5 * cosine_sim_masked_rand.mean().item() > 0)