def test_inference_pretrained(self): model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "facebook/wav2vec2-large-lv60", return_attention_mask=True ) input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])), ) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, min_masks=2, ) outputs = model( inputs_dict.input_values, attention_mask=inputs_dict.attention_mask, mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim = optax.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8 ) # retrieve cosine sim of masked features cosine_sim_masked = cosine_sim[mask_time_indices] # ... now compare to randomly initialized model config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60") model_rand = FlaxWav2Vec2ForPreTraining(config) outputs_rand = model_rand( inputs_dict.input_values, attention_mask=inputs_dict.attention_mask, mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim_rand = optax.cosine_similarity( outputs_rand.projected_states, outputs_rand.projected_quantized_states ) # retrieve cosine sim of masked features cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices] # a pretrained wav2vec2 model has learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states > 0.5 # a random wav2vec2 model has not learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states is very likely < 0.1 self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
def compute_contrastive_loss( quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives ): batch_size, sequence_length, hidden_size = quantized_features.shape # take negative vectors from sampled indices quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)] quantized_negatives = quantized_negatives.reshape( batch_size, sequence_length, num_negatives, hidden_size ).transpose(2, 0, 1, 3) target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0) loss_logits = optax.cosine_similarity(transformer_features, target_features) loss_logits = loss_logits / logits_temp neg_is_pos = (quantized_features == quantized_negatives).all(-1) neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0) # make sure incorrectly sampled vectors don't contribute to loss loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits) predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0]) targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten() target_mask = jnp.where(targets >= 0, 1.0, 0.0) contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask contrastive_loss = contrastive_loss.sum() return contrastive_loss
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8): # compute cosine similarity of projected and projected_quantized states cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon) loss = cosine_sim.sum() return loss
def compute_loss( params, input_values, attention_mask, freeze_feature_encoder: bool = False, epsilon: float = 1e-8 ): outputs = model( input_values, attention_mask=attention_mask, freeze_feature_encoder=freeze_feature_encoder, params=params, ) # compute cosine similarity of projected and projected_quantized states cosine_sim = optax.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon ) loss = cosine_sim.sum() return loss, outputs.to_tuple()