def check_ctc_loss(self, config, input_values, *args):
        model = TFWav2Vec2ForCTC(config)

        input_values = input_values[:3]
        attention_mask = tf.ones_like(input_values)

        input_lengths = tf.constant(
            [input_values.shape[-1] // i for i in [4, 2, 1]])
        max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(
            input_lengths)
        labels = ids_tensor(
            (input_values.shape[0], min(max_length_labels) - 1),
            model.config.vocab_size)

        length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)

        # convert values that are over input_lengths to padding
        input_values = input_values * length_mask
        attention_mask = attention_mask * length_mask

        model.config.ctc_loss_reduction = "sum"
        sum_loss = model(input_values,
                         attention_mask=attention_mask,
                         labels=labels).loss

        model.config.ctc_loss_reduction = "mean"
        mean_loss = model(input_values,
                          attention_mask=attention_mask,
                          labels=labels).loss

        self.parent.assertTrue(
            abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2)
    def check_training(self, config, input_values, *args):
        model = TFWav2Vec2ForCTC(config)

        # freeze feature encoder
        model.freeze_feature_extractor()

        input_values = input_values[:3]

        input_lengths = tf.constant(
            [input_values.shape[-1] // i for i in [4, 2, 1]])
        max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(
            input_lengths)
        labels = ids_tensor(
            (input_values.shape[0], max(max_length_labels) - 2),
            model.config.vocab_size)

        length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)

        input_values = input_values * length_mask

        pad_size = max(max_length_labels) - labels.shape[1]
        labels = tf.pad(labels, ((0, 0), (0, pad_size)), constant_values=-100)

        loss = model(input_values, labels=labels, training=True).loss

        self.parent.assertFalse(tf.math.is_inf(loss))
Example #3
0
 def check_labels_out_of_vocab(self, config, input_values, *args):
     model = TFWav2Vec2ForCTC(config)
     input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
     max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths)
     labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size + 100)
     with pytest.raises(ValueError):
         model(input_values, labels=labels)