def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = FlaxWav2Vec2Model(config)
     decoder_model = FlaxBartForCausalLM(decoder_config)
     return encoder_model, decoder_model
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = FlaxWav2Vec2Model(config)
     decoder_model = FlaxGPT2LMHeadModel(decoder_config)
     return encoder_model, decoder_model