Esempio n. 1
0
    def test_inference_pretrained(self):
        model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-large-lv60", return_attention_mask=True
        )
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])),
        )

        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            min_masks=2,
        )

        outputs = model(
            inputs_dict.input_values,
            attention_mask=inputs_dict.attention_mask,
            mask_time_indices=mask_time_indices,
        )

        # compute cosine similarity
        cosine_sim = optax.cosine_similarity(
            outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8
        )

        # retrieve cosine sim of masked features
        cosine_sim_masked = cosine_sim[mask_time_indices]

        # ... now compare to randomly initialized model

        config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60")
        model_rand = FlaxWav2Vec2ForPreTraining(config)

        outputs_rand = model_rand(
            inputs_dict.input_values,
            attention_mask=inputs_dict.attention_mask,
            mask_time_indices=mask_time_indices,
        )

        # compute cosine similarity
        cosine_sim_rand = optax.cosine_similarity(
            outputs_rand.projected_states, outputs_rand.projected_quantized_states
        )

        # retrieve cosine sim of masked features
        cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]

        # a pretrained wav2vec2 model has learned to predict the quantized latent states
        # => the cosine similarity between quantized states and predicted states > 0.5
        # a random wav2vec2 model has not learned to predict the quantized latent states
        # => the cosine similarity between quantized states and predicted states is very likely < 0.1
        self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
Esempio n. 2
0
def compute_contrastive_loss(
    quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives
):
    batch_size, sequence_length, hidden_size = quantized_features.shape

    # take negative vectors from sampled indices
    quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)]
    quantized_negatives = quantized_negatives.reshape(
        batch_size, sequence_length, num_negatives, hidden_size
    ).transpose(2, 0, 1, 3)

    target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0)
    loss_logits = optax.cosine_similarity(transformer_features, target_features)
    loss_logits = loss_logits / logits_temp

    neg_is_pos = (quantized_features == quantized_negatives).all(-1)
    neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)

    # make sure incorrectly sampled vectors don't contribute to loss
    loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits)

    predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
    targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()

    target_mask = jnp.where(targets >= 0, 1.0, 0.0)
    contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask

    contrastive_loss = contrastive_loss.sum()

    return contrastive_loss
Esempio n. 3
0
 def compute_loss(projected_states,
                  projected_quantized_states,
                  epsilon=1e-8):
     # compute cosine similarity of projected and projected_quantized states
     cosine_sim = optax.cosine_similarity(projected_states,
                                          projected_quantized_states,
                                          epsilon=epsilon)
     loss = cosine_sim.sum()
     return loss
Esempio n. 4
0
 def compute_loss(
     params, input_values, attention_mask, freeze_feature_encoder: bool = False, epsilon: float = 1e-8
 ):
     outputs = model(
         input_values,
         attention_mask=attention_mask,
         freeze_feature_encoder=freeze_feature_encoder,
         params=params,
     )
     # compute cosine similarity of projected and projected_quantized states
     cosine_sim = optax.cosine_similarity(
         outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon
     )
     loss = cosine_sim.sum()
     return loss, outputs.to_tuple()