def check_encoder_decoder_model( self, config, attention_mask, decoder_config, decoder_input_ids, decoder_attention_mask, input_values=None, input_features=None, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) self.assertTrue(enc_dec_model.config.decoder.is_decoder) self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) self.assertTrue(enc_dec_model.config.is_encoder_decoder) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_values=input_values, input_features=input_features, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, output_hidden_states=True, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) ) encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1]) outputs_encoder_decoder = enc_dec_model( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) )
def check_encoder_decoder_model_with_inputs(self, config, attention_mask, decoder_config, decoder_input_ids, decoder_attention_mask, input_values=None, input_features=None, **kwargs): inputs = input_values if input_features is None else input_features encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( inputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, output_hidden_states=True, ) self.assertEqual(outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) outputs_encoder_decoder_kwarg = enc_dec_model( inputs=inputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, output_hidden_states=True, ) self.assertEqual(outputs_encoder_decoder_kwarg["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, )))