def test_inference_pretrained(self): model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "facebook/wav2vec2-large-lv60", return_attention_mask=True ) input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])), ) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, min_masks=2, ) outputs = model( inputs_dict.input_values, attention_mask=inputs_dict.attention_mask, mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim = optax.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8 ) # retrieve cosine sim of masked features cosine_sim_masked = cosine_sim[mask_time_indices] # ... now compare to randomly initialized model config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60") model_rand = FlaxWav2Vec2ForPreTraining(config) outputs_rand = model_rand( inputs_dict.input_values, attention_mask=inputs_dict.attention_mask, mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim_rand = optax.cosine_similarity( outputs_rand.projected_states, outputs_rand.projected_quantized_states ) # retrieve cosine sim of masked features cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices] # a pretrained wav2vec2 model has learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states > 0.5 # a random wav2vec2 model has not learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states is very likely < 0.1 self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
def test_train(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) input_values = inputs_dict["input_values"] attention_mask = inputs_dict["attention_mask"] model = FlaxWav2Vec2ForPreTraining(config) features_shape = ( input_values.shape[0], model._get_feat_extract_output_lengths( np.array(input_values.shape[1])), ) batch_size, sequence_length = features_shape[:2] mask_prob = 0.5 mask_length = 4 mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob, mask_length) dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0)) output = model( input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices, train=True, dropout_rng=dropout_rng, gumbel_rng=gumbel_rng, )[0] self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim))
def test_freeze_feature_encoder(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) input_values = inputs_dict["input_values"] attention_mask = inputs_dict["attention_mask"] model = FlaxWav2Vec2ForPreTraining(config) outputs = model( input_values, attention_mask=attention_mask, freeze_feature_encoder=False, ) outputs_frozen = model( input_values, attention_mask=attention_mask, freeze_feature_encoder=True, ) # dummy loss function def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8): # compute cosine similarity of projected and projected_quantized states cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon) loss = cosine_sim.sum() return loss # transform the loss function to get the gradients grad_fn = jax.value_and_grad(compute_loss) # compute loss and gradients for unfrozen model loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states) # compare to loss and gradients for frozen model loss_frozen, grads_frozen = grad_fn( outputs_frozen.projected_states, outputs_frozen.projected_quantized_states) self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5) self.assertEqual(grads.shape, grads_frozen.shape) max_diff = np.amax(np.abs(grads - grads_frozen)) self.assertLessEqual(max_diff, 1e-5)
def test_freeze_feature_encoder(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) input_values = inputs_dict["input_values"] attention_mask = inputs_dict["attention_mask"] model = FlaxWav2Vec2ForPreTraining(config) params = model.params # dummy loss function def compute_loss(params, input_values, attention_mask, freeze_feature_encoder: bool = False, epsilon: float = 1e-8): outputs = model( input_values, attention_mask=attention_mask, freeze_feature_encoder=freeze_feature_encoder, params=params, ) # compute cosine similarity of projected and projected_quantized states cosine_sim = optax.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon) loss = cosine_sim.sum() return loss, outputs.to_tuple() # transform the loss function to get the gradients grad_fn = jax.value_and_grad(compute_loss, has_aux=True) # compute loss, outputs and gradients for unfrozen model (loss, outputs), grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False) # compare to loss, outputs and gradients for frozen model (loss_frozen, outputs_frozen), grads_frozen = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=True) # ensure that the outputs and losses remain precisely equal for output, output_frozen in zip(outputs, outputs_frozen): self.assertTrue((output == output_frozen).all()) self.assertEqual(loss, loss_frozen) grads = flatten_dict(grads) grads_frozen = flatten_dict(grads_frozen) # ensure that the dicts of gradients contain the same keys self.assertEqual(grads.keys(), grads_frozen.keys()) # ensure that the gradients of the feature extractor layers are precisely zero when frozen and contain non-zero entries when unfrozen feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k) feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k) for feature_extractor_grad, feature_extractor_grad_frozen in zip( feature_extractor_grads, feature_extractor_grads_frozen): self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) self.assertTrue((feature_extractor_grad > 0.0).any()) # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) for grad, grad_frozen in zip(grads, grads_frozen): self.assertTrue((grad == grad_frozen).all())