def test_equivalence_flax_pytorch(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) pt_model = pt_model_class(config).eval() fx_model = model_class(config, dtype=jnp.float32) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict) self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) self.assertEqual( len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) pt_model = VisionTextDualEncoderModel(config) fx_model = FlaxVisionTextDualEncoderModel(config) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_equivalence_pt_to_flax(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: with self.subTest(model_class.__name__): # prepare inputs prepared_inputs_dict = self._prepare_for_class( inputs_dict, model_class) pt_inputs = { k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items() } # load corresponding PyTorch class pt_model_class_name = model_class.__name__[ 4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) pt_model = pt_model_class(config).eval() # Flax models don't use the `use_cache` option and cache is not returned as a default. # So we disable `use_cache` here for PyTorch model. pt_model.config.use_cache = False fx_model = model_class(config, dtype=jnp.float32) fx_state = convert_pytorch_state_dict_to_flax( pt_model.state_dict(), fx_model) fx_model.params = fx_state with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() self.assertEqual( len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) fx_outputs_loaded = fx_model_loaded( **prepared_inputs_dict).to_tuple() self.assertEqual( len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) pt_model = EncoderDecoderModel(encoder_decoder_config) fx_model = FlaxEncoderDecoderModel(encoder_decoder_config) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_equivalence_pt_to_flax(self): # It might be better to put this inside the for loop below (because we modify the config there). # But logically, it is fine. config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): # Output all for aggressive testing config.output_hidden_states = True config.output_attentions = self.has_attentions # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} # load corresponding PyTorch class pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) pt_model = pt_model_class(config).eval() # Flax models don't use the `use_cache` option and cache is not returned as a default. # So we disable `use_cache` here for PyTorch model. pt_model.config.use_cache = False fx_model = model_class(config, dtype=jnp.float32) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state # send pytorch model to the correct device pt_model.to(torch_device) with torch.no_grad(): pt_outputs = pt_model(**pt_inputs) fx_outputs = fx_model(**prepared_inputs_dict) fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) self.assertEqual(fx_keys, pt_keys) self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) self.assertEqual(fx_keys, pt_keys) self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
def test_equivalence_pt_to_flax(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): # load PyTorch class pt_model = model_class(config).eval() # Flax models don't use the `use_cache` option and cache is not returned as a default. # So we disable `use_cache` here for PyTorch model. pt_model.config.use_cache = False fx_model_class_name = "Flax" + model_class.__name__ if not hasattr(transformers, fx_model_class_name): return fx_model_class = getattr(transformers, fx_model_class_name) # load Flax class fx_model = fx_model_class(config, dtype=jnp.float32) # make sure only flax inputs are forward that actually exist in function args fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() # prepare inputs pt_inputs = self._prepare_for_class(inputs_dict, model_class) # remove function args that don't exist in Flax pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() # convert inputs to Flax fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_outputs = fx_model(**fx_inputs).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple() self.assertEqual( len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
def test_equivalence_pt_to_flax(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} # load corresponding PyTorch class pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) batch_size, seq_length = pt_inputs["input_ids"].shape rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) for batch_idx, start_index in enumerate(rnd_start_indices): pt_inputs["attention_mask"][batch_idx, :start_index] = 0 pt_inputs["attention_mask"][batch_idx, start_index:] = 1 prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 pt_model = pt_model_class(config).eval() fx_model = model_class(config, dtype=jnp.float32) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() self.assertEqual( len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)