def check_encoder_decoder_model_from_pretrained(
            self, config, pixel_values, encoder_hidden_states, decoder_config,
            decoder_input_ids, decoder_attention_mask, return_dict, **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        kwargs = {
            "encoder_model": encoder_model,
            "decoder_model": decoder_model,
            "return_dict": return_dict
        }
        enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
            **kwargs)
        outputs_encoder_decoder = enc_dec_model(
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            return_dict=True,
        )

        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape[0],
            pixel_values.shape[0])
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1],
            config.hidden_size)
    def check_encoder_decoder_model_generate(self, pixel_values, config,
                                             decoder_config, **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        kwargs = {
            "encoder_model": encoder_model,
            "decoder_model": decoder_model
        }
        enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
            **kwargs)

        pad_token_id = enc_dec_model.config.decoder.pad_token_id
        eos_token_id = enc_dec_model.config.decoder.eos_token_id
        decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id

        # Copied from generation_utils (GPT2 doesn't have `pad_token_id`)
        if pad_token_id is None and eos_token_id is not None:
            pad_token_id = eos_token_id
        if decoder_start_token_id is None:
            decoder_start_token_id = enc_dec_model.config.decoder.bos_token_id

        # Bert does not have a bos token id, so use pad_token_id instead
        # Copied from `test_modeling_encoder_decoder.py`
        if decoder_start_token_id is None:
            decoder_start_token_id = pad_token_id

        generated_output = enc_dec_model.generate(
            pixel_values,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            decoder_start_token_id=decoder_start_token_id,
        )
        generated_sequences = generated_output.sequences
        self.assertEqual(generated_sequences.shape, (pixel_values.shape[0], ) +
                         (decoder_config.max_length, ))
    def check_save_and_load(self, config, pixel_values, encoder_hidden_states,
                            decoder_config, decoder_input_ids,
                            decoder_attention_mask, **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        kwargs = {
            "encoder_model": encoder_model,
            "decoder_model": decoder_model
        }
        enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
            **kwargs)

        outputs = enc_dec_model(
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
        )
        out_2 = np.array(outputs[0])
        out_2[np.isnan(out_2)] = 0

        with tempfile.TemporaryDirectory() as tmpdirname:
            enc_dec_model.save_pretrained(tmpdirname)
            FlaxVisionEncoderDecoderModel.from_pretrained(tmpdirname)

            after_outputs = enc_dec_model(
                pixel_values=pixel_values,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )
            out_1 = np.array(after_outputs[0])
            out_1[np.isnan(out_1)] = 0
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)
    def check_encoder_decoder_model_output_attentions(
            self, config, pixel_values, encoder_hidden_states, decoder_config,
            decoder_input_ids, decoder_attention_mask, **kwargs):
        # make the decoder inputs a different shape from the encoder inputs to harden the test
        decoder_input_ids = decoder_input_ids[:, :-1]
        decoder_attention_mask = decoder_attention_mask[:, :-1]
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        kwargs = {
            "encoder_model": encoder_model,
            "decoder_model": decoder_model
        }
        enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
            **kwargs)
        outputs_encoder_decoder = enc_dec_model(
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            output_attentions=True,
        )

        encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
        self.assertEqual(len(encoder_attentions), config.num_hidden_layers)

        self.assertEqual(encoder_attentions[0].shape[-3:-2],
                         (config.num_attention_heads, ))

        decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
        num_decoder_layers = (decoder_config.num_decoder_layers if hasattr(
            decoder_config, "num_decoder_layers") else
                              decoder_config.num_hidden_layers)
        self.assertEqual(len(decoder_attentions), num_decoder_layers)

        self.assertEqual(
            decoder_attentions[0].shape[-3:],
            (decoder_config.num_attention_heads, decoder_input_ids.shape[-1],
             decoder_input_ids.shape[-1]),
        )

        cross_attentions = outputs_encoder_decoder["cross_attentions"]
        self.assertEqual(len(cross_attentions), num_decoder_layers)

        cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
            1 +
            (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0))
        self.assertEqual(
            cross_attentions[0].shape[-3:-1],
            (decoder_config.num_attention_heads,
             cross_attention_input_seq_len),
        )
 def get_from_encoderdecoder_pretrained_model(self):
     return FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
         "google/vit-base-patch16-224-in21k", "gpt2")
Example #6
0
def main():
    parser = HfArgumentParser((ModelArguments, ))
    (model_args, ) = parser.parse_args_into_dataclasses()

    # Load pretrained model and tokenizer

    # Use explicit specified encoder config
    if model_args.encoder_config_name:
        encoder_config = AutoConfig.from_pretrained(
            model_args.encoder_config_name)
    # Use pretrained encoder model's config
    else:
        encoder_config = AutoConfig.from_pretrained(
            model_args.encoder_model_name_or_path)

    # Use explicit specified decoder config
    if model_args.decoder_config_name:
        decoder_config = AutoConfig.from_pretrained(
            model_args.decoder_config_name)
    # Use pretrained decoder model's config
    else:
        decoder_config = AutoConfig.from_pretrained(
            model_args.decoder_model_name_or_path)

    # necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
    decoder_config.is_decoder = True
    decoder_config.add_cross_attention = True

    model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        encoder_pretrained_model_name_or_path=model_args.
        encoder_model_name_or_path,
        decoder_pretrained_model_name_or_path=model_args.
        decoder_model_name_or_path,
        encoder_config=encoder_config,
        decoder_config=decoder_config,
    )

    # GPT2 only has bos/eos tokens but not decoder_start/pad tokens
    decoder_start_token_id = decoder_config.decoder_start_token_id
    pad_token_id = decoder_config.pad_token_id
    if decoder_start_token_id is None:
        decoder_start_token_id = decoder_config.bos_token_id
    if pad_token_id is None:
        pad_token_id = decoder_config.eos_token_id

    # This is necessary to make Flax's generate() work
    model.config.eos_token_id = decoder_config.eos_token_id
    model.config.decoder_start_token_id = decoder_start_token_id
    model.config.pad_token_id = pad_token_id

    feature_extractor = AutoFeatureExtractor.from_pretrained(
        model_args.encoder_model_name_or_path)

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.decoder_model_name_or_path)
    tokenizer.pad_token = tokenizer.convert_ids_to_tokens(
        model.config.pad_token_id)

    model.save_pretrained(model_args.output_dir)
    feature_extractor.save_pretrained(model_args.output_dir)
    tokenizer.save_pretrained(model_args.output_dir)