def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model() inputs = ids_tensor([13, 5], model_2.config.encoder.vocab_size) decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size) attention_mask = ids_tensor([13, 5], vocab_size=2) outputs = model_2( inputs=inputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, ) out_2 = np.array(outputs[0]) out_2[np.isnan(out_2)] = 0 with tempfile.TemporaryDirectory() as tmp_dirname: model_2.save_pretrained(tmp_dirname) model_1 = FlaxSpeechEncoderDecoderModel.from_pretrained(tmp_dirname) after_outputs = model_1( inputs=inputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, ) out_1 = np.array(after_outputs[0]) out_1[np.isnan(out_1)] = 0 max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 4e-2)
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): pt_model.to(torch_device) pt_model.eval() # prepare inputs flax_inputs = inputs_dict pt_inputs = { k: torch.tensor(v.tolist()) for k, v in flax_inputs.items() } with torch.no_grad(): pt_outputs = pt_model(**pt_inputs) pt_logits = pt_outputs.logits pt_outputs = pt_outputs.to_tuple() fx_outputs = fx_model(**inputs_dict) fx_logits = fx_outputs.logits fx_outputs = fx_outputs.to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2) # PT -> Flax with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained( tmpdirname, from_pt=True) fx_outputs_loaded = fx_model_loaded(**inputs_dict) fx_logits_loaded = fx_outputs_loaded.logits fx_outputs_loaded = fx_outputs_loaded.to_tuple() self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2) # Flax -> PT with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained( tmpdirname, from_flax=True) pt_model_loaded.to(torch_device) pt_model_loaded.eval() with torch.no_grad(): pt_outputs_loaded = pt_model_loaded(**pt_inputs) pt_logits_loaded = pt_outputs_loaded.logits pt_outputs_loaded = pt_outputs_loaded.to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
def check_save_and_load( self, config, inputs, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model} enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) outputs = enc_dec_model( inputs=inputs, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) out_2 = np.array(outputs[0]) out_2[np.isnan(out_2)] = 0 with tempfile.TemporaryDirectory() as tmpdirname: enc_dec_model.save_pretrained(tmpdirname) FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname) after_outputs = enc_dec_model( inputs=inputs, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) out_1 = np.array(after_outputs[0]) out_1[np.isnan(out_1)] = 0 max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 4e-2)
def test_flaxwav2vec2bart_pt_flax_equivalence(self): pt_model = SpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large") fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained( "patrickvonplaten/wav2vec2-2-bart-large", from_pt=True ) pt_model.to(torch_device) pt_model.eval() # prepare inputs batch_size = 13 input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size) attention_mask = random_attention_mask([batch_size, 512]) decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size) decoder_attention_mask = random_attention_mask([batch_size, 4]) inputs_dict = { "inputs": input_values, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, } flax_inputs = inputs_dict pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs) pt_logits = pt_outputs.logits pt_outputs = pt_outputs.to_tuple() fx_outputs = fx_model(**inputs_dict) fx_logits = fx_outputs.logits fx_outputs = fx_outputs.to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2) # PT -> Flax with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) fx_outputs_loaded = fx_model_loaded(**inputs_dict) fx_logits_loaded = fx_outputs_loaded.logits fx_outputs_loaded = fx_outputs_loaded.to_tuple() self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2) # Flax -> PT with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) pt_model_loaded.to(torch_device) pt_model_loaded.eval() with torch.no_grad(): pt_outputs_loaded = pt_model_loaded(**pt_inputs) pt_logits_loaded = pt_outputs_loaded.logits pt_outputs_loaded = pt_outputs_loaded.to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)