def create_and_check_batch_inference(self, config, input_values, *args):
        # test does not pass for models making use of `group_norm`
        # check: https://github.com/pytorch/fairseq/issues/3227
        config.layerdrop = 0.0
        model = TFWav2Vec2Model(config)

        input_values = input_values[:3]
        attention_mask = tf.ones_like(input_values)

        input_lengths = tf.constant(
            [input_values.shape[-1] // i for i in [4, 2, 1]])
        length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)

        # convert values that are over input_lengths to padding
        input_values = input_values * length_mask
        attention_mask = attention_mask * length_mask

        batch_outputs = model(input_values,
                              attention_mask=attention_mask,
                              training=False).last_hidden_state

        for i in range(input_values.shape[0]):
            input_slice = input_values[i:i + 1, :input_lengths[i]]
            output = model(input_slice, training=False).last_hidden_state

            batch_output = batch_outputs[i:i + 1, :output.shape[1]]
            self.parent.assertTrue(np.allclose(output, batch_output,
                                               atol=1e-3))
 def test_model_from_pretrained(self):
     model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
     self.assertIsNotNone(model)
 def create_and_check_model(self, config, input_values, attention_mask):
     model = TFWav2Vec2Model(config)
     result = model(input_values, attention_mask=attention_mask)
     self.parent.assertEqual(
         result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
     )