Ejemplo n.º 1
0
        def test_pt_tf_model_equivalence(self):
            if not is_torch_available():
                return

            import torch
            import transformers

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
            )

            for model_class in self.all_model_classes:
                pt_model_class_name = model_class.__name__[
                    2:]  # Skip the "TF" at the beggining
                pt_model_class = getattr(transformers, pt_model_class_name)

                config.output_hidden_states = True
                tf_model = model_class(config)
                pt_model = pt_model_class(config)

                # Check we can load pt model in tf and vice-versa with model => model functions
                tf_model = transformers.load_pytorch_model_in_tf2_model(
                    tf_model, pt_model, tf_inputs=inputs_dict)
                pt_model = transformers.load_tf2_model_in_pytorch_model(
                    pt_model, tf_model)

                # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
                pt_model.eval()
                pt_inputs_dict = dict(
                    (name, torch.from_numpy(key.numpy()).to(torch.long))
                    for name, key in inputs_dict.items())
                with torch.no_grad():
                    pto = pt_model(**pt_inputs_dict)
                tfo = tf_model(inputs_dict)
                max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
                self.assertLessEqual(max_diff, 2e-2)

                # Check we can load pt model in tf and vice-versa with checkpoint => model functions
                with TemporaryDirectory() as tmpdirname:
                    pt_checkpoint_path = os.path.join(tmpdirname,
                                                      'pt_model.bin')
                    torch.save(pt_model.state_dict(), pt_checkpoint_path)
                    tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
                        tf_model, pt_checkpoint_path)

                    tf_checkpoint_path = os.path.join(tmpdirname,
                                                      'tf_model.h5')
                    tf_model.save_weights(tf_checkpoint_path)
                    pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
                        pt_model, tf_checkpoint_path)

                # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
                pt_model.eval()
                pt_inputs_dict = dict(
                    (name, torch.from_numpy(key.numpy()).to(torch.long))
                    for name, key in inputs_dict.items())
                with torch.no_grad():
                    pto = pt_model(**pt_inputs_dict)
                tfo = tf_model(inputs_dict)
                max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
                self.assertLessEqual(max_diff, 2e-2)
Ejemplo n.º 2
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.

        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
                - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
        config = kwargs.pop('config', None)
        state_dict = kwargs.pop('state_dict', None)
        cache_dir = kwargs.pop('cache_dir', None)
        from_tf = kwargs.pop('from_tf', False)
        force_download = kwargs.pop('force_download', False)
        proxies = kwargs.pop('proxies', None)
        output_loading_info = kwargs.pop('output_loading_info', False)

        # Load config
        if config is None:
            config, model_kwargs = cls.config_class.from_pretrained(
                pretrained_model_name_or_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                proxies=proxies,
                **kwargs)
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                archive_file = cls.pretrained_model_archive_map[
                    pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False"
                        .format([
                            WEIGHTS_NAME, TF2_WEIGHTS_NAME,
                            TF_WEIGHTS_NAME + ".index"
                        ], pretrained_model_name_or_path))
            elif os.path.isfile(pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(
                    pretrained_model_name_or_path)
                archive_file = pretrained_model_name_or_path + ".index"

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies)
            except EnvironmentError:
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                    msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
                        archive_file)
                else:
                    msg = "Model name '{}' was not found in model name list ({}). " \
                        "We assumed '{}' was a path or url to model weight files named one of {} but " \
                        "couldn't find any such file at this path or url.".format(
                            pretrained_model_name_or_path,
                            ', '.join(cls.pretrained_model_archive_map.keys()),
                            archive_file,
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if state_dict is None and not from_tf:
            state_dict = torch.load(resolved_archive_file, map_location='cpu')

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        if from_tf:
            if resolved_archive_file.endswith('.index'):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(
                    model, config,
                    resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
                    from transformers import load_tf2_checkpoint_in_pytorch_model
                    model = load_tf2_checkpoint_in_pytorch_model(
                        model, resolved_archive_file, allow_missing_keys=True)
                except ImportError as e:
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
                    raise e
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
                if 'gamma' in key:
                    new_key = key.replace('gamma', 'weight')
                if 'beta' in key:
                    new_key = key.replace('beta', 'bias')
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(
                    prefix[:-1], {})
                module._load_from_state_dict(state_dict, prefix,
                                             local_metadata, True,
                                             missing_keys, unexpected_keys,
                                             error_msgs)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')

            # Make sure we are able to load base models as well as derived models (with heads)
            start_prefix = ''
            model_to_load = model
            if not hasattr(model, cls.base_model_prefix) and any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                start_prefix = cls.base_model_prefix + '.'
            if hasattr(model, cls.base_model_prefix) and not any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)
            if len(missing_keys) > 0:
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".
                    format(model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError(
                    'Error(s) in loading state_dict for {}:\n\t{}'.format(
                        model.__class__.__name__, "\n\t".join(error_msgs)))

        if hasattr(model, 'tie_weights'):
            model.tie_weights(
            )  # make sure word embedding weights are still tied

        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs
            }
            return model, loading_info

        return model
Ejemplo n.º 3
0
    def test_pt_tf_model_equivalence(self):
        for model_class in self.all_model_classes:
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
                return_obj_labels="PreTraining" in model_class.__name__)

            tf_model_class_name = "TF" + model_class.__name__  # Add the "TF" at the beginning

            if not hasattr(transformers, tf_model_class_name):
                # transformers does not have TF version yet
                return

            tf_model_class = getattr(transformers, tf_model_class_name)

            config.output_hidden_states = True
            config.task_obj_predict = False

            pt_model = model_class(config)
            tf_model = tf_model_class(config)

            # Check we can load pt model in tf and vice-versa with model => model functions
            pt_inputs = self._prepare_for_class(inputs_dict, model_class)

            def recursive_numpy_convert(iterable):
                return_dict = {}
                for key, value in iterable.items():
                    if type(value) == bool:
                        return_dict[key] = value
                    if isinstance(value, dict):
                        return_dict[key] = recursive_numpy_convert(value)
                    else:
                        if isinstance(value, (list, tuple)):
                            return_dict[key] = (tf.convert_to_tensor(
                                iter_value.cpu().numpy(), dtype=tf.int32)
                                                for iter_value in value)
                        else:
                            return_dict[key] = tf.convert_to_tensor(
                                value.cpu().numpy(), dtype=tf.int32)
                return return_dict

            tf_inputs_dict = recursive_numpy_convert(pt_inputs)

            tf_model = transformers.load_pytorch_model_in_tf2_model(
                tf_model, pt_model, tf_inputs=tf_inputs_dict)
            pt_model = transformers.load_tf2_model_in_pytorch_model(
                pt_model, tf_model).to(torch_device)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()

            # Delete obj labels as we want to compute the hidden states and not the loss

            if "obj_labels" in inputs_dict:
                del inputs_dict["obj_labels"]

            pt_inputs = self._prepare_for_class(inputs_dict, model_class)
            tf_inputs_dict = recursive_numpy_convert(pt_inputs)

            with torch.no_grad():
                pto = pt_model(**pt_inputs)
            tfo = tf_model(tf_inputs_dict, training=False)
            tf_hidden_states = tfo[0].numpy()
            pt_hidden_states = pto[0].cpu().numpy()

            tf_nans = np.copy(np.isnan(tf_hidden_states))
            pt_nans = np.copy(np.isnan(pt_hidden_states))

            pt_hidden_states[tf_nans] = 0
            tf_hidden_states[tf_nans] = 0
            pt_hidden_states[pt_nans] = 0
            tf_hidden_states[pt_nans] = 0

            max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
            # Debug info (remove when fixed)
            if max_diff >= 2e-2:
                print("===")
                print(model_class)
                print(config)
                print(inputs_dict)
                print(pt_inputs)
            self.assertLessEqual(max_diff, 6e-2)

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
                    tf_model, pt_checkpoint_path)

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
                    pt_model, tf_checkpoint_path)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()

            for key, value in pt_inputs.items():
                if key in ("visual_feats", "visual_pos"):
                    pt_inputs[key] = value.to(torch.float32)
                else:
                    pt_inputs[key] = value.to(torch.long)

            with torch.no_grad():
                pto = pt_model(**pt_inputs)

            tfo = tf_model(tf_inputs_dict)
            tfo = tfo[0].numpy()
            pto = pto[0].cpu().numpy()
            tf_nans = np.copy(np.isnan(tfo))
            pt_nans = np.copy(np.isnan(pto))

            pto[tf_nans] = 0
            tfo[tf_nans] = 0
            pto[pt_nans] = 0
            tfo[pt_nans] = 0

            max_diff = np.amax(np.abs(tfo - pto))
            self.assertLessEqual(max_diff, 6e-2)
