def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) kwargs = { "encoder_model": encoder_model, "decoder_model": decoder_model } enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained( **kwargs) pad_token_id = enc_dec_model.config.decoder.pad_token_id eos_token_id = enc_dec_model.config.decoder.eos_token_id decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id # Copied from generation_utils (GPT2 doesn't have `pad_token_id`) if pad_token_id is None and eos_token_id is not None: pad_token_id = eos_token_id if decoder_start_token_id is None: decoder_start_token_id = enc_dec_model.config.decoder.bos_token_id # Bert does not have a bos token id, so use pad_token_id instead # Copied from `test_modeling_encoder_decoder.py` if decoder_start_token_id is None: decoder_start_token_id = pad_token_id generated_output = enc_dec_model.generate( input_ids, pad_token_id=pad_token_id, eos_token_id=eos_token_id, decoder_start_token_id=decoder_start_token_id, ) generated_sequences = generated_output.sequences self.assertEqual(generated_sequences.shape, (input_ids.shape[0], ) + (decoder_config.max_length, ))
def check_encoder_decoder_model_from_pretrained( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, return_dict, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) kwargs = { "encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict } enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained( **kwargs) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, return_dict=True, ) self.assertEqual(outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size, )))
def check_encoder_decoder_model_output_attentions( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs ): # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model} enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, output_attentions=True, ) encoder_attentions = outputs_encoder_decoder["encoder_attentions"] self.assertEqual(len(encoder_attentions), config.num_hidden_layers) self.assertEqual( encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]) ) decoder_attentions = outputs_encoder_decoder["decoder_attentions"] num_decoder_layers = ( decoder_config.num_decoder_layers if hasattr(decoder_config, "num_decoder_layers") else decoder_config.num_hidden_layers ) self.assertEqual(len(decoder_attentions), num_decoder_layers) self.assertEqual( decoder_attentions[0].shape[-3:], (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), ) cross_attentions = outputs_encoder_decoder["cross_attentions"] self.assertEqual(len(cross_attentions), num_decoder_layers) cross_attention_input_seq_len = decoder_input_ids.shape[-1] * ( 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) ) self.assertEqual( cross_attentions[0].shape[-3:], (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), )
def check_encoder_decoder_model_from_encoder_decoder_pretrained( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) # assert that model attributes match those of configs self.assertEqual(config.use_cache, encoder_model.config.use_cache) self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache) with tempfile.TemporaryDirectory() as enc_tmpdir: with tempfile.TemporaryDirectory() as dec_tmpdir: encoder_model.save_pretrained(enc_tmpdir) decoder_model.save_pretrained(dec_tmpdir) # load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained( encoder_pretrained_model_name_or_path=enc_tmpdir, decoder_pretrained_model_name_or_path=dec_tmpdir, encoder_use_cache=not config.use_cache, decoder_use_cache=not decoder_config.use_cache, ) # assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied self.assertNotEqual(config.use_cache, enc_dec_model.config.encoder.use_cache) self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, output_hidden_states=True, return_dict=True, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) )
def check_save_and_load( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model} enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) outputs = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) out_2 = np.array(outputs[0]) out_2[np.isnan(out_2)] = 0 with tempfile.TemporaryDirectory() as tmpdirname: enc_dec_model.save_pretrained(tmpdirname) FlaxEncoderDecoderModel.from_pretrained(tmpdirname) after_outputs = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) out_1 = np.array(after_outputs[0]) out_1[np.isnan(out_1)] = 0 max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5)
def get_from_encoderdecoder_pretrained_model(self): return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained( "bert-base-cased", "gpt2")
def get_pretrained_model(self): return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base")