def check_encoder_decoder_model(
        self,
        config,
        attention_mask,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        input_values=None,
        input_features=None,
        **kwargs
    ):
        encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
        enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
        self.assertTrue(enc_dec_model.config.decoder.is_decoder)
        self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
        self.assertTrue(enc_dec_model.config.is_encoder_decoder)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_values=input_values,
            input_features=input_features,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            output_hidden_states=True,
        )
        self.assertEqual(
            outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
        )
        encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
        outputs_encoder_decoder = enc_dec_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(
            outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
        )
Esempio n. 2
0
    def check_encoder_decoder_model_with_inputs(self,
                                                config,
                                                attention_mask,
                                                decoder_config,
                                                decoder_input_ids,
                                                decoder_attention_mask,
                                                input_values=None,
                                                input_features=None,
                                                **kwargs):
        inputs = input_values if input_features is None else input_features
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model,
                                                  decoder=decoder_model)
        enc_dec_model.to(torch_device)

        outputs_encoder_decoder = enc_dec_model(
            inputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            output_hidden_states=True,
        )
        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        outputs_encoder_decoder_kwarg = enc_dec_model(
            inputs=inputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            output_hidden_states=True,
        )
        self.assertEqual(outputs_encoder_decoder_kwarg["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))