def test_compute_mask_indices(self):
        batch_size = 4
        sequence_length = 60
        mask_prob = 0.5
        mask_length = 1

        mask = _compute_mask_indices((batch_size, sequence_length), mask_prob,
                                     mask_length)

        self.assertListEqual(
            mask.sum(axis=-1).tolist(),
            [mask_prob * sequence_length for _ in range(batch_size)])

        attention_mask = torch.ones((batch_size, sequence_length),
                                    device=torch_device,
                                    dtype=torch.long)
        attention_mask[:, -sequence_length // 2:] = 0

        mask = _compute_mask_indices((batch_size, sequence_length),
                                     mask_prob,
                                     mask_length,
                                     attention_mask=attention_mask)

        self.assertListEqual(
            mask.sum(axis=-1).tolist(),
            [mask_prob * sequence_length // 2 for _ in range(batch_size)])
Example #2
0
    def test_compute_mask_indices_overlap(self):
        batch_size = 4
        sequence_length = 60
        mask_prob = 0.5
        mask_length = 4

        mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)

        # because of overlap there is a range of possible masks
        for batch_sum in mask.sum(axis=-1):
            self.assertIn(
                int(batch_sum),
                list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
            )

        attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
        attention_mask[:, -sequence_length // 2 :] = 0

        mask = _compute_mask_indices(
            (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
        )

        # because of overlap there is a range of possible masks
        for batch_sum in mask.sum(axis=-1):
            self.assertIn(
                int(batch_sum),
                list(
                    range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2))
                ),
            )
    def test_inference_integration(self):
        model = Wav2Vec2ForPreTraining.from_pretrained(
            "facebook/wav2vec2-base")
        model.to(torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base", return_attention_mask=True)
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(
                torch.tensor(inputs_dict["input_values"].shape[1])),
        )

        torch.manual_seed(0)
        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            device=inputs_dict["input_values"].device,
            min_masks=2,
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
                mask_time_indices=mask_time_indices,
            )

        # compute cosine similarity
        cosine_sim = torch.cosine_similarity(
            outputs.projected_states,
            outputs.projected_quantized_states,
            dim=-1)

        # retrieve cosine sim of masked features
        cosine_sim_masked = cosine_sim[mask_time_indices]

        # fmt: off
        expected_cosine_sim_masked = torch.tensor(
            [
                0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257,
                0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696,
                0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839,
                0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001,
                0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997
            ],
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(
            torch.allclose(cosine_sim_masked,
                           expected_cosine_sim_masked,
                           atol=1e-3))
    def test_model_for_pretraining(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        model = Wav2Vec2ForPreTraining(config).to(torch_device)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(
                torch.tensor(inputs_dict["input_values"].shape[1])),
        )

        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            device=inputs_dict["input_values"].device,
            min_masks=2,
        ).to(torch_device)

        loss = model(
            inputs_dict["input_values"],
            attention_mask=inputs_dict["attention_mask"],
            mask_time_indices=mask_time_indices,
        ).loss

        mask_time_indices[:, :mask_time_indices.shape[-1] // 2] = True
        loss_more_masked = model(
            inputs_dict["input_values"],
            attention_mask=inputs_dict["attention_mask"],
            mask_time_indices=mask_time_indices,
        ).loss

        # loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
        self.assertTrue(
            loss.detach().item() <= loss_more_masked.detach().item())
    def test_compute_mask_indices(self):
        batch_size = 4
        sequence_length = 60
        mask_prob = 0.5
        mask_length = 1

        mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)

        self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
    def forward(self, wav):
        """Takes an input waveform and return its corresponding wav2vec encoding.

        Arguments
        ---------
        wav : torch.Tensor (signal)
            A batch of audio signals to transform to features.
        """
        batch_size, raw_sequence_length = wav.shape

        if self.normalize_wav:
            wav = F.layer_norm(wav, wav.shape)

        sequence_length = self.model._get_feat_extract_output_lengths(
            raw_sequence_length)

        # 1. Compute the indices that will be masked
        mask_time_indices = _compute_mask_indices(
            (batch_size, sequence_length),
            mask_prob=self.mask_prob,
            mask_length=self.mask_length,
        )
        torch_mask_time_indices = torch.tensor(
            mask_time_indices,
            device=wav.device,
            dtype=torch.long,
        )

        # 2. Sample the negative samples from the entire sequence.
        # Fairseq does it only on the masked indices, but this only work if you
        # have long sentences. For more versatily, we sample on the entire sequence.
        # value.
        full_sentence_indices = np.ones((batch_size, sequence_length))

        # print(np.sum(mask_time_indices, axis=1))
        negative_sample_indices = torch.tensor(
            transformers.models.wav2vec2.modeling_wav2vec2.
            _sample_negative_indices(
                (batch_size, sequence_length),
                num_negatives=self.config.num_negatives,
                mask_time_indices=full_sentence_indices,
            ),
            device=wav.device,
            dtype=torch.long,
        )

        return (
            self.model(
                wav,
                mask_time_indices=torch_mask_time_indices,
                sampled_negative_indices=negative_sample_indices,
            ),
            torch_mask_time_indices,
        )
    def test_compute_mask_indices_overlap(self):
        batch_size = 4
        sequence_length = 80
        mask_prob = 0.5
        mask_length = 4

        mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)

        # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
        for batch_sum in mask.sum(axis=-1):
            self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
    def test_loss_pretraining(self):
        model = Wav2Vec2ForPreTraining.from_pretrained(
            "facebook/wav2vec2-base",
            attention_dropout=0.0,
            feat_proj_dropout=0.0,
            hidden_dropout=0.0,
            layerdrop=0.0,
        )
        model.to(torch_device).train()

        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base", return_attention_mask=True)
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(
                inputs_dict["input_values"].shape[1]),
        )

        torch.manual_seed(0)
        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            device=inputs_dict["input_values"].device,
            min_masks=2,
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
                mask_time_indices=mask_time_indices,
            )

        # check diversity loss
        num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
        diversity_loss = (num_codevectors -
                          outputs.codevector_perplexity) / num_codevectors
        self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)

        # check overall loss (contrastive loss + diversity loss)
        expected_loss = 62.5170

        self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # reformat list to dict and set to pytorch format
        batch = self.feature_extractor.pad(
            features,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        device = batch["input_values"].device
        batch_size = batch["input_values"].shape[0]

        mask_indices_seq_length = self.model._get_feat_extract_output_lengths(
            batch["input_values"].shape[-1])
        # make sure masked sequence length is a Python scalar
        mask_indices_seq_length = int(mask_indices_seq_length)

        # make sure that no loss is computed on padded inputs
        if batch.get("attention_mask") is not None:
            # compute real output lengths according to convolution formula
            batch[
                "sub_attention_mask"] = self.model._get_feature_vector_attention_mask(
                    mask_indices_seq_length, batch["attention_mask"])

        features_shape = (batch_size, mask_indices_seq_length)

        # sample randomly masked indices
        mask_time_indices = _compute_mask_indices(
            features_shape,
            self.model.config.mask_time_prob,
            self.model.config.mask_time_length,
            attention_mask=batch.get("sub_attention_mask"),
        )

        # sample negative indices
        sampled_negative_indices = _sample_negative_indices(
            features_shape,
            self.model.config.num_negatives,
            mask_time_indices=mask_time_indices,
        )
        batch["mask_time_indices"] = torch.tensor(mask_time_indices,
                                                  dtype=torch.long,
                                                  device=device)
        batch["sampled_negative_indices"] = torch.tensor(
            sampled_negative_indices, dtype=torch.long, device=device)

        return batch
    def test_compute_mask_indices_attn_mask_overlap(self):
        batch_size = 4
        sequence_length = 80
        mask_prob = 0.5
        mask_length = 4

        attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
        attention_mask[:2, sequence_length // 2 :] = 0

        mask = _compute_mask_indices(
            (batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
        )

        for batch_sum in mask.sum(axis=-1):
            self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)

        self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
    def test_compute_mask_indices_overlap(self):
        batch_size = 4
        sequence_length = 60
        mask_prob = 0.5
        mask_length = 4

        mask = _compute_mask_indices((batch_size, sequence_length), mask_prob,
                                     mask_length, torch_device)

        # because of overlap there is a range of possible masks
        for batch_sum in mask.sum(axis=-1):
            self.assertIn(
                int(batch_sum),
                list(
                    range(int(mask_prob // mask_length * sequence_length),
                          int(mask_prob * sequence_length))),
            )
    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # reformat list to dict and set to pytorch format
        batch = self.feature_extractor.pad(
            features,
            max_length=self.max_length,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        mask_indices_seq_length = self.model._get_feat_extract_output_lengths(
            batch["input_values"].shape[-1])

        batch_size = batch["input_values"].shape[0]

        # make sure that no loss is computed on padded inputs
        if batch["attention_mask"] is not None:
            # compute real output lengths according to convolution formula
            output_lengths = self.model._get_feat_extract_output_lengths(
                batch["attention_mask"].sum(-1)).to(torch.long)

            attention_mask = torch.zeros((batch_size, mask_indices_seq_length),
                                         dtype=torch.long,
                                         device=batch["input_values"].device)

            # these two operations makes sure that all values
            # before the output lengths indices are attended to
            attention_mask[(torch.arange(attention_mask.shape[0],
                                         device=batch["input_values"].device),
                            output_lengths - 1)] = 1
            attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(
                [-1]).bool()

        # sample randomly masked indices
        batch["mask_time_indices"] = _compute_mask_indices(
            (batch_size, mask_indices_seq_length),
            self.model.config.mask_time_prob,
            self.model.config.mask_time_length,
            device=batch["input_values"].device,
            attention_mask=attention_mask,
            min_masks=2,
        )

        return batch
Example #13
0
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # reformat list to dict and set to pytorch format
        batch = self.feature_extractor.pad(
            features,
            max_length=self.max_length,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])

        # sample randomly masked indices
        batch["mask_time_indices"] = _compute_mask_indices(
            (batch["input_values"].shape[0], mask_indices_seq_length),
            self.model.config.mask_time_prob,
            self.model.config.mask_time_length,
            device=batch["input_values"].device,
            min_masks=2,
        )

        return batch
    def test_inference_pretrained(self):
        model = Wav2Vec2ForPreTraining.from_pretrained(
            "facebook/wav2vec2-base")
        model.to(torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base", return_attention_mask=True)
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(
                torch.tensor(inputs_dict["input_values"].shape[1])),
        )

        torch.manual_seed(0)
        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            device=inputs_dict["input_values"].device,
            min_masks=2,
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
                mask_time_indices=mask_time_indices,
            )

        # compute cosine similarity
        cosine_sim = torch.cosine_similarity(
            outputs.projected_states,
            outputs.projected_quantized_states,
            dim=-1)

        # retrieve cosine sim of masked features
        cosine_sim_masked = cosine_sim[mask_time_indices]

        # ... now compare to randomly initialized model

        config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base")
        model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()

        with torch.no_grad():
            outputs_rand = model_rand(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
                mask_time_indices=mask_time_indices,
            )

        # compute cosine similarity
        cosine_sim_rand = torch.cosine_similarity(
            outputs_rand.projected_states,
            outputs_rand.projected_quantized_states,
            dim=-1)

        # retrieve cosine sim of masked features
        cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]

        # a pretrained wav2vec2 model has learned to predict the quantized latent states
        # => the cosine similarity between quantized states and predicted states > 0.5
        # a random wav2vec2 model has not learned to predict the quantized latent states
        # => the cosine similarity between quantized states and predicted states is very likely < 0.1
        self.assertTrue(cosine_sim_masked.mean().item() -
                        5 * cosine_sim_masked_rand.mean().item() > 0)