def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = Speech2TextModel(config=config).to(torch_device).eval() outputs = model(**inputs_dict) encoder_last_hidden_state = outputs.encoder_last_hidden_state last_hidden_state = outputs.last_hidden_state with tempfile.TemporaryDirectory() as tmpdirname: encoder = model.get_encoder() encoder.save_pretrained(tmpdirname) encoder = Speech2TextEncoder.from_pretrained(tmpdirname).to(torch_device) encoder_last_hidden_state_2 = encoder( inputs_dict["input_features"], attention_mask=inputs_dict["attention_mask"] )[0] self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) with tempfile.TemporaryDirectory() as tmpdirname: decoder = model.get_decoder() decoder.save_pretrained(tmpdirname) decoder = Speech2TextDecoder.from_pretrained(tmpdirname).to(torch_device) encoder_attention_mask = encoder._get_feature_vector_attention_mask( encoder_last_hidden_state.shape[1], inputs_dict["attention_mask"] ) last_hidden_state_2 = decoder( input_ids=inputs_dict["decoder_input_ids"], attention_mask=inputs_dict["decoder_attention_mask"], encoder_hidden_states=encoder_last_hidden_state, encoder_attention_mask=encoder_attention_mask, )[0] self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
def get_encoder_decoder_model(self, config, decoder_config): encoder_model = Speech2TextEncoder(config).eval() decoder_model = BertLMHeadModel(decoder_config).eval() return encoder_model, decoder_model