Ejemplo n.º 4
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        """Instantiate a pretrained pytorch model from a pre-trained model configuration.

        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
                - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
        config = kwargs.pop('config', None)
        state_dict = kwargs.pop('state_dict', None)
        cache_dir = kwargs.pop('cache_dir', None)
        from_tf = kwargs.pop('from_tf', False)
        force_download = kwargs.pop('force_download', False)
        proxies = kwargs.pop('proxies', None)
        output_loading_info = kwargs.pop('output_loading_info', False)
        random_init = kwargs.pop("random_init", False)
        use_cdn = kwargs.pop("use_cdn", True)
        local_files_only = kwargs.pop("local_files_only", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        kwargs_config = kwargs.copy()

        mapping_keys_state_dic = kwargs.pop("mapping_keys_state_dic", None)
        kwargs_config.pop("mapping_keys_state_dic", None)

        if config is None:

            config, model_kwargs = cls.config_class.from_pretrained(
                pretrained_model_name_or_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                **kwargs_config)
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False"
                        .format(
                            [
                                WEIGHTS_NAME, TF2_WEIGHTS_NAME,
                                TF_WEIGHTS_NAME + ".index"
                            ],
                            pretrained_model_name_or_path,
                        ))
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                assert (
                    from_tf
                ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
                    pretrained_model_name_or_path + ".index")
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
                    use_cdn=use_cdn,
                )

            try:
                # Load from URL or cache if already cached
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
                if resolved_archive_file is None:
                    raise EnvironmentError
            except EnvironmentError:
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        # Instantiate model.

        model = cls(config, *model_args, **model_kwargs)

        if state_dict is None and not from_tf:
            state_dict = torch.load(resolved_archive_file, map_location='cpu')

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        if from_tf:
            if resolved_archive_file.endswith('.index'):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(
                    model, config,
                    resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
                    from transformers import load_tf2_checkpoint_in_pytorch_model
                    model = load_tf2_checkpoint_in_pytorch_model(
                        model, resolved_archive_file, allow_missing_keys=True)
                except ImportError as e:
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
                    raise e
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []

            for key in state_dict.keys():
                new_key = None
                if 'gamma' in key:
                    new_key = key.replace('gamma', 'weight')
                if 'beta' in key:
                    new_key = key.replace('beta', 'bias')
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata
            # assert mapping_keys_state_dic is not None, "ERROR did not found mapping dicts for {} ".format(pretrained_model_name_or_path)
            # mapping_keys_state_dic = {"roberta": "encoder", "lm_head": "head.mlm"}
            if mapping_keys_state_dic is not None:
                assert isinstance(mapping_keys_state_dic, dict), "ERROR "
                print(
                    "INFO : from loading from pretrained method (assuming loading original google model : "
                    "need to rename some keys {})".format(
                        mapping_keys_state_dic))
                state_dict = cls.adapt_state_dic_to_multitask(
                    state_dict,
                    keys_mapping=mapping_keys_state_dic,
                    add_prefix=pretrained_model_name_or_path ==
                    "asafaya/bert-base-arabic")
                #pdb.set_trace()

            def load(module, prefix=''):

                local_metadata = {"version": 1}

                if not prefix.startswith("head") or prefix.startswith(
                        "head.mlm"):
                    assert len(
                        missing_keys
                    ) == 0, "ERROR {} missing keys in state_dict {}".format(
                        prefix, missing_keys)
                else:
                    if len(missing_keys) == 0:
                        print(
                            "Warning {} missing keys in state_dict {} (warning expected for task-specific fine-tuning)"
                            .format(prefix, missing_keys))

                module._load_from_state_dict(state_dict, prefix,
                                             local_metadata, True,
                                             missing_keys, unexpected_keys,
                                             error_msgs)
                for name, child in module._modules.items():

                    # load_params_only_ls = kwargs.get("load_params_only_ls ")
                    not_load_params_ls = kwargs.get(
                        "not_load_params_ls") if kwargs.get(
                            "not_load_params_ls") is not None else []
                    assert isinstance(
                        not_load_params_ls, list
                    ), f"Argument error not_load_params_ls should be a list but is {not_load_params_ls}"
                    matching_not_load = []
                    # RANDOM-INIT
                    for pattern in not_load_params_ls:
                        matching = re.match(pattern, prefix + name)
                        if matching is not None:
                            matching_not_load.append(matching)
                    if len(matching_not_load) > 0:
                        # means there is at least one patter in not load pattern that matched --> so should load
                        print("MATCH not loading : {} parameters {} ".format(
                            prefix + name, not_load_params_ls))
                    if child is not None and len(matching_not_load) == 0:
                        #print("MODEL loading : child {} full {} ".format(name, prefix + name + '.'))
                        load(child, prefix + name + '.')
                    else:
                        print(
                            "MODEL not loading : child {} matching_not_load {} "
                            .format(child, matching_not_load))

            # Make sure we are able to load base models as well as derived models (with heads)
            start_prefix = ''
            model_to_load = model
            if not hasattr(model, cls.base_model_prefix) and any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                start_prefix = cls.base_model_prefix + '.'
            if hasattr(model, cls.base_model_prefix) and not any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                model_to_load = getattr(model, cls.base_model_prefix)
            if not random_init:
                load(model_to_load, prefix=start_prefix)
            else:
                print("WARNING : RANDOM INTIALIZATION OF BERTMULTITASK")

            if len(missing_keys) > 0:
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".
                    format(model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError(
                    'Error(s) in loading state_dict for {}:\n\t{}'.format(
                        model.__class__.__name__, "\n\t".join(error_msgs)))

        if hasattr(model, 'tie_weights'):
            model.tie_weights(
            )  # make sure word embedding weights are still tied

        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs
            }
            return model, loading_info

        return model
Ejemplo n.º 5
0
    def from_pretrained(cls, pretrained_model_name_or_path, eigvecs_dict,
                        *model_args, **kwargs):
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                archive_file = cls.pretrained_model_archive_map[
                    pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False"
                        .format(
                            [
                                WEIGHTS_NAME, TF2_WEIGHTS_NAME,
                                TF_WEIGHTS_NAME + ".index"
                            ],
                            pretrained_model_name_or_path,
                        ))
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                assert (
                    from_tf
                ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
                    pretrained_model_name_or_path + ".index")
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
                )

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
            except EnvironmentError:
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                    msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
                        archive_file)
                else:
                    msg = (
                        "Model name '{}' was not found in model name list ({}). "
                        "We assumed '{}' was a path or url to model weight files named one of {} but "
                        "couldn't find any such file at this path or url.".
                        format(
                            pretrained_model_name_or_path,
                            ", ".join(cls.pretrained_model_archive_map.keys()),
                            archive_file,
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
                        ))
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        # Instantiate model.
        model = cls(config, eigvecs_dict, *model_args, **model_kwargs)

        if state_dict is None and not from_tf:
            try:
                state_dict = torch.load(resolved_archive_file,
                                        map_location="cpu")
            except Exception:
                raise OSError(
                    "Unable to load weights from pytorch checkpoint file. "
                    "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
                )

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        if from_tf:
            if resolved_archive_file.endswith(".index"):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(
                    model, config,
                    resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
                    from transformers import load_tf2_checkpoint_in_pytorch_model

                    model = load_tf2_checkpoint_in_pytorch_model(
                        model, resolved_archive_file, allow_missing_keys=True)
                except ImportError:
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
                    raise
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
                if "gamma" in key:
                    new_key = key.replace("gamma", "weight")
                if "beta" in key:
                    new_key = key.replace("beta", "bias")
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, "_metadata", None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
            def load(module: nn.Module, prefix=""):
                local_metadata = {} if metadata is None else metadata.get(
                    prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict,
                    prefix,
                    local_metadata,
                    True,
                    missing_keys,
                    unexpected_keys,
                    error_msgs,
                )
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + ".")

            # Make sure we are able to load base models as well as derived models (with heads)
            start_prefix = ""
            model_to_load = model
            if not hasattr(model, cls.base_model_prefix) and any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                start_prefix = cls.base_model_prefix + "."
            if hasattr(model, cls.base_model_prefix) and not any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)

            if model.__class__.__name__ != model_to_load.__class__.__name__:
                base_model_state_dict = model_to_load.state_dict().keys()
                head_model_state_dict_without_base_prefix = [
                    key.split(cls.base_model_prefix + ".")[-1]
                    for key in model.state_dict().keys()
                ]

                missing_keys.extend(head_model_state_dict_without_base_prefix -
                                    base_model_state_dict)

            if len(missing_keys) > 0:
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".
                    format(model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError(
                    "Error(s) in loading state_dict for {}:\n\t{}".format(
                        model.__class__.__name__, "\n\t".join(error_msgs)))
        model.tie_weights(
        )  # make sure token embedding weights are still tied if needed

        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs,
            }
            return model, loading_info

        return model
    def test_pt_tf_model_equivalence(self):
        if not is_torch_available():
            return

        import torch
        import transformers

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            pt_model_class_name = model_class.__name__[
                2:]  # Skip the "TF" at the beggining
            pt_model_class = getattr(transformers, pt_model_class_name)

            config.output_hidden_states = True

            tf_model = model_class(config)
            pt_model = pt_model_class(config)

            # Check we can load pt model in tf and vice-versa with model => model functions

            tf_model = transformers.load_pytorch_model_in_tf2_model(
                tf_model,
                pt_model,
                tf_inputs=self._prepare_for_class(inputs_dict, model_class))
            pt_model = transformers.load_tf2_model_in_pytorch_model(
                pt_model, tf_model)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = dict(
                (name, torch.from_numpy(key.numpy()).to(torch.long))
                for name, key in self._prepare_for_class(
                    inputs_dict, model_class).items())
            # need to rename encoder-decoder "inputs" for PyTorch
            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(self._prepare_for_class(inputs_dict, model_class),
                           training=False)
            tf_hidden_states = tfo[0].numpy()
            pt_hidden_states = pto[0].numpy()

            tf_nans = np.copy(np.isnan(tf_hidden_states))
            pt_nans = np.copy(np.isnan(pt_hidden_states))

            pt_hidden_states[tf_nans] = 0
            tf_hidden_states[tf_nans] = 0
            pt_hidden_states[pt_nans] = 0
            tf_hidden_states[pt_nans] = 0

            max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
            # Debug info (remove when fixed)
            if max_diff >= 2e-2:
                print("===")
                print(model_class)
                print(config)
                print(inputs_dict)
                print(pt_inputs_dict)
            self.assertLessEqual(max_diff, 2e-2)

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
                    tf_model, pt_checkpoint_path)

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
                    pt_model, tf_checkpoint_path)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = dict(
                (name, torch.from_numpy(key.numpy()).to(torch.long))
                for name, key in self._prepare_for_class(
                    inputs_dict, model_class).items())
            # need to rename encoder-decoder "inputs" for PyTorch
            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))
            tfo = tfo[0].numpy()
            pto = pto[0].numpy()
            tf_nans = np.copy(np.isnan(tfo))
            pt_nans = np.copy(np.isnan(pto))

            pto[tf_nans] = 0
            tfo[tf_nans] = 0
            pto[pt_nans] = 0
            tfo[pt_nans] = 0

            max_diff = np.amax(np.abs(tfo - pto))
            self.assertLessEqual(max_diff, 2e-2)
Ejemplo n.º 7
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)
        use_cdn = kwargs.pop("use_cdn", True)

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False"
                        .format(
                            [
                                WEIGHTS_NAME, TF2_WEIGHTS_NAME,
                                TF_WEIGHTS_NAME + ".index"
                            ],
                            pretrained_model_name_or_path,
                        ))
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                assert (from_tf), (
                    "We found a TensorFlow checkpoint at {}, please set from_tf to True"
                    " to load from this checkpoint"
                ).format(pretrained_model_name_or_path + ".index")
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
                    use_cdn=use_cdn,
                )

            try:
                # Load from URL or cache if already cached
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
                if resolved_archive_file is None:
                    raise EnvironmentError
            except EnvironmentError:
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make "
                    f"sure that:\n\n- '{pretrained_model_name_or_path}' is a correct "
                    f"model identifier listed on 'https://huggingface.co/models'\n\n- "
                    f"or '{pretrained_model_name_or_path}' is the correct path to a "
                    f"directory containing a file named one of {WEIGHTS_NAME}, "
                    f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n")
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if state_dict is None and not from_tf:
            try:
                state_dict = torch.load(resolved_archive_file,
                                        map_location="cpu")
            except Exception:
                raise OSError(
                    "Unable to load weights from pytorch checkpoint file. "
                    "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
                )

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        if from_tf:
            if resolved_archive_file.endswith(".index"):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(
                    model, config,
                    resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
                    from transformers import load_tf2_checkpoint_in_pytorch_model

                    model = load_tf2_checkpoint_in_pytorch_model(
                        model, resolved_archive_file, allow_missing_keys=True)
                except ImportError:
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
                    raise
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            has_all_sub_modules = all(
                any(s.startswith(_prefix) for s in state_dict.keys())
                for _prefix in model.base_model_prefixs)
            has_prefix_module = any(
                s.startswith(model.base_model_prefix)
                for s in state_dict.keys())

            old_keys = list(state_dict.keys())
            for key in old_keys:
                new_key = key

                if "gamma" in key:
                    new_key = new_key.replace("gamma", "weight")
                if "beta" in key:
                    new_key = new_key.replace("beta", "bias")

                _state = state_dict.pop(key)
                if has_all_sub_modules:
                    state_dict[new_key] = _state
                elif not has_prefix_module:
                    for _prefix in model.base_model_prefixs:
                        _key = _prefix + "." + new_key
                        state_dict[_key] = _state
                else:
                    if new_key.startswith(model.base_model_prefix):
                        for _prefix in model.base_model_prefixs:
                            _key = _prefix + new_key[len(model.
                                                         base_model_prefix):]
                            state_dict[_key] = _state
                    else:
                        state_dict[new_key] = _state
            if hasattr(model, "hack_pretrained_state_dict"):
                state_dict = model.hack_pretrained_state_dict(state_dict)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, "_metadata", None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
            def load(module, prefix=""):
                local_metadata = {} if metadata is None else metadata.get(
                    prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict,
                    prefix,
                    local_metadata,
                    True,
                    missing_keys,
                    unexpected_keys,
                    error_msgs,
                )
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + ".")

            # Make sure we are able to load base models as well as derived models (with heads)
            load(model)
            load = None

            if len(missing_keys) > 0:
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".
                    format(model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError(
                    "Error(s) in loading state_dict for {}:\n\t{}".format(
                        model.__class__.__name__, "\n\t".join(error_msgs)))
        model.tie_weights(
        )  # make sure token embedding weights are still tied if needed

        # Set model in evaluation mode to deactivate DropOut modules by default
        model.eval()

        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs,
            }
            return model, loading_info

        if hasattr(config, "xla_device") and config.xla_device:
            import torch_xla.core.xla_model as xm

            model = xm.send_cpu_data_to_device(model, xm.xla_device())
            model.to(xm.xla_device())

        return model
    def test_pt_tf_model_equivalence(self):
        import numpy as np
        import tensorflow as tf

        import transformers

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            tf_model_class_name = "TF" + model_class.__name__  # Add the "TF" at the beginning

            if not hasattr(transformers, tf_model_class_name):
                # transformers does not have TF version yet
                return

            tf_model_class = getattr(transformers, tf_model_class_name)

            config.output_hidden_states = True

            tf_model = tf_model_class(config)
            pt_model = model_class(config)

            # make sure only tf inputs are forward that actually exist in function args
            tf_input_keys = set(
                inspect.signature(tf_model.call).parameters.keys())

            # remove all head masks
            tf_input_keys.discard("head_mask")
            tf_input_keys.discard("cross_attn_head_mask")
            tf_input_keys.discard("decoder_head_mask")

            pt_inputs = self._prepare_for_class(inputs_dict, model_class)
            pt_inputs = {
                k: v
                for k, v in pt_inputs.items() if k in tf_input_keys
            }

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            tf_inputs_dict = {}
            for key, tensor in pt_inputs.items():
                # skip key that does not exist in tf
                if type(tensor) == bool:
                    tf_inputs_dict[key] = tensor
                elif key == "input_values":
                    tf_inputs_dict[key] = tf.convert_to_tensor(
                        tensor.numpy(), dtype=tf.float32)
                elif key == "pixel_values":
                    tf_inputs_dict[key] = tf.convert_to_tensor(
                        tensor.numpy(), dtype=tf.float32)
                else:
                    tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(),
                                                               dtype=tf.int32)

            # Check we can load pt model in tf and vice-versa with model => model functions
            tf_model = transformers.load_pytorch_model_in_tf2_model(
                tf_model, pt_model, tf_inputs=tf_inputs_dict)
            pt_model = transformers.load_tf2_model_in_pytorch_model(
                pt_model, tf_model)

            # need to rename encoder-decoder "inputs" for PyTorch
            #            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
            #                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

            with torch.no_grad():
                pto = pt_model(**pt_inputs)
            tfo = tf_model(tf_inputs_dict, training=False)

            self.assertEqual(len(tfo), len(pto),
                             "Output lengths differ between TF and PyTorch")
            for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):

                if not (isinstance(tf_output, tf.Tensor)
                        and isinstance(pt_output, torch.Tensor)):
                    continue

                tf_out = tf_output.numpy()
                pt_out = pt_output.numpy()

                self.assertEqual(
                    tf_out.shape, pt_out.shape,
                    "Output component shapes differ between TF and PyTorch")

                if len(tf_out.shape) > 0:

                    tf_nans = np.copy(np.isnan(tf_out))
                    pt_nans = np.copy(np.isnan(pt_out))

                    pt_out[tf_nans] = 0
                    tf_out[tf_nans] = 0
                    pt_out[pt_nans] = 0
                    tf_out[pt_nans] = 0

                max_diff = np.amax(np.abs(tf_out - pt_out))
                self.assertLessEqual(max_diff, 4e-2)

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
                    tf_model, pt_checkpoint_path)

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
                    pt_model, tf_checkpoint_path)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            tf_inputs_dict = {}
            for key, tensor in pt_inputs.items():
                # skip key that does not exist in tf
                if type(tensor) == bool:
                    tensor = np.array(tensor, dtype=bool)
                    tf_inputs_dict[key] = tf.convert_to_tensor(tensor,
                                                               dtype=tf.int32)
                elif key == "input_values":
                    tf_inputs_dict[key] = tf.convert_to_tensor(
                        tensor.numpy(), dtype=tf.float32)
                elif key == "pixel_values":
                    tf_inputs_dict[key] = tf.convert_to_tensor(
                        tensor.numpy(), dtype=tf.float32)
                else:
                    tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(),
                                                               dtype=tf.int32)

            # need to rename encoder-decoder "inputs" for PyTorch
            #            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
            #                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

            with torch.no_grad():
                pto = pt_model(**pt_inputs)

            tfo = tf_model(tf_inputs_dict)

            self.assertEqual(len(tfo), len(pto),
                             "Output lengths differ between TF and PyTorch")
            for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):

                if not (isinstance(tf_output, tf.Tensor)
                        and isinstance(pt_output, torch.Tensor)):
                    continue

                tf_out = tf_output.numpy()
                pt_out = pt_output.numpy()

                self.assertEqual(
                    tf_out.shape, pt_out.shape,
                    "Output component shapes differ between TF and PyTorch")

                if len(tf_out.shape) > 0:
                    tf_nans = np.copy(np.isnan(tf_out))
                    pt_nans = np.copy(np.isnan(pt_out))

                    pt_out[tf_nans] = 0
                    tf_out[tf_nans] = 0
                    pt_out[pt_nans] = 0
                    tf_out[pt_nans] = 0

                max_diff = np.amax(np.abs(tf_out - pt_out))
                self.assertLessEqual(max_diff, 4e-2)
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)


        if not isinstance(config, BertConfig.PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        if pretrained_model_name_or_path is not None:
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):

                    archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):

                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):

                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
                            pretrained_model_name_or_path,
                        )
                    )
            elif os.path.isfile(pretrained_model_name_or_path) or urlparse(pretrained_model_name_or_path) in ["http", "https", "s3"]:
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                assert (
                    from_tf
                ), "We found a TensorFlow checkpoint at {}".format(
                    pretrained_model_name_or_path + ".index"
                )
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                archive_file="/".join((S3_BUCKET_PREFIX, pretrained_model_name_or_path, (TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME)))


            try:
                resolved_archive_file = BertConfig.cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
            except EnvironmentError:

                raise EnvironmentError

            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        model = cls(config, *model_args, **model_kwargs)

        if state_dict is None and not from_tf:
            try:
                state_dict = torch.load(resolved_archive_file, map_location="cpu")
            except Exception:
                raise OSError(
                    "Unable to load weights from pytorch checkpoint file. "
                )

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        if from_tf:
            if resolved_archive_file.endswith(".index"):

                model = cls.load_tf_weights(model, config, resolved_archive_file[:-6])
            else:

                try:
                    from transformers import load_tf2_checkpoint_in_pytorch_model

                    model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
                except ImportError:
                    raise
        else:

            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
                if "gamma" in key:
                    new_key = key.replace("gamma", "weight")
                if "beta" in key:
                    new_key = key.replace("beta", "bias")
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            metadata = getattr(state_dict, "_metadata", None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

            def load(module: nn.Module, prefix=""):
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
                )
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + ".")

            start_prefix = ""
            model_to_load = model
            if not hasattr(model, cls.base_model_prefix) and any(
                s.startswith(cls.base_model_prefix) for s in state_dict.keys()
            ):
                start_prefix = cls.base_model_prefix + "."
            if hasattr(model, cls.base_model_prefix) and not any(
                s.startswith(cls.base_model_prefix) for s in state_dict.keys()
            ):
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)

            if model.__class__.__name__ != model_to_load.__class__.__name__:
                base_model_state_dict = model_to_load.state_dict().keys()
                head_model_state_dict_without_base_prefix = [
                    key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
                ]

                missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)

        model.tie_weights()
        model.eval()

        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs,
            }
            return model, loading_info

        return model
    def test_pt_tf_model_equivalence(self):
        from transformers import is_torch_available

        if not is_torch_available():
            return

        import torch

        import transformers

        for model_class in self.all_model_classes:
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
                return_obj_labels="PreTraining" in model_class.__name__
            )

            pt_model_class_name = model_class.__name__[2:]  # Skip the "TF" at the beginning
            pt_model_class = getattr(transformers, pt_model_class_name)

            config.output_hidden_states = True
            config.task_obj_predict = False

            tf_model = model_class(config)
            pt_model = pt_model_class(config)

            # Check we can load pt model in tf and vice-versa with model => model functions

            tf_model = transformers.load_pytorch_model_in_tf2_model(
                tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)
            )
            pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()

            # Delete obj labels as we want to compute the hidden states and not the loss

            if "obj_labels" in inputs_dict:
                del inputs_dict["obj_labels"]

            def torch_type(key):
                if key in ("visual_feats", "visual_pos"):
                    return torch.float32
                else:
                    return torch.long

            def recursive_numpy_convert(iterable):
                return_dict = {}
                for key, value in iterable.items():
                    if isinstance(value, dict):
                        return_dict[key] = recursive_numpy_convert(value)
                    else:
                        if isinstance(value, (list, tuple)):
                            return_dict[key] = (
                                torch.from_numpy(iter_value.numpy()).to(torch_type(key)) for iter_value in value
                            )
                        else:
                            return_dict[key] = torch.from_numpy(value.numpy()).to(torch_type(key))
                return return_dict

            pt_inputs_dict = recursive_numpy_convert(self._prepare_for_class(inputs_dict, model_class))

            # need to rename encoder-decoder "inputs" for PyTorch
            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
            tf_hidden_states = tfo[0].numpy()
            pt_hidden_states = pto[0].numpy()

            import numpy as np

            tf_nans = np.copy(np.isnan(tf_hidden_states))
            pt_nans = np.copy(np.isnan(pt_hidden_states))

            pt_hidden_states[tf_nans] = 0
            tf_hidden_states[tf_nans] = 0
            pt_hidden_states[pt_nans] = 0
            tf_hidden_states[pt_nans] = 0

            max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
            # Debug info (remove when fixed)
            if max_diff >= 2e-2:
                print("===")
                print(model_class)
                print(config)
                print(inputs_dict)
                print(pt_inputs_dict)
            self.assertLessEqual(max_diff, 6e-2)

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
            with tempfile.TemporaryDirectory() as tmpdirname:
                import os

                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = dict(
                (name, torch.from_numpy(key.numpy()).to(torch.long))
                for name, key in self._prepare_for_class(inputs_dict, model_class).items()
            )

            for key, value in pt_inputs_dict.items():
                if key in ("visual_feats", "visual_pos"):
                    pt_inputs_dict[key] = value.to(torch.float32)
                else:
                    pt_inputs_dict[key] = value.to(torch.long)

            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))
            tfo = tfo[0].numpy()
            pto = pto[0].numpy()
            tf_nans = np.copy(np.isnan(tfo))
            pt_nans = np.copy(np.isnan(pto))

            pto[tf_nans] = 0
            tfo[tf_nans] = 0
            pto[pt_nans] = 0
            tfo[pt_nans] = 0

            max_diff = np.amax(np.abs(tfo - pto))
            self.assertLessEqual(max_diff, 6e-2)
Ejemplo n.º 11
0
    def test_pt_tf_model_equivalence(self):
        import torch

        import transformers

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            pt_model_class_name = model_class.__name__[2:]  # Skip the "TF" at the beginning
            pt_model_class = getattr(transformers, pt_model_class_name)

            config.output_hidden_states = True

            tf_model = model_class(config)
            pt_model = pt_model_class(config)

            # Check we can load pt model in tf and vice-versa with model => model functions

            tf_model = transformers.load_pytorch_model_in_tf2_model(
                tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)
            )
            pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = {}
            for name, key in self._prepare_for_class(inputs_dict, model_class).items():
                if type(key) == bool:
                    pt_inputs_dict[name] = key
                elif name == "input_values":
                    pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
                elif name == "pixel_values":
                    pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
                else:
                    pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)

            # need to rename encoder-decoder "inputs" for PyTorch
            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)

            self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
            for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):

                if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
                    continue

                tf_out = tf_output.numpy()
                pt_out = pt_output.numpy()

                self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")

                if len(tf_out.shape) > 0:

                    tf_nans = np.copy(np.isnan(tf_out))
                    pt_nans = np.copy(np.isnan(pt_out))

                    pt_out[tf_nans] = 0
                    tf_out[tf_nans] = 0
                    pt_out[pt_nans] = 0
                    tf_out[pt_nans] = 0

                max_diff = np.amax(np.abs(tf_out - pt_out))
                self.assertLessEqual(max_diff, 4e-2)

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = {}
            for name, key in self._prepare_for_class(inputs_dict, model_class).items():
                if type(key) == bool:
                    key = np.array(key, dtype=bool)
                    pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
                elif name == "input_values":
                    pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
                elif name == "pixel_values":
                    pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
                else:
                    pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
            # need to rename encoder-decoder "inputs" for PyTorch
            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))

            self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
            for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):

                if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
                    continue

                tf_out = tf_output.numpy()
                pt_out = pt_output.numpy()

                self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")

                if len(tf_out.shape) > 0:
                    tf_nans = np.copy(np.isnan(tf_out))
                    pt_nans = np.copy(np.isnan(pt_out))

                    pt_out[tf_nans] = 0
                    tf_out[tf_nans] = 0
                    pt_out[pt_nans] = 0
                    tf_out[pt_nans] = 0

                max_diff = np.amax(np.abs(tf_out - pt_out))
                self.assertLessEqual(max_diff, 4e-2)
Ejemplo n.º 12
0
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
    r"""
    Instantiate a pretrained pytorch model from a pre-trained model configuration.

    The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
    To train the model, you should first set it back in training mode with ``model.train()``.

    The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
    pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
    task.

    The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
    weights are discarded.

    Parameters:
      pretrained_model_name_or_path (:obj:`str`, `optional`):
        Can be either:

          - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
            ``bert-base-uncased``.
          - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
            ``dbmdz/bert-base-german-cased``.
          - A path to a `directory` containing model weights saved using
            :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
          - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
            this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
            as ``config`` argument. This loading path is slower than converting the torch.TensorFlow checkpoint in
            a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
          - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
            arguments ``config`` and ``state_dict``).
      model_args (sequence of positional arguments, `optional`):
        All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
      config (:obj:`typing.Union[PretrainedConfig, str]`, `optional`):
        Can be either:

          - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
          - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

        Configuration for the model to use instead of an automatically loaded configuation. Configuration can
        be automatically loaded when:

          - The model is a model provided by the library (loaded with the `shortcut name` string of a
            pretrained model).
          - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
            by suppling the save directory.
          - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
            configuration JSON file named `config.json` is found in the directory.
      state_dict (:obj:`typing.Dict[str, torch.Tensor]`, `optional`):
        A state dictionary to use instead of a state dictionary loaded from saved weights file.

        This option can be used if you want to create a model from a pretrained configuration but load your own
        weights. In this case though, you should check if using
        :func:`~transformers.PreTrainedModel.save_pretrained` and
        :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
      cache_dir (:obj:`str`, `optional`):
        Path to a directory in which a downloaded pretrained model configuration should be cached if the
        standard cache should not be used.
      from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
        Load the model weights from a torch.TensorFlow checkpoint save file (see docstring of
        ``pretrained_model_name_or_path`` argument).
      force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
        Whether or not to force the (re-)download of the model weights and configuration files, overriding the
        cached versions if they exist.
      resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
        Whether or not to delete incompletely received files. Will attempt to resume the download if such a
        file exists.
      proxies (:obj:`typing.Dict[str, str], `optional`):
        A dictionary of proxy servers to use by protocol or endpoint, e.g.,
        :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
        request.
      output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
        Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
        messages.
      local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
        Whether or not to only look at local files (e.g., not try doanloading the model).
      use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
        Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
        our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
      kwargs (remaining dictionary of keyword arguments, `optional`):
        Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
        :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
        automatically loaded:

          - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
            underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
            already been done)
          - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
            initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
            ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
            with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
            attribute will be passed to the underlying model's ``__init__`` function.

    Examples::

      from transformers import BertConfig, BertModel
      # Download model and configuration from S3 and cache.
      model = BertModel.from_pretrained('bert-base-uncased')
      # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
      model = BertModel.from_pretrained('./test/saved_model/')
      # Update configuration during loading.
      model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)
      assert model.config.output_attention == True
      # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
      config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
      model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
    """
    config = kwargs.pop("config", None)
    state_dict = kwargs.pop("state_dict", None)
    cache_dir = kwargs.pop("cache_dir", None)
    from_tf = kwargs.pop("from_tf", False)
    force_download = kwargs.pop("force_download", False)
    resume_download = kwargs.pop("resume_download", False)
    proxies = kwargs.pop("proxies", None)
    output_loading_info = kwargs.pop("output_loading_info", False)
    local_files_only = kwargs.pop("local_files_only", False)
    use_cdn = kwargs.pop("use_cdn", True)

    model_kwargs = kwargs

    # Load model
    if pretrained_model_name_or_path is not None:
      if os.path.isdir(pretrained_model_name_or_path):
        if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
          # Load from a TF 1.0 checkpoint
          archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
        elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
          # Load from a TF 2.0 checkpoint
          archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
        elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
          # Load from a PyTorch checkpoint
          archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
        else:
          raise EnvironmentError(
            "Error no file named {} found in directory {} or `from_tf` set to False".format(
              [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
              pretrained_model_name_or_path,
            )
          )
      elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
        archive_file = pretrained_model_name_or_path
      elif os.path.isfile(pretrained_model_name_or_path + ".index"):
        assert (
          from_tf
        ), "We found a torch.TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
          pretrained_model_name_or_path + ".index"
        )
        archive_file = pretrained_model_name_or_path + ".index"
      else:
        archive_file = hf_bucket_url(
          pretrained_model_name_or_path,
          filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
          use_cdn=use_cdn,
        )

      try:
        # Load from URL or cache if already cached
        resolved_archive_file = cached_path(
          archive_file,
          cache_dir=cache_dir,
          force_download=force_download,
          proxies=proxies,
          resume_download=resume_download,
          local_files_only=local_files_only,
        )
        if resolved_archive_file is None:
          raise EnvironmentError
      except EnvironmentError:
        msg = (
          f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
          f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
          f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
        )
        raise EnvironmentError(msg)

      if resolved_archive_file == archive_file:
        l.logger().info("loading weights file {}".format(archive_file))
      else:
        l.logger().info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
    else:
      resolved_archive_file = None

    # Instantiate model.
    model = cls(config, *model_args, **model_kwargs)

    if state_dict is None and not from_tf:
      try:
        state_dict = torch.load(resolved_archive_file, map_location="cpu")
      except Exception:
        raise OSError(
          "Unable to load weights from pytorch checkpoint file. "
          "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
        )

    missing_keys = []
    unexpected_keys = []
    error_msgs = []

    if from_tf:
      if resolved_archive_file.endswith(".index"):
        # Load from a torch.TensorFlow 1.X checkpoint - provided by original authors
        model = cls.load_tf_weights(model, config, resolved_archive_file[:-6])  # Remove the '.index'
      else:
        # Load from our torch.TensorFlow 2.0 checkpoints
        try:
          from transformers import load_tf2_checkpoint_in_pytorch_model

          model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
        except ImportError:
          l.logger().error(
            "Loading a torch.TensorFlow model in PyTorch, requires both PyTorch and torch.TensorFlow to be installed. Please see "
            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
          )
          raise
    else:
      # Convert old format to new format if needed from a PyTorch state_dict
      old_keys = []
      new_keys = []
      for key in state_dict.keys():
        new_key = None
        if "gamma" in key:
          new_key = key.replace("gamma", "weight")
        if "beta" in key:
          new_key = key.replace("beta", "bias")
        if new_key:
          old_keys.append(key)
          new_keys.append(new_key)
      for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

      # copy state_dict so _load_from_state_dict can modify it
      metadata = getattr(state_dict, "_metadata", None)
      state_dict = state_dict.copy()
      if metadata is not None:
        state_dict._metadata = metadata

      # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
      # so we need to apply the function recursively.
      def load(module: torch.nn.Module, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
          state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
        )
        for name, child in module._modules.items():
          if child is not None:
            load(child, prefix + name + ".")

      # Make sure we are able to load base models as well as derived models (with heads)
      start_prefix = ""
      model_to_load = model
      has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
      if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
        start_prefix = cls.base_model_prefix + "."
      if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
        model_to_load = getattr(model, cls.base_model_prefix)

      load(model_to_load, prefix=start_prefix)

      if model.__class__.__name__ != model_to_load.__class__.__name__:
        base_model_state_dict = model_to_load.state_dict().keys()
        head_model_state_dict_without_base_prefix = [
          key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
        ]
        missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)

      # Some models may have keys that are not in the state by design, removing them before needlessly warning
      # the user.
      if cls.authorized_missing_keys is not None:
        for pat in cls.authorized_missing_keys:
          missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

      if len(unexpected_keys) > 0:
        l.logger().warning(
          f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
          f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
          f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
          f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
          f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
          f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
        )
      else:
        l.logger().info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
      if len(missing_keys) > 0:
        l.logger().warning(
          f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
          f"and are newly initialized: {missing_keys}\n"
          f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
        )
      else:
        l.logger().info(
          f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
          f"If your task is similar to the task the model of the checkpoint was trained on, "
          f"you can already use {model.__class__.__name__} for predictions without further training."
        )
      if len(error_msgs) > 0:
        raise RuntimeError(
          "Error(s) in loading state_dict for {}:\n\t{}".format(
            model.__class__.__name__, "\n\t".join(error_msgs)
          )
        )
    # make sure token embedding weights are still tied if needed
    model.tie_weights()

    # typing.Set model in evaluation mode to deactivate DropOut modules by default
    model.eval()

    if output_loading_info:
      loading_info = {
        "missing_keys": missing_keys,
        "unexpected_keys": unexpected_keys,
        "error_msgs": error_msgs,
      }
      return model, loading_info

    if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
      model = pytorch.xla_model.send_cpu_data_to_device(model, pytorch.xla_model.xla_device())
      model.to(pytorch.xla_model.xla_device())

    return model
    def test_pt_tf_model_equivalence(self):
        import numpy as np
        import tensorflow as tf

        import transformers

        # make masks reproducible
        np.random.seed(2)

        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        num_patches = int((config.image_size // config.patch_size)**2)
        noise = np.random.uniform(size=(self.model_tester.batch_size,
                                        num_patches))
        pt_noise = torch.from_numpy(noise).to(device=torch_device)
        tf_noise = tf.constant(noise)

        def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict):

            tf_inputs_dict = {}
            for key, tensor in pt_inputs_dict.items():
                tf_inputs_dict[key] = tf.convert_to_tensor(
                    tensor.cpu().numpy(), dtype=tf.float32)

            return tf_inputs_dict

        def check_outputs(tf_outputs, pt_outputs, model_class, names):
            """
            Args:
                model_class: The class of the model that is currently testing. For example, `TFBertModel`,
                    TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
                    debugging easier and faster.

                names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
                    Currently unused, but in the future, we could use this information to make the error message clearer
                    by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
            """

            # Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
            if type(tf_outputs) in [tuple, list]:
                self.assertEqual(type(tf_outputs), type(pt_outputs))
                self.assertEqual(len(tf_outputs), len(pt_outputs))
                if type(names) == tuple:
                    for tf_output, pt_output, name in zip(
                            tf_outputs, pt_outputs, names):
                        check_outputs(tf_output,
                                      pt_output,
                                      model_class,
                                      names=name)
                elif type(names) == str:
                    for idx, (tf_output, pt_output) in enumerate(
                            zip(tf_outputs, pt_outputs)):
                        check_outputs(tf_output,
                                      pt_output,
                                      model_class,
                                      names=f"{names}_{idx}")
                else:
                    raise ValueError(
                        f"`names` should be a `tuple` or a string. Got {type(names)} instead."
                    )
            elif isinstance(tf_outputs, tf.Tensor):
                self.assertTrue(isinstance(pt_outputs, torch.Tensor))

                tf_outputs = tf_outputs.numpy()
                if isinstance(tf_outputs, np.float32):
                    tf_outputs = np.array(tf_outputs, dtype=np.float32)
                pt_outputs = pt_outputs.detach().to("cpu").numpy()

                tf_nans = np.isnan(tf_outputs)
                pt_nans = np.isnan(pt_outputs)

                pt_outputs[tf_nans] = 0
                tf_outputs[tf_nans] = 0
                pt_outputs[pt_nans] = 0
                tf_outputs[pt_nans] = 0

                max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
                self.assertLessEqual(max_diff, 1e-5)
            else:
                raise ValueError(
                    f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
                )

        def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict):
            # we are not preparing a model with labels because of the formation
            # of the ViT MAE model

            # send pytorch model to the correct device
            pt_model.to(torch_device)

            # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
            pt_model.eval()

            tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)

            # send pytorch inputs to the correct device
            pt_inputs_dict = {
                k:
                v.to(device=torch_device) if isinstance(v, torch.Tensor) else v
                for k, v in pt_inputs_dict.items()
            }

            # Original test: check without `labels`
            with torch.no_grad():
                pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise)
            tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise)

            tf_keys = tuple(
                [k for k, v in tf_outputs.items() if v is not None])
            pt_keys = tuple(
                [k for k, v in pt_outputs.items() if v is not None])

            self.assertEqual(tf_keys, pt_keys)
            check_outputs(tf_outputs.to_tuple(),
                          pt_outputs.to_tuple(),
                          model_class,
                          names=tf_keys)

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            tf_model_class_name = "TF" + model_class.__name__  # Add the "TF" at the beginning

            # Output all for aggressive testing
            config.output_hidden_states = True
            config.output_attentions = self.has_attentions

            tf_model_class = getattr(transformers, tf_model_class_name)

            tf_model = tf_model_class(config)
            pt_model = model_class(config)

            # make sure only tf inputs are forward that actually exist in function args
            tf_input_keys = set(
                inspect.signature(tf_model.call).parameters.keys())

            # remove all head masks
            tf_input_keys.discard("head_mask")
            tf_input_keys.discard("cross_attn_head_mask")
            tf_input_keys.discard("decoder_head_mask")

            pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)

            pt_inputs_dict = {
                k: v
                for k, v in pt_inputs_dict.items() if k in tf_input_keys
            }

            # Check we can load pt model in tf and vice-versa with model => model functions
            tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
            tf_model = transformers.load_pytorch_model_in_tf2_model(
                tf_model, pt_model, tf_inputs=tf_inputs_dict)
            pt_model = transformers.load_tf2_model_in_pytorch_model(
                pt_model, tf_model)

            check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
                    tf_model, pt_checkpoint_path)

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
                    pt_model, tf_checkpoint_path)
                pt_model = pt_model.to(torch_device)

            check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)