def check_encoder_decoder_model_generate(self, input_ids, config,
                                             decoder_config, **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        kwargs = {
            "encoder_model": encoder_model,
            "decoder_model": decoder_model
        }
        enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(
            **kwargs)

        pad_token_id = enc_dec_model.config.decoder.pad_token_id
        eos_token_id = enc_dec_model.config.decoder.eos_token_id
        decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id

        # Copied from generation_utils (GPT2 doesn't have `pad_token_id`)
        if pad_token_id is None and eos_token_id is not None:
            pad_token_id = eos_token_id
        if decoder_start_token_id is None:
            decoder_start_token_id = enc_dec_model.config.decoder.bos_token_id

        # Bert does not have a bos token id, so use pad_token_id instead
        # Copied from `test_modeling_encoder_decoder.py`
        if decoder_start_token_id is None:
            decoder_start_token_id = pad_token_id

        generated_output = enc_dec_model.generate(
            input_ids,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            decoder_start_token_id=decoder_start_token_id,
        )
        generated_sequences = generated_output.sequences
        self.assertEqual(generated_sequences.shape, (input_ids.shape[0], ) +
                         (decoder_config.max_length, ))
    def check_encoder_decoder_model_from_pretrained(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask,
            return_dict, **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        kwargs = {
            "encoder_model": encoder_model,
            "decoder_model": decoder_model,
            "return_dict": return_dict
        }
        enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(
            **kwargs)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            return_dict=True,
        )

        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape,
            (input_ids.shape + (config.hidden_size, )))
예제 #3
0
    def check_encoder_decoder_model_output_attentions(
        self,
        config,
        input_ids,
        attention_mask,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        **kwargs
    ):
        # make the decoder inputs a different shape from the encoder inputs to harden the test
        decoder_input_ids = decoder_input_ids[:, :-1]
        decoder_attention_mask = decoder_attention_mask[:, :-1]
        encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
        kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
        enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            output_attentions=True,
        )

        encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
        self.assertEqual(len(encoder_attentions), config.num_hidden_layers)

        self.assertEqual(
            encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
        )

        decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
        num_decoder_layers = (
            decoder_config.num_decoder_layers
            if hasattr(decoder_config, "num_decoder_layers")
            else decoder_config.num_hidden_layers
        )
        self.assertEqual(len(decoder_attentions), num_decoder_layers)

        self.assertEqual(
            decoder_attentions[0].shape[-3:],
            (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
        )

        cross_attentions = outputs_encoder_decoder["cross_attentions"]
        self.assertEqual(len(cross_attentions), num_decoder_layers)

        cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
            1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
        )
        self.assertEqual(
            cross_attentions[0].shape[-3:],
            (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
        )
예제 #4
0
    def check_encoder_decoder_model_from_encoder_decoder_pretrained(
        self,
        config,
        input_ids,
        attention_mask,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        **kwargs
    ):
        encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
        # assert that model attributes match those of configs
        self.assertEqual(config.use_cache, encoder_model.config.use_cache)
        self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache)

        with tempfile.TemporaryDirectory() as enc_tmpdir:
            with tempfile.TemporaryDirectory() as dec_tmpdir:
                encoder_model.save_pretrained(enc_tmpdir)
                decoder_model.save_pretrained(dec_tmpdir)
                # load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs
                enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(
                    encoder_pretrained_model_name_or_path=enc_tmpdir,
                    decoder_pretrained_model_name_or_path=dec_tmpdir,
                    encoder_use_cache=not config.use_cache,
                    decoder_use_cache=not decoder_config.use_cache,
                )

        # assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied
        self.assertNotEqual(config.use_cache, enc_dec_model.config.encoder.use_cache)
        self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache)

        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )

        self.assertEqual(
            outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
        )
예제 #5
0
    def check_save_and_load(
        self,
        config,
        input_ids,
        attention_mask,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        **kwargs
    ):
        encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
        kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
        enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)

        outputs = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )
        out_2 = np.array(outputs[0])
        out_2[np.isnan(out_2)] = 0

        with tempfile.TemporaryDirectory() as tmpdirname:
            enc_dec_model.save_pretrained(tmpdirname)
            FlaxEncoderDecoderModel.from_pretrained(tmpdirname)

            after_outputs = enc_dec_model(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
                attention_mask=attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )
            out_1 = np.array(after_outputs[0])
            out_1[np.isnan(out_1)] = 0
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)
 def get_from_encoderdecoder_pretrained_model(self):
     return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(
         "bert-base-cased", "gpt2")
예제 #7
0
 def get_pretrained_model(self):
     return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base")