def check_encoder_decoder_model_standalone(self, config, inputs_dict):
        model = Speech2TextModel(config=config).to(torch_device).eval()
        outputs = model(**inputs_dict)

        encoder_last_hidden_state = outputs.encoder_last_hidden_state
        last_hidden_state = outputs.last_hidden_state

        with tempfile.TemporaryDirectory() as tmpdirname:
            encoder = model.get_encoder()
            encoder.save_pretrained(tmpdirname)
            encoder = Speech2TextEncoder.from_pretrained(tmpdirname).to(torch_device)

        encoder_last_hidden_state_2 = encoder(
            inputs_dict["input_features"], attention_mask=inputs_dict["attention_mask"]
        )[0]

        self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)

        with tempfile.TemporaryDirectory() as tmpdirname:
            decoder = model.get_decoder()
            decoder.save_pretrained(tmpdirname)
            decoder = Speech2TextDecoder.from_pretrained(tmpdirname).to(torch_device)

        encoder_attention_mask = encoder._get_feature_vector_attention_mask(
            encoder_last_hidden_state.shape[1], inputs_dict["attention_mask"]
        )

        last_hidden_state_2 = decoder(
            input_ids=inputs_dict["decoder_input_ids"],
            attention_mask=inputs_dict["decoder_attention_mask"],
            encoder_hidden_states=encoder_last_hidden_state,
            encoder_attention_mask=encoder_attention_mask,
        )[0]

        self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
Ejemplo n.º 2
0
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = Speech2TextEncoder(config).eval()
     decoder_model = BertLMHeadModel(decoder_config).eval()
     return encoder_model, decoder_model