Exemplo n.º 1
0
    def test_greedy_generate_pt_fx(self):
        config, input_ids, _, max_length = self._get_input_ids_and_config()
        config.do_sample = False
        config.max_length = max_length
        config.decoder_start_token_id = 0

        for model_class in self.all_generative_model_classes:
            flax_model = model_class(config)

            pt_model_class_name = model_class.__name__[
                4:]  # Skip the "Flax" at the beginning
            pt_model_class = getattr(transformers, pt_model_class_name)
            pt_model = pt_model_class(config).eval()
            pt_model = load_flax_weights_in_pytorch_model(
                pt_model, flax_model.params)

            flax_generation_outputs = flax_model.generate(input_ids).sequences
            pt_generation_outputs = pt_model.generate(
                torch.tensor(input_ids, dtype=torch.long))

            if flax_generation_outputs.shape[-1] > pt_generation_outputs.shape[
                    -1]:
                flax_generation_outputs = flax_generation_outputs[:, :
                                                                  pt_generation_outputs
                                                                  .shape[-1]]

            self.assertListEqual(pt_generation_outputs.numpy().tolist(),
                                 flax_generation_outputs.tolist())
Exemplo n.º 2
0
    def test_save_load_to_base_pt(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        base_class = FLAX_MODEL_MAPPING[config.__class__]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            model = model_class(config)
            base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))

            # convert Flax model to PyTorch model
            pt_model_class = getattr(transformers, model_class.__name__[4:])  # Skip the "Flax" at the beginning
            pt_model = pt_model_class(config).eval()
            pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)

            # check that all base model weights are loaded correctly
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_model.save_pretrained(tmpdirname)
                base_model = base_class.from_pretrained(tmpdirname, from_pt=True)

                base_params = flatten_dict(unfreeze(base_model.params))

                for key in base_params_from_head.keys():
                    max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(
                    inputs_dict, model_class)
                pt_inputs = {
                    k: torch.tensor(v.tolist())
                    for k, v in prepared_inputs_dict.items()
                }

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[
                    4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(
                    pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()
                # PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
                pt_outputs = pt_outputs[1:]

                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
                self.assertEqual(
                    len(fx_outputs), len(pt_outputs),
                    "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs[:4],
                                                pt_outputs[:4]):
                    self.assert_almost_equals(fx_output, pt_output.numpy(),
                                              4e-2)

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(
                        tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
                pt_outputs_loaded = pt_outputs_loaded[1:]

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded),
                    "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs[:4],
                                                pt_outputs_loaded[:4]):
                    self.assert_almost_equals(fx_output, pt_output.numpy(),
                                              4e-2)
Exemplo n.º 4
0
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):

                # Output all for aggressive testing
                config.output_hidden_states = True
                config.output_attentions = self.has_attentions

                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                # send pytorch model to the correct device
                pt_model.to(torch_device)

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs)
                fx_outputs = fx_model(**prepared_inputs_dict)

                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

                self.assertEqual(fx_keys, pt_keys)
                self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                # send pytorch model to the correct device
                pt_model_loaded.to(torch_device)
                pt_model_loaded.eval()

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs)

                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
                pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])

                self.assertEqual(fx_keys, pt_keys)
                self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
Exemplo n.º 5
0
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # load corresponding PyTorch class
                pt_model = model_class(config).eval()

                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False

                fx_model_class_name = "Flax" + model_class.__name__

                if not hasattr(transformers, fx_model_class_name):
                    # no flax model exists for this class
                    return

                fx_model_class = getattr(transformers, fx_model_class_name)

                # load Flax class
                fx_model = fx_model_class(config, dtype=jnp.float32)
                # make sure only flax inputs are forward that actually exist in function args
                fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                # prepare inputs
                pt_inputs = self._prepare_for_class(inputs_dict, model_class)

                # remove function args that don't exist in Flax
                pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

                fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}

                fx_outputs = fx_model(**fx_inputs).to_tuple()
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")

                for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
Exemplo n.º 6
0
    def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):

        config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

        pt_model = VisionTextDualEncoderModel(config)
        fx_model = FlaxVisionTextDualEncoderModel(config)

        pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

        self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
Exemplo n.º 7
0
    def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):

        encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

        pt_model = EncoderDecoderModel(encoder_decoder_config)
        fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)

        pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

        self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
Exemplo n.º 8
0
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
                batch_size, seq_length = pt_inputs["input_ids"].shape
                rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
                for batch_idx, start_index in enumerate(rnd_start_indices):
                    pt_inputs["attention_mask"][batch_idx, :start_index] = 0
                    pt_inputs["attention_mask"][batch_idx, start_index:] = 1
                    prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
                    prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
                    self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
                    self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)