def check_save_and_load(self, config, pixel_values, encoder_hidden_states,
                            decoder_config, decoder_input_ids,
                            decoder_attention_mask, **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        kwargs = {
            "encoder_model": encoder_model,
            "decoder_model": decoder_model
        }
        enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
            **kwargs)

        outputs = enc_dec_model(
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
        )
        out_2 = np.array(outputs[0])
        out_2[np.isnan(out_2)] = 0

        with tempfile.TemporaryDirectory() as tmpdirname:
            enc_dec_model.save_pretrained(tmpdirname)
            FlaxVisionEncoderDecoderModel.from_pretrained(tmpdirname)

            after_outputs = enc_dec_model(
                pixel_values=pixel_values,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )
            out_1 = np.array(after_outputs[0])
            out_1[np.isnan(out_1)] = 0
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)
    def test_real_model_save_load_from_pretrained(self):
        model_2 = self.get_pretrained_model()
        pixel_values = floats_tensor([
            13,
            model_2.config.encoder.num_channels,
            model_2.config.encoder.image_size,
            model_2.config.encoder.image_size,
        ])
        decoder_input_ids = ids_tensor([13, 1],
                                       model_2.config.decoder.vocab_size)

        outputs = model_2(
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
        )
        out_2 = np.array(outputs[0])
        out_2[np.isnan(out_2)] = 0

        with tempfile.TemporaryDirectory() as tmp_dirname:
            model_2.save_pretrained(tmp_dirname)
            model_1 = FlaxVisionEncoderDecoderModel.from_pretrained(
                tmp_dirname)

            after_outputs = model_1(
                pixel_values=pixel_values,
                decoder_input_ids=decoder_input_ids,
            )
            out_1 = np.array(after_outputs[0])
            out_1[np.isnan(out_1)] = 0
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)
    def test_inference_coco_en(self):

        loc = "ydshieh/vit-gpt2-coco-en"

        feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
        tokenizer = AutoTokenizer.from_pretrained(loc)
        model = FlaxVisionEncoderDecoderModel.from_pretrained(loc)

        img = prepare_img()
        pixel_values = feature_extractor(images=img,
                                         return_tensors="np").pixel_values

        decoder_input_ids = np.array([[model.config.decoder_start_token_id]])
        logits = model(pixel_values, decoder_input_ids)[0]
        logits = np.array(logits)

        # verify the logits
        expected_shape = (1, 1, model.config.decoder.vocab_size)
        self.assertEqual(logits.shape, expected_shape)

        EXPECTED_LOGIT_SLICE = np.array([
            -38.705837,
            -30.639936,
            -31.41905,
            -39.01204,
            -38.38698,
            -34.887215,
            -33.29087,
            -35.684475,
            -38.50852,
            -36.124676,
        ])
        max_diff = np.amax(np.abs(logits[0, 0, :10] - EXPECTED_LOGIT_SLICE))
        self.assertLessEqual(max_diff, 1e-4)

        def generate_step(pixel_values):

            outputs = model.generate(pixel_values, max_length=16, num_beams=4)
            output_ids = outputs.sequences
            preds = tokenizer.batch_decode(output_ids,
                                           skip_special_tokens=True)
            preds = [pred.strip() for pred in preds]

            return preds, outputs.scores

        preds, scores = generate_step(pixel_values)

        EXPECTED_SCORES = np.array([-0.59563464])
        scores = np.array(scores)
        max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
        self.assertLessEqual(max_diff, 1e-4)

        # should produce
        # ["a cat laying on top of a couch next to another cat"]
        self.assertEqual(
            preds, ["a cat laying on top of a couch next to another cat"])
    def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

        pt_model.to(torch_device)
        pt_model.eval()

        # prepare inputs
        flax_inputs = inputs_dict
        pt_inputs = {
            k: torch.tensor(v.tolist())
            for k, v in flax_inputs.items()
        }

        with torch.no_grad():
            pt_outputs = pt_model(**pt_inputs).to_tuple()

        fx_outputs = fx_model(**inputs_dict).to_tuple()
        self.assertEqual(len(fx_outputs), len(pt_outputs),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output, pt_output in zip(fx_outputs, pt_outputs):
            self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)

        # PT -> Flax
        with tempfile.TemporaryDirectory() as tmpdirname:
            pt_model.save_pretrained(tmpdirname)
            fx_model_loaded = FlaxVisionEncoderDecoderModel.from_pretrained(
                tmpdirname, from_pt=True)

        fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
        self.assertEqual(len(fx_outputs_loaded), len(pt_outputs),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
            self.assert_almost_equals(fx_output_loaded, pt_output.numpy(),
                                      1e-5)

        # Flax -> PT
        with tempfile.TemporaryDirectory() as tmpdirname:
            fx_model.save_pretrained(tmpdirname)
            pt_model_loaded = VisionEncoderDecoderModel.from_pretrained(
                tmpdirname, from_flax=True)

        pt_model_loaded.to(torch_device)
        pt_model_loaded.eval()

        with torch.no_grad():
            pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

        self.assertEqual(len(fx_outputs), len(pt_outputs_loaded),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
            self.assert_almost_equals(fx_output, pt_output_loaded.numpy(),
                                      1e-5)