def test_compute_mask_indices(self): batch_size = 4 sequence_length = 60 mask_prob = 0.5 mask_length = 1 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) self.assertListEqual( mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)]) attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long) attention_mask[:, -sequence_length // 2:] = 0 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask) self.assertListEqual( mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self): batch_size = 4 sequence_length = 60 mask_prob = 0.5 mask_length = 4 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) # because of overlap there is a range of possible masks for batch_sum in mask.sum(axis=-1): self.assertIn( int(batch_sum), list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))), ) attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long) attention_mask[:, -sequence_length // 2 :] = 0 mask = _compute_mask_indices( (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask ) # because of overlap there is a range of possible masks for batch_sum in mask.sum(axis=-1): self.assertIn( int(batch_sum), list( range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2)) ), )
def test_inference_integration(self): model = Wav2Vec2ForPreTraining.from_pretrained( "facebook/wav2vec2-base") model.to(torch_device) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "facebook/wav2vec2-base", return_attention_mask=True) input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths( torch.tensor(inputs_dict["input_values"].shape[1])), ) torch.manual_seed(0) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, device=inputs_dict["input_values"].device, min_masks=2, ).to(torch_device) with torch.no_grad(): outputs = model( inputs_dict.input_values.to(torch_device), attention_mask=inputs_dict.attention_mask.to(torch_device), mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim = torch.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, dim=-1) # retrieve cosine sim of masked features cosine_sim_masked = cosine_sim[mask_time_indices] # fmt: off expected_cosine_sim_masked = torch.tensor( [ 0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997 ], device=torch_device, ) # fmt: on self.assertTrue( torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
def test_model_for_pretraining(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) model = Wav2Vec2ForPreTraining(config).to(torch_device) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths( torch.tensor(inputs_dict["input_values"].shape[1])), ) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, device=inputs_dict["input_values"].device, min_masks=2, ).to(torch_device) loss = model( inputs_dict["input_values"], attention_mask=inputs_dict["attention_mask"], mask_time_indices=mask_time_indices, ).loss mask_time_indices[:, :mask_time_indices.shape[-1] // 2] = True loss_more_masked = model( inputs_dict["input_values"], attention_mask=inputs_dict["attention_mask"], mask_time_indices=mask_time_indices, ).loss # loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted self.assertTrue( loss.detach().item() <= loss_more_masked.detach().item())
def test_compute_mask_indices(self): batch_size = 4 sequence_length = 60 mask_prob = 0.5 mask_length = 1 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
def forward(self, wav): """Takes an input waveform and return its corresponding wav2vec encoding. Arguments --------- wav : torch.Tensor (signal) A batch of audio signals to transform to features. """ batch_size, raw_sequence_length = wav.shape if self.normalize_wav: wav = F.layer_norm(wav, wav.shape) sequence_length = self.model._get_feat_extract_output_lengths( raw_sequence_length) # 1. Compute the indices that will be masked mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.mask_prob, mask_length=self.mask_length, ) torch_mask_time_indices = torch.tensor( mask_time_indices, device=wav.device, dtype=torch.long, ) # 2. Sample the negative samples from the entire sequence. # Fairseq does it only on the masked indices, but this only work if you # have long sentences. For more versatily, we sample on the entire sequence. # value. full_sentence_indices = np.ones((batch_size, sequence_length)) # print(np.sum(mask_time_indices, axis=1)) negative_sample_indices = torch.tensor( transformers.models.wav2vec2.modeling_wav2vec2. _sample_negative_indices( (batch_size, sequence_length), num_negatives=self.config.num_negatives, mask_time_indices=full_sentence_indices, ), device=wav.device, dtype=torch.long, ) return ( self.model( wav, mask_time_indices=torch_mask_time_indices, sampled_negative_indices=negative_sample_indices, ), torch_mask_time_indices, )
def test_compute_mask_indices_overlap(self): batch_size = 4 sequence_length = 80 mask_prob = 0.5 mask_length = 4 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_loss_pretraining(self): model = Wav2Vec2ForPreTraining.from_pretrained( "facebook/wav2vec2-base", attention_dropout=0.0, feat_proj_dropout=0.0, hidden_dropout=0.0, layerdrop=0.0, ) model.to(torch_device).train() feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "facebook/wav2vec2-base", return_attention_mask=True) input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths( inputs_dict["input_values"].shape[1]), ) torch.manual_seed(0) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, device=inputs_dict["input_values"].device, min_masks=2, ).to(torch_device) with torch.no_grad(): outputs = model( inputs_dict.input_values.to(torch_device), attention_mask=inputs_dict.attention_mask.to(torch_device), mask_time_indices=mask_time_indices, ) # check diversity loss num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3) # check overall loss (contrastive loss + diversity loss) expected_loss = 62.5170 self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
def __call__( self, features: List[Dict[str, Union[List[int], torch.Tensor]]] ) -> Dict[str, torch.Tensor]: # reformat list to dict and set to pytorch format batch = self.feature_extractor.pad( features, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) device = batch["input_values"].device batch_size = batch["input_values"].shape[0] mask_indices_seq_length = self.model._get_feat_extract_output_lengths( batch["input_values"].shape[-1]) # make sure masked sequence length is a Python scalar mask_indices_seq_length = int(mask_indices_seq_length) # make sure that no loss is computed on padded inputs if batch.get("attention_mask") is not None: # compute real output lengths according to convolution formula batch[ "sub_attention_mask"] = self.model._get_feature_vector_attention_mask( mask_indices_seq_length, batch["attention_mask"]) features_shape = (batch_size, mask_indices_seq_length) # sample randomly masked indices mask_time_indices = _compute_mask_indices( features_shape, self.model.config.mask_time_prob, self.model.config.mask_time_length, attention_mask=batch.get("sub_attention_mask"), ) # sample negative indices sampled_negative_indices = _sample_negative_indices( features_shape, self.model.config.num_negatives, mask_time_indices=mask_time_indices, ) batch["mask_time_indices"] = torch.tensor(mask_time_indices, dtype=torch.long, device=device) batch["sampled_negative_indices"] = torch.tensor( sampled_negative_indices, dtype=torch.long, device=device) return batch
def test_compute_mask_indices_attn_mask_overlap(self): batch_size = 4 sequence_length = 80 mask_prob = 0.5 mask_length = 4 attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device) attention_mask[:2, sequence_length // 2 :] = 0 mask = _compute_mask_indices( (batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask ) for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
def test_compute_mask_indices_overlap(self): batch_size = 4 sequence_length = 60 mask_prob = 0.5 mask_length = 4 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) # because of overlap there is a range of possible masks for batch_sum in mask.sum(axis=-1): self.assertIn( int(batch_sum), list( range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))), )
def __call__( self, features: List[Dict[str, Union[List[int], torch.Tensor]]] ) -> Dict[str, torch.Tensor]: # reformat list to dict and set to pytorch format batch = self.feature_extractor.pad( features, max_length=self.max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) mask_indices_seq_length = self.model._get_feat_extract_output_lengths( batch["input_values"].shape[-1]) batch_size = batch["input_values"].shape[0] # make sure that no loss is computed on padded inputs if batch["attention_mask"] is not None: # compute real output lengths according to convolution formula output_lengths = self.model._get_feat_extract_output_lengths( batch["attention_mask"].sum(-1)).to(torch.long) attention_mask = torch.zeros((batch_size, mask_indices_seq_length), dtype=torch.long, device=batch["input_values"].device) # these two operations makes sure that all values # before the output lengths indices are attended to attention_mask[(torch.arange(attention_mask.shape[0], device=batch["input_values"].device), output_lengths - 1)] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip( [-1]).bool() # sample randomly masked indices batch["mask_time_indices"] = _compute_mask_indices( (batch_size, mask_indices_seq_length), self.model.config.mask_time_prob, self.model.config.mask_time_length, device=batch["input_values"].device, attention_mask=attention_mask, min_masks=2, ) return batch
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # reformat list to dict and set to pytorch format batch = self.feature_extractor.pad( features, max_length=self.max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) # sample randomly masked indices batch["mask_time_indices"] = _compute_mask_indices( (batch["input_values"].shape[0], mask_indices_seq_length), self.model.config.mask_time_prob, self.model.config.mask_time_length, device=batch["input_values"].device, min_masks=2, ) return batch
def test_inference_pretrained(self): model = Wav2Vec2ForPreTraining.from_pretrained( "facebook/wav2vec2-base") model.to(torch_device) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "facebook/wav2vec2-base", return_attention_mask=True) input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths( torch.tensor(inputs_dict["input_values"].shape[1])), ) torch.manual_seed(0) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, device=inputs_dict["input_values"].device, min_masks=2, ).to(torch_device) with torch.no_grad(): outputs = model( inputs_dict.input_values.to(torch_device), attention_mask=inputs_dict.attention_mask.to(torch_device), mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim = torch.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, dim=-1) # retrieve cosine sim of masked features cosine_sim_masked = cosine_sim[mask_time_indices] # ... now compare to randomly initialized model config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base") model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval() with torch.no_grad(): outputs_rand = model_rand( inputs_dict.input_values.to(torch_device), attention_mask=inputs_dict.attention_mask.to(torch_device), mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim_rand = torch.cosine_similarity( outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1) # retrieve cosine sim of masked features cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices] # a pretrained wav2vec2 model has learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states > 0.5 # a random wav2vec2 model has not learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states is very likely < 0.1 self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)