def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: # reformat list to dict and set to pytorch format batch = self.feature_extractor.pad( features, max_length=self.max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="np", ) mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) # sample randomly masked indices batch["mask_time_indices"] = _compute_mask_indices( (batch["input_values"].shape[0], mask_indices_seq_length), self.model.config.mask_time_prob, self.model.config.mask_time_length, min_masks=2, ) # sample indices to take for negative vectors batch["sampled_negative_indices"] = _sample_negative_indices( (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)), self.model.config.num_negatives, ) return batch
def test_train(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) input_values = inputs_dict["input_values"] attention_mask = inputs_dict["attention_mask"] model = FlaxWav2Vec2ForPreTraining(config) features_shape = ( input_values.shape[0], model._get_feat_extract_output_lengths( np.array(input_values.shape[1])), ) batch_size, sequence_length = features_shape[:2] mask_prob = 0.5 mask_length = 4 mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob, mask_length) dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0)) output = model( input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices, train=True, dropout_rng=dropout_rng, gumbel_rng=gumbel_rng, )[0] self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim))
def test_inference_pretrained(self): model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "facebook/wav2vec2-large-lv60", return_attention_mask=True ) input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])), ) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, min_masks=2, ) outputs = model( inputs_dict.input_values, attention_mask=inputs_dict.attention_mask, mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim = optax.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8 ) # retrieve cosine sim of masked features cosine_sim_masked = cosine_sim[mask_time_indices] # ... now compare to randomly initialized model config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60") model_rand = FlaxWav2Vec2ForPreTraining(config) outputs_rand = model_rand( inputs_dict.input_values, attention_mask=inputs_dict.attention_mask, mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim_rand = optax.cosine_similarity( outputs_rand.projected_states, outputs_rand.projected_quantized_states ) # retrieve cosine sim of masked features cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices] # a pretrained wav2vec2 model has learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states > 0.5 # a random wav2vec2 model has not learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states is very likely < 0.1 self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
def test_compute_mask_indices(self): batch_size = 4 sequence_length = 60 mask_prob = 0.5 mask_length = 1 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self): batch_size = 4 sequence_length = 80 mask_prob = 0.5 mask_length = 4 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def __call__( self, features: List[Dict[str, Union[List[int], np.ndarray]]] ) -> Dict[str, np.ndarray]: # reformat list to dict and set to pytorch format batch = self.feature_extractor.pad( features, max_length=self.max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="np", ) mask_indices_seq_length = self.model._get_feat_extract_output_lengths( batch["input_values"].shape[-1]) batch_size = batch["input_values"].shape[0] attention_mask = None if batch["attention_mask"] is not None: output_lengths = self.model._get_feat_extract_output_lengths( batch["attention_mask"].sum(-1)) attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8) # these two operations makes sure that all values # before the output lengths indices are attended to attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1 attention_mask = jnp.flip( jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") # sample randomly masked indices batch["mask_time_indices"] = _compute_mask_indices( (batch_size, mask_indices_seq_length), self.model.config.mask_time_prob, self.model.config.mask_time_length, attention_mask=attention_mask, min_masks=2, ) # sample indices to take for negative vectors batch["sampled_negative_indices"] = _sample_negative_indices( (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim, )), self.model.config.num_negatives, attention_mask=attention_mask, ) return batch
def test_compute_mask_indices_attn_mask_overlap(self): batch_size = 4 sequence_length = 80 mask_prob = 0.5 mask_length = 4 attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32) attention_mask[:2, sequence_length // 2 :] = 0 mask = _compute_mask_indices( (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask ) for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)