def check_encoder_decoder_model_generate(self, config, decoder_config, input_values=None, input_features=None, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) # make sure EOS token is set to None to prevent early stopping of generation enc_dec_model.config.eos_token_id = None if hasattr(enc_dec_model.config, "decoder") and hasattr( enc_dec_model.config.decoder, "eos_token_id"): enc_dec_model.config.decoder.eos_token_id = None inputs = input_values if input_features is None else input_features # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id) self.assertEqual(generated_output.shape, (inputs.shape[0], ) + (decoder_config.max_length, ))
def check_encoder_decoder_model_generate( self, config, decoder_config, input_values=None, input_features=None, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) inputs = input_values if input_features is None else input_features # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id ) self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))