예제 #1
0
    def test_inference_pretrained(self):
        model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-large-lv60", return_attention_mask=True
        )
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True)

        features_shape = (
            inputs_dict["input_values"].shape[0],
            model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])),
        )

        mask_time_indices = _compute_mask_indices(
            features_shape,
            model.config.mask_time_prob,
            model.config.mask_time_length,
            min_masks=2,
        )

        outputs = model(
            inputs_dict.input_values,
            attention_mask=inputs_dict.attention_mask,
            mask_time_indices=mask_time_indices,
        )

        # compute cosine similarity
        cosine_sim = optax.cosine_similarity(
            outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8
        )

        # retrieve cosine sim of masked features
        cosine_sim_masked = cosine_sim[mask_time_indices]

        # ... now compare to randomly initialized model

        config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60")
        model_rand = FlaxWav2Vec2ForPreTraining(config)

        outputs_rand = model_rand(
            inputs_dict.input_values,
            attention_mask=inputs_dict.attention_mask,
            mask_time_indices=mask_time_indices,
        )

        # compute cosine similarity
        cosine_sim_rand = optax.cosine_similarity(
            outputs_rand.projected_states, outputs_rand.projected_quantized_states
        )

        # retrieve cosine sim of masked features
        cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]

        # a pretrained wav2vec2 model has learned to predict the quantized latent states
        # => the cosine similarity between quantized states and predicted states > 0.5
        # a random wav2vec2 model has not learned to predict the quantized latent states
        # => the cosine similarity between quantized states and predicted states is very likely < 0.1
        self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
    def test_train(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        input_values = inputs_dict["input_values"]
        attention_mask = inputs_dict["attention_mask"]

        model = FlaxWav2Vec2ForPreTraining(config)

        features_shape = (
            input_values.shape[0],
            model._get_feat_extract_output_lengths(
                np.array(input_values.shape[1])),
        )

        batch_size, sequence_length = features_shape[:2]

        mask_prob = 0.5
        mask_length = 4
        mask_time_indices = _compute_mask_indices(
            (batch_size, sequence_length), mask_prob, mask_length)

        dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0))

        output = model(
            input_values,
            attention_mask=attention_mask,
            mask_time_indices=mask_time_indices,
            train=True,
            dropout_rng=dropout_rng,
            gumbel_rng=gumbel_rng,
        )[0]

        self.assertTrue(output.shape == (batch_size, sequence_length,
                                         model.config.proj_codevector_dim))
예제 #3
0
    def test_freeze_feature_encoder(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        input_values = inputs_dict["input_values"]
        attention_mask = inputs_dict["attention_mask"]

        model = FlaxWav2Vec2ForPreTraining(config)

        outputs = model(
            input_values,
            attention_mask=attention_mask,
            freeze_feature_encoder=False,
        )

        outputs_frozen = model(
            input_values,
            attention_mask=attention_mask,
            freeze_feature_encoder=True,
        )

        # dummy loss function
        def compute_loss(projected_states,
                         projected_quantized_states,
                         epsilon=1e-8):
            # compute cosine similarity of projected and projected_quantized states
            cosine_sim = optax.cosine_similarity(projected_states,
                                                 projected_quantized_states,
                                                 epsilon=epsilon)
            loss = cosine_sim.sum()
            return loss

        # transform the loss function to get the gradients
        grad_fn = jax.value_and_grad(compute_loss)

        # compute loss and gradients for unfrozen model
        loss, grads = grad_fn(outputs.projected_states,
                              outputs.projected_quantized_states)

        # compare to loss and gradients for frozen model
        loss_frozen, grads_frozen = grad_fn(
            outputs_frozen.projected_states,
            outputs_frozen.projected_quantized_states)

        self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5)
        self.assertEqual(grads.shape, grads_frozen.shape)
        max_diff = np.amax(np.abs(grads - grads_frozen))
        self.assertLessEqual(max_diff, 1e-5)
    def test_freeze_feature_encoder(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        input_values = inputs_dict["input_values"]
        attention_mask = inputs_dict["attention_mask"]

        model = FlaxWav2Vec2ForPreTraining(config)
        params = model.params

        # dummy loss function
        def compute_loss(params,
                         input_values,
                         attention_mask,
                         freeze_feature_encoder: bool = False,
                         epsilon: float = 1e-8):
            outputs = model(
                input_values,
                attention_mask=attention_mask,
                freeze_feature_encoder=freeze_feature_encoder,
                params=params,
            )
            # compute cosine similarity of projected and projected_quantized states
            cosine_sim = optax.cosine_similarity(
                outputs.projected_states,
                outputs.projected_quantized_states,
                epsilon=epsilon)
            loss = cosine_sim.sum()
            return loss, outputs.to_tuple()

        # transform the loss function to get the gradients
        grad_fn = jax.value_and_grad(compute_loss, has_aux=True)

        # compute loss, outputs and gradients for unfrozen model
        (loss, outputs), grads = grad_fn(params,
                                         input_values,
                                         attention_mask,
                                         freeze_feature_encoder=False)

        # compare to loss, outputs and gradients for frozen model
        (loss_frozen,
         outputs_frozen), grads_frozen = grad_fn(params,
                                                 input_values,
                                                 attention_mask,
                                                 freeze_feature_encoder=True)

        # ensure that the outputs and losses remain precisely equal
        for output, output_frozen in zip(outputs, outputs_frozen):
            self.assertTrue((output == output_frozen).all())
        self.assertEqual(loss, loss_frozen)

        grads = flatten_dict(grads)
        grads_frozen = flatten_dict(grads_frozen)

        # ensure that the dicts of gradients contain the same keys
        self.assertEqual(grads.keys(), grads_frozen.keys())

        # ensure that the gradients of the feature extractor layers are precisely zero when frozen and contain non-zero entries when unfrozen
        feature_extractor_grads = tuple(grads[k] for k in grads
                                        if "feature_extractor" in k)
        feature_extractor_grads_frozen = tuple(grads_frozen[k]
                                               for k in grads_frozen
                                               if "feature_extractor" in k)

        for feature_extractor_grad, feature_extractor_grad_frozen in zip(
                feature_extractor_grads, feature_extractor_grads_frozen):
            self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
            self.assertTrue((feature_extractor_grad > 0.0).any())

        # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
        grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
        grads_frozen = tuple(grads_frozen[k] for k in grads_frozen
                             if "feature_extractor" not in k)

        for grad, grad_frozen in zip(grads, grads_frozen):
            self.assertTrue((grad == grad_frozen).all())