def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):

        encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

        # Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
        # the encoder/decoder models.
        # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
        #   https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
        #   (the change in `src/transformers/modeling_tf_utils.py`)
        _tf_model = TFVisionEncoderDecoderModel(encoder_decoder_config)
        # Make sure model is built
        _tf_model(**inputs_dict)

        # Using `tf_model` to pass the test.
        encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
        decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
        # Make sure models are built
        encoder(encoder.dummy_inputs)
        decoder(decoder.dummy_inputs)
        tf_model = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

        with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:

            tf_model.encoder.save_pretrained(encoder_tmp_dirname)
            tf_model.decoder.save_pretrained(decoder_tmp_dirname)
            pt_model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
                encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
            )
            # This is only for copying some specific attributes of this particular model.
            pt_model.config = tf_model.config

        self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
    def check_encoder_decoder_model_from_pretrained_configs(
        self,
        config,
        pixel_values,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        **kwargs
    ):
        encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
        self.assertTrue(encoder_decoder_config.decoder.is_decoder)

        enc_dec_model = TFVisionEncoderDecoderModel(encoder_decoder_config)

        self.assertTrue(enc_dec_model.config.is_encoder_decoder)

        outputs_encoder_decoder = enc_dec_model(
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(
            outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
        )
        self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
        self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
    def check_encoder_decoder_model_from_pretrained_configs(
            self,
            config,
            decoder_config,
            decoder_input_ids,
            decoder_attention_mask,
            pixel_values=None,
            **kwargs):
        encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
            config, decoder_config)
        self.assertTrue(encoder_decoder_config.decoder.is_decoder)

        enc_dec_model = VisionEncoderDecoderModel(encoder_decoder_config)
        enc_dec_model.to(torch_device)
        enc_dec_model.eval()

        self.assertTrue(enc_dec_model.config.is_encoder_decoder)

        outputs_encoder_decoder = enc_dec_model(
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
    def test_pt_tf_equivalence(self):

        config_inputs_dict = self.prepare_config_and_inputs()
        # Keep only common arguments
        arg_names = [
            "config",
            "pixel_values",
            "decoder_config",
            "decoder_input_ids",
            "decoder_attention_mask",
            "encoder_hidden_states",
        ]
        config_inputs_dict = {
            k: v
            for k, v in config_inputs_dict.items() if k in arg_names
        }

        config = config_inputs_dict.pop("config")
        decoder_config = config_inputs_dict.pop("decoder_config")

        inputs_dict = config_inputs_dict
        # `encoder_hidden_states` is not used in model call/forward
        del inputs_dict["encoder_hidden_states"]

        # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
        batch_size = inputs_dict["decoder_attention_mask"].shape[0]
        inputs_dict["decoder_attention_mask"] = tf.constant(
            np.concatenate([
                np.ones(shape=(batch_size, 1)),
                inputs_dict["decoder_attention_mask"][:, 1:]
            ],
                           axis=1))

        # TF models don't use the `use_cache` option and cache is not returned as a default.
        # So we disable `use_cache` here for PyTorch model.
        decoder_config.use_cache = False

        self.assertTrue(decoder_config.cross_attention_hidden_size is None)

        # check without `enc_to_dec_proj` projection
        self.assertTrue(config.hidden_size == decoder_config.hidden_size)
        self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
        self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)

        # This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
        # which randomly initialize `enc_to_dec_proj`.
        # # check `enc_to_dec_proj` work as expected
        # decoder_config.hidden_size = decoder_config.hidden_size * 2
        # self.assertTrue(config.hidden_size != decoder_config.hidden_size)
        # self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
        # self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)

        # Let's just check `enc_to_dec_proj` can run for now
        decoder_config.hidden_size = decoder_config.hidden_size * 2
        self.assertTrue(config.hidden_size != decoder_config.hidden_size)
        encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
            config, decoder_config)
        model = TFVisionEncoderDecoderModel(encoder_decoder_config)
        model(**inputs_dict)
 def get_encoder_decoder_config(self):
     encoder_config = AutoConfig.from_pretrained(
         "google/vit-base-patch16-224-in21k")
     decoder_config = AutoConfig.from_pretrained("../gpt2",
                                                 is_decoder=True,
                                                 add_cross_attention=True)
     return VisionEncoderDecoderConfig.from_encoder_decoder_configs(
         encoder_config, decoder_config)
 def get_encoder_decoder_config_small(self):
     encoder_config = AutoConfig.from_pretrained(
         "hf-internal-testing/tiny-random-vit")
     decoder_config = AutoConfig.from_pretrained(
         "hf-internal-testing/tiny-random-gpt2",
         is_decoder=True,
         add_cross_attention=True)
     return VisionEncoderDecoderConfig.from_encoder_decoder_configs(
         encoder_config, decoder_config)
    def check_equivalence_flax_to_pt(self, config, decoder_config,
                                     inputs_dict):

        encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
            config, decoder_config)

        pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
        fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)

        pt_model = load_flax_weights_in_pytorch_model(pt_model,
                                                      fx_model.params)

        self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
    def check_equivalence_pt_to_flax(self, config, decoder_config,
                                     inputs_dict):

        encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
            config, decoder_config)

        pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
        fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)

        fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(),
                                                      fx_model)
        fx_model.params = fx_state

        self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
    def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):

        encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

        pt_model = VisionEncoderDecoderModel(encoder_decoder_config)

        with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:

            pt_model.encoder.save_pretrained(encoder_tmp_dirname)
            pt_model.decoder.save_pretrained(decoder_tmp_dirname)
            tf_model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
                encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
            )
            # This is only for copying some specific attributes of this particular model.
            tf_model.config = pt_model.config

        self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)