def test_real_model_save_load_from_pretrained(self):
        model_2 = self.get_pretrained_model()
        model_2.to(torch_device)
        input_name, inputs = self.get_inputs()
        decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
        attention_mask = ids_tensor([13, 5], vocab_size=2)
        with torch.no_grad():
            outputs = model_2(
                **{input_name: inputs},
                decoder_input_ids=decoder_input_ids,
                attention_mask=attention_mask,
            )
            out_2 = outputs[0].cpu().numpy()
            out_2[np.isnan(out_2)] = 0

            with tempfile.TemporaryDirectory() as tmp_dirname:
                model_2.save_pretrained(tmp_dirname)
                model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
                model_1.to(torch_device)

                after_outputs = model_1(
                    **{input_name: inputs},
                    decoder_input_ids=decoder_input_ids,
                    attention_mask=attention_mask,
                )
                out_1 = after_outputs[0].cpu().numpy()
                out_1[np.isnan(out_1)] = 0
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
    def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

        pt_model.to(torch_device)
        pt_model.eval()

        # prepare inputs
        flax_inputs = inputs_dict
        pt_inputs = {
            k: torch.tensor(v.tolist())
            for k, v in flax_inputs.items()
        }

        with torch.no_grad():
            pt_outputs = pt_model(**pt_inputs)
        pt_logits = pt_outputs.logits
        pt_outputs = pt_outputs.to_tuple()

        fx_outputs = fx_model(**inputs_dict)
        fx_logits = fx_outputs.logits
        fx_outputs = fx_outputs.to_tuple()

        self.assertEqual(len(fx_outputs), len(pt_outputs),
                         "Output lengths differ between Flax and PyTorch")
        self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)

        # PT -> Flax
        with tempfile.TemporaryDirectory() as tmpdirname:
            pt_model.save_pretrained(tmpdirname)
            fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(
                tmpdirname, from_pt=True)

        fx_outputs_loaded = fx_model_loaded(**inputs_dict)
        fx_logits_loaded = fx_outputs_loaded.logits
        fx_outputs_loaded = fx_outputs_loaded.to_tuple()

        self.assertEqual(len(fx_outputs_loaded), len(pt_outputs),
                         "Output lengths differ between Flax and PyTorch")
        self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)

        # Flax -> PT
        with tempfile.TemporaryDirectory() as tmpdirname:
            fx_model.save_pretrained(tmpdirname)
            pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(
                tmpdirname, from_flax=True)

        pt_model_loaded.to(torch_device)
        pt_model_loaded.eval()

        with torch.no_grad():
            pt_outputs_loaded = pt_model_loaded(**pt_inputs)
        pt_logits_loaded = pt_outputs_loaded.logits
        pt_outputs_loaded = pt_outputs_loaded.to_tuple()

        self.assertEqual(len(fx_outputs), len(pt_outputs_loaded),
                         "Output lengths differ between Flax and PyTorch")
        self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
Ejemplo n.º 3
0
    def check_save_and_load(self,
                            config,
                            attention_mask,
                            decoder_config,
                            decoder_input_ids,
                            decoder_attention_mask,
                            input_values=None,
                            input_features=None,
                            **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model,
                                                  decoder=decoder_model)
        enc_dec_model.to(torch_device)
        enc_dec_model.eval()
        with torch.no_grad():
            outputs = enc_dec_model(
                input_values=input_values,
                input_features=input_features,
                decoder_input_ids=decoder_input_ids,
                attention_mask=attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )
            out_2 = outputs[0].cpu().numpy()
            out_2[np.isnan(out_2)] = 0

            with tempfile.TemporaryDirectory() as tmpdirname:
                enc_dec_model.save_pretrained(tmpdirname)
                enc_dec_model = SpeechEncoderDecoderModel.from_pretrained(
                    tmpdirname)
                enc_dec_model.to(torch_device)

                after_outputs = enc_dec_model(
                    input_values=input_values,
                    input_features=input_features,
                    decoder_input_ids=decoder_input_ids,
                    attention_mask=attention_mask,
                    decoder_attention_mask=decoder_attention_mask,
                )
                out_1 = after_outputs[0].cpu().numpy()
                out_1[np.isnan(out_1)] = 0
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
    def test_real_model_save_load_from_pretrained(self):
        model_2, inputs = self.get_pretrained_model_and_inputs()
        model_2.to(torch_device)

        with torch.no_grad():
            outputs = model_2(**inputs)
            out_2 = outputs[0].cpu().numpy()
            out_2[np.isnan(out_2)] = 0

            with tempfile.TemporaryDirectory() as tmp_dirname:
                model_2.save_pretrained(tmp_dirname)
                model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
                model_1.to(torch_device)

                after_outputs = model_1(**inputs)
                out_1 = after_outputs[0].cpu().numpy()
                out_1[np.isnan(out_1)] = 0
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
    def test_flaxwav2vec2bart_pt_flax_equivalence(self):
        pt_model = SpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
        fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
            "patrickvonplaten/wav2vec2-2-bart-large", from_pt=True
        )

        pt_model.to(torch_device)
        pt_model.eval()

        # prepare inputs
        batch_size = 13
        input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
        attention_mask = random_attention_mask([batch_size, 512])
        decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
        decoder_attention_mask = random_attention_mask([batch_size, 4])
        inputs_dict = {
            "inputs": input_values,
            "attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }

        flax_inputs = inputs_dict
        pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}

        with torch.no_grad():
            pt_outputs = pt_model(**pt_inputs)
        pt_logits = pt_outputs.logits
        pt_outputs = pt_outputs.to_tuple()

        fx_outputs = fx_model(**inputs_dict)
        fx_logits = fx_outputs.logits
        fx_outputs = fx_outputs.to_tuple()

        self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
        self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)

        # PT -> Flax
        with tempfile.TemporaryDirectory() as tmpdirname:
            pt_model.save_pretrained(tmpdirname)
            fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)

        fx_outputs_loaded = fx_model_loaded(**inputs_dict)
        fx_logits_loaded = fx_outputs_loaded.logits
        fx_outputs_loaded = fx_outputs_loaded.to_tuple()
        self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
        self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)

        # Flax -> PT
        with tempfile.TemporaryDirectory() as tmpdirname:
            fx_model.save_pretrained(tmpdirname)
            pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)

        pt_model_loaded.to(torch_device)
        pt_model_loaded.eval()

        with torch.no_grad():
            pt_outputs_loaded = pt_model_loaded(**pt_inputs)
        pt_logits_loaded = pt_outputs_loaded.logits
        pt_outputs_loaded = pt_outputs_loaded.to_tuple()

        self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
        self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)