Ejemplo n.º 1
0
def load(kmdl, path):
    # load the parts which have identical names and shapes: std->std; lhc-formable->lhc-formable
    kmdl.load_weights(path, True, True)

    file0 = file = h5py.File(path, 'r')
    if 'layer_names' not in file.attrs and 'model_weights' in file:
        file = file['model_weights']
    from tensorflow.python.keras.saving.hdf5_format import _legacy_weights, load_attributes_from_hdf5_group, \
        preprocess_weights_for_loading

    if 'keras_version' in file.attrs:
        original_keras_version = file.attrs['keras_version']  # .decode('utf8')
    else:
        original_keras_version = '1'
    if 'backend' in file.attrs:
        original_backend = file.attrs['backend']  # .decode('utf8')
    else:
        original_backend = None

    layer_names = load_attributes_from_hdf5_group(file, 'layer_names')
    index = {}
    for layer in kmdl.layers:
        if layer.name:
            index.setdefault(layer.name, []).append(layer)

    # load the remaining parts
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = file[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [
            np.asarray(g[weight_name]) for weight_name in weight_names
        ]

        layer = index.get(name, [])
        if len(layer) == 0:
            continue
        assert len(layer) == 1
        layer = layer[0]

        if type(layer) in (Conv2dLhcf, Conv2dLhcr):
            weight_values = preprocess_weights_for_loading(
                layer, weight_values, original_keras_version, original_backend)
            wdict = dict(zip(weight_names, weight_values))

            symbolic_weights = _legacy_weights(layer)
            symbol_names = [s.name for s in symbolic_weights]
            sdict = dict(zip(symbol_names, symbolic_weights))

            for pname in Conv2dLhcf.VAR_NAMES[:3]:
                # symb = [__ for _, __ in sdict.items() if _[:-2].endswith(pname)]
                symb = [__ for _, __ in sdict.items() if pname in _]
                # wght = [__ for _, __ in wdict.items() if _[:-2].endswith(pname)]
                wght = [__ for _, __ in wdict.items() if pname in _]
                assert len(symb) == 1 and len(wght) <= 1
                if len(wght) == 1:
                    weight_value_tuples.append((symb[0], wght[0]))

    KB.batch_set_value(weight_value_tuples)
    file0.close()
def load_tf_weights(model, resolved_archive_file):
    """
    Load the TF weights from a H5 file.

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.
    """
    with h5py.File(resolved_archive_file, "r") as f:
        saved_layer_names = set(
            hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        weight_value_tuples = []

        for layer in model.layers:
            if layer.name in saved_layer_names:
                g = f[layer.name]
                saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(
                    g, "weight_names")
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
                saved_weight_names_values = {}

                for weight_name in saved_weight_names:
                    name = "/".join(weight_name.split("/")[1:])
                    saved_weight_names_values[name] = np.asarray(
                        g[weight_name])

                for symbolic_weight in symbolic_weights:
                    splited_layers = symbolic_weight.name.split("/")[1:]
                    symbolic_weight_name = "/".join(splited_layers)

                    if symbolic_weight_name in saved_weight_names_values:
                        saved_weight_value = saved_weight_names_values[
                            symbolic_weight_name]

                        if K.int_shape(
                                symbolic_weight) != saved_weight_value.shape:
                            try:
                                array = np.reshape(
                                    saved_weight_value,
                                    K.int_shape(symbolic_weight))
                            except AssertionError as e:
                                e.args += (K.int_shape(symbolic_weight),
                                           saved_weight_value.shape)
                                raise e
                        else:
                            array = saved_weight_value

                        weight_value_tuples.append((symbolic_weight, array))

    K.batch_set_value(weight_value_tuples)
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
    """
    Detect missing and unexpected layers.

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.

    Returns:
        Two lists, one for the missing layers, and another one for the unexpected layers.
    """
    missing_layers = []
    unexpected_layers = []

    with h5py.File(resolved_archive_file, "r") as f:
        saved_layer_names = set(
            hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        model_layer_names = set(layer.name for layer in model.layers)
        missing_layers = list(model_layer_names - saved_layer_names)
        unexpected_layers = list(saved_layer_names - model_layer_names)

        for layer in model.layers:
            if layer.name in saved_layer_names:
                g = f[layer.name]
                saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(
                    g, "weight_names")
                saved_weight_names_set = set(
                    "/".join(weight_name.split("/")[2:])
                    for weight_name in saved_weight_names)
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
                symbolic_weights_names = set(
                    "/".join(symbolic_weight.name.split("/")[2:])
                    for symbolic_weight in symbolic_weights)
                missing_layers.extend(
                    list(symbolic_weights_names - saved_weight_names_set))
                unexpected_layers.extend(
                    list(saved_weight_names_set - symbolic_weights_names))

    return missing_layers, unexpected_layers
Ejemplo n.º 4
0
    def load_from_bert_pretrained(cls, config_file, pretrained_model_name='bert-base-uncased', **kwargs):
        model = cls(config_file, **kwargs)
        model(model.dummy_inputs, training=False)

        ckpt_layer_mapping = {}
        for vind, ckpt_ind in enumerate(model.config.ckpt_layer_mapping.split(',')):
            ckpt_layer_mapping['layer_._{}'.format(vind)] = 'layer_._{}'.format(ckpt_ind)

        archive_file = hf_bucket_url(pretrained_model_name, filename=TF2_WEIGHTS_NAME, use_cdn=True)
        resolved_archive_file = cached_path(archive_file, cache_dir=None, force_download=False, resume_download=False,
                                            proxies=None)
        f = h5py.File(resolved_archive_file, mode='r')

        layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
        g = f[layer_names[0]]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
        weights_map = {'/'.join(name.split('/')[2:]): i for i, name in enumerate(weight_names)}
        weight_value_tuples = []
        w_names = []
        for w in model.layers[0].weights:
            w_name = '/'.join(w.name.split('/')[3:])
            for k in ckpt_layer_mapping:
                if w_name.find(k):
                    w_name = w_name.replace(k, ckpt_layer_mapping[k])
                    break

            if w_name in weights_map and w.shape == weight_values[weights_map[w_name]].shape:
                w_names.append(w_name)
                weight_value_tuples.append((w, weight_values[weights_map[w_name]]))

        logger.info("Loaded %d weights" % (len(w_names)))
        logger.info("Loaded weights names are: %s" % (", ".join(w_names)))

        K.batch_set_value(weight_value_tuples)

        print("Loaded %d weights" % (len(w_names)))
        print("Loaded weights names are: %s" % (", ".join(w_names)))

        model(model.dummy_inputs, training=False)
        return model
Ejemplo n.º 5
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.

        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
                - a path to a `directory` containing model weights saved using :func:`~xz_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `PyTorch state_dict saved_models file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) one of:
                    - an instance of a class derived from :class:`~xz_transformers.PretrainedConfig`, or
                    - a string valid as input to :func:`~xz_transformers.PretrainedConfig.from_pretrained()`
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~xz_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the saved_models directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            from_pt: (`optional`) boolean, default False:
                Load the model weights from a PyTorch state_dict saved_models file (see docstring of pretrained_model_name_or_path argument).

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~xz_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            # For tasks purposes. Not runnable.
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)

        """
        config = kwargs.pop("config", None)
        from_pt = kwargs.pop("from_pt", False)
        output_loading_info = kwargs.pop("output_loading_info", False)

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                return_unused_kwargs=True,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            archive_file = None
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                archive_file = cls.pretrained_model_archive_map[
                    pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif from_pt and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False"
                        .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME],
                                pretrained_model_name_or_path))
            elif os.path.isfile(pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"

            # 提供的pretrained_model_name_or_path有误
            if archive_file is None:
                raise EnvironmentError
            else:
                resolved_archive_file = archive_file
        else:
            resolved_archive_file = None

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            # Load from a PyTorch checkpoint
            return load_pytorch_checkpoint_in_tf2_model(
                model, resolved_archive_file, allow_missing_keys=True)

        model(model.dummy_inputs,
              training=False)  # build the network with dummy inputs

        assert os.path.isfile(
            resolved_archive_file), "Error retrieving file {}".format(
                resolved_archive_file)
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
        try:
            model.load_weights(resolved_archive_file, by_name=True)
        except OSError:
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )

        model(model.dummy_inputs,
              training=False)  # Make sure restore ops are run

        # Check if the models are the same to output loading informations
        with h5py.File(resolved_archive_file, "r") as f:
            if "layer_names" not in f.attrs and "model_weights" in f:
                f = f["model_weights"]
            hdf5_layer_names = set(
                hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        model_layer_names = set(layer.name for layer in model.layers)
        missing_keys = list(model_layer_names - hdf5_layer_names)
        unexpected_keys = list(hdf5_layer_names - model_layer_names)
        error_msgs = []

        if len(missing_keys) > 0:
            logger.info(
                "Layers of {} not initialized from pretrained model: {}".
                format(model.__class__.__name__, missing_keys))
        if len(unexpected_keys) > 0:
            logger.info(
                "Layers from pretrained model not used in {}: {}".format(
                    model.__class__.__name__, unexpected_keys))
        if len(error_msgs) > 0:
            raise RuntimeError(
                "Error(s) in loading weights for {}:\n\t{}".format(
                    model.__class__.__name__, "\n\t".join(error_msgs)))
        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs
            }
            return model, loading_info

        return model
Ejemplo n.º 6
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        r"""
        Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.

        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.

        Parameters:
            pretrained_model_name_or_path (:obj:`str`, `optional`):
                Can be either:

                    - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
                      ``bert-base-uncased``.
                    - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
                      ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
                    - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
                      this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
                      as ``config`` argument. This loading path is slower than converting the PyTorch model in a
                      TensorFlow model using the provided conversion scripts and loading the TensorFlow model
                      afterwards.
                    - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
                      arguments ``config`` and ``state_dict``).
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `shortcut name` string of a
                      pretrained model).
                    - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
                      by suppling the save directory.
                    - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
                      configuration JSON file named `config.json` is found in the directory.
            from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch state_dict save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies: (:obj:`Dict[str, str], `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g.,
                :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
                request.
            output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
                messages.
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (e.g., not try doanloading the model).
            use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
                our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            from transformers import BertConfig, TFBertModel
            # Download model and configuration from S3 and cache.
            model = TFBertModel.from_pretrained('bert-base-uncased')
            # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            model = TFBertModel.from_pretrained('./test/saved_model/')
            # Update configuration during loading.
            model = TFBertModel.from_pretrained('bert-base-uncased', output_attention=True)
            assert model.config.output_attention == True
            # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
            config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
            model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)

        """
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)
        use_cdn = kwargs.pop("use_cdn", True)

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif from_pt and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False"
                        .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME],
                                pretrained_model_name_or_path))
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
                    use_cdn=use_cdn,
                )

            try:
                # Load from URL or cache if already cached
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
                if resolved_archive_file is None:
                    raise EnvironmentError
            except EnvironmentError:
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            # Load from a PyTorch checkpoint
            return load_pytorch_checkpoint_in_tf2_model(
                model, resolved_archive_file, allow_missing_keys=True)

        model(model.dummy_inputs,
              training=False)  # build the network with dummy inputs

        assert os.path.isfile(
            resolved_archive_file), "Error retrieving file {}".format(
                resolved_archive_file)
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
        try:
            model.load_weights(resolved_archive_file, by_name=True)
        except OSError:
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )

        model(model.dummy_inputs,
              training=False)  # Make sure restore ops are run

        # Check if the models are the same to output loading informations
        with h5py.File(resolved_archive_file, "r") as f:
            if "layer_names" not in f.attrs and "model_weights" in f:
                f = f["model_weights"]
            hdf5_layer_names = set(
                hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        model_layer_names = set(layer.name for layer in model.layers)
        missing_keys = list(model_layer_names - hdf5_layer_names)
        unexpected_keys = list(hdf5_layer_names - model_layer_names)
        error_msgs = []

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.warning(
                f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
            )
        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
        else:
            logger.warning(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )
        if len(error_msgs) > 0:
            raise RuntimeError(
                "Error(s) in loading weights for {}:\n\t{}".format(
                    model.__class__.__name__, "\n\t".join(error_msgs)))
        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs
            }
            return model, loading_info

        return model
Ejemplo n.º 7
0
def from_pretrained_detailed(model_class, pretrained_model_name_or_path,
                             *model_args, **kwargs):
    r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.

    The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
    It is up to you to train those weights with a downstream fine-tuning task.

    The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

    Parameters:
        pretrained_model_name_or_path: either:
            - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
            - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
            - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
            - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.

        model_args: (`optional`) Sequence of positional arguments:
            All remaning positional arguments will be passed to the underlying model's ``__init__`` method

        config: (`optional`) one of:
                - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`

            Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

        from_pt: (`optional`) boolean, default False:
            Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).

        cache_dir: (`optional`) string:
            Path to a directory in which a downloaded pre-trained model
            configuration should be cached if the standard cache should not be used.

        force_download: (`optional`) boolean, default False:
            Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

        resume_download: (`optional`) boolean, default False:
            Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

        proxies: (`optional`) dict, default None:
            A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
            The proxies are used on each request.

        output_loading_info: (`optional`) boolean:
            Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

        kwargs: (`optional`) Remaining dictionary of keyword arguments:
            Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

            - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
            - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
            
            - If layer pruning is supported, ``layer_pruning`` will passed as a dictionary contains layer pruning configurations as follows:
                - strategy:
                    can be one of these values: {`top`, `buttom`, `symmetric`, `alternate`, `custom`}
                - k:
                    is the number of layers to prune. mandatory if strategy is one of {`top`, `buttom`, `symmetric`, `alternate`}
                - layers_indexes:
                    is array of layers indexs to prune. mandatory if strategy is `custom`
                - is_odd:
                    is odd alternate or not. mandatory if strategy is `alternate`

    Examples::

        # For example purposes. Not runnable.
        model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
        model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
        assert model.config.output_attention == True
        # Loading from a TF checkpoint file instead of a PyTorch model (slower)
        config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
        model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)

    """
    config = kwargs.pop("config", None)
    cache_dir = kwargs.pop("cache_dir", None)
    from_pt = kwargs.pop("from_pt", False)
    force_download = kwargs.pop("force_download", False)
    resume_download = kwargs.pop("resume_download", False)
    proxies = kwargs.pop("proxies", None)
    output_loading_info = kwargs.pop("output_loading_info", False)
    local_files_only = kwargs.pop("local_files_only", False)
    use_cdn = kwargs.pop("use_cdn", True)

    # mwahdan: Read layer_pruning config if exist
    layer_pruning = kwargs.pop("layer_pruning", None)

    # Load config if we don't provide a configuration
    if not isinstance(config, PretrainedConfig):
        config_path = config if config is not None else pretrained_model_name_or_path
        config, model_kwargs = model_class.config_class.from_pretrained(
            config_path,
            *model_args,
            cache_dir=cache_dir,
            return_unused_kwargs=True,
            force_download=force_download,
            resume_download=resume_download,
            proxies=proxies,
            local_files_only=local_files_only,
            **kwargs,
        )
    else:
        model_kwargs = kwargs

    # Load model
    if pretrained_model_name_or_path is not None:
        if os.path.isdir(pretrained_model_name_or_path):
            if os.path.isfile(
                    os.path.join(pretrained_model_name_or_path,
                                 TF2_WEIGHTS_NAME)):
                # Load from a TF 2.0 checkpoint
                archive_file = os.path.join(pretrained_model_name_or_path,
                                            TF2_WEIGHTS_NAME)
            elif from_pt and os.path.isfile(
                    os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                # Load from a PyTorch checkpoint
                archive_file = os.path.join(pretrained_model_name_or_path,
                                            WEIGHTS_NAME)
            else:
                raise EnvironmentError(
                    "Error no file named {} found in directory {} or `from_pt` set to False"
                    .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME],
                            pretrained_model_name_or_path))
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
                pretrained_model_name_or_path):
            archive_file = pretrained_model_name_or_path
        elif os.path.isfile(pretrained_model_name_or_path + ".index"):
            archive_file = pretrained_model_name_or_path + ".index"
        else:
            archive_file = hf_bucket_url(
                pretrained_model_name_or_path,
                filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
            )

        try:
            # Load from URL or cache if already cached
            resolved_archive_file = cached_path(
                archive_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
            )
            if resolved_archive_file is None:
                raise EnvironmentError
        except EnvironmentError:
            msg = (
                f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
            )
            raise EnvironmentError(msg)
        if resolved_archive_file == archive_file:
            logger.info("loading weights file {}".format(archive_file))
        else:
            logger.info("loading weights file {} from cache at {}".format(
                archive_file, resolved_archive_file))
    else:
        resolved_archive_file = None

    # mwahdan: Modify config
    if layer_pruning:
        layer_pruning_k = layer_pruning_layers_indexes = layer_pruning_is_odd = None
        layer_pruning_strategy = get_mandatory_parameter(
            'strategy', layer_pruning)
        if layer_pruning_strategy in {'top', 'buttom', 'symmetric'}:
            layer_pruning_k = get_mandatory_parameter('k', layer_pruning)
            config, original_num_layers = modify_num_of_layers(
                config, k=layer_pruning_k)
        elif layer_pruning_strategy == 'custom':
            layer_pruning_layers_indexes = get_mandatory_parameter(
                'layers_indexes', layer_pruning)
            config, original_num_layers = modify_num_of_layers(
                config, layers_indexes=layer_pruning_layers_indexes)
        elif layer_pruning_strategy == 'alternate':
            layer_pruning_k = get_mandatory_parameter('k', layer_pruning)
            layer_pruning_is_odd = get_mandatory_parameter(
                'is_odd', layer_pruning)
            config, original_num_layers = modify_num_of_layers(
                config, k=layer_pruning_k, is_alternate=True)
        else:
            raise Exception('`%s` is not a supported layer pruning strategy' %
                            layer_pruning_strategy)

    # Instantiate model.
    model = model_class(config, *model_args, **model_kwargs)

    # mwahdan: Rename layers
    if layer_pruning:
        model = rename_layers_in_strategy(model, layer_pruning_strategy,
                                          original_num_layers, layer_pruning_k,
                                          layer_pruning_layers_indexes,
                                          layer_pruning_is_odd)

    if from_pt:
        # Load from a PyTorch checkpoint
        model = load_pytorch_checkpoint_in_tf2_model(model,
                                                     resolved_archive_file,
                                                     allow_missing_keys=True)
        # mwahdan: Rename layers
        if layer_pruning is not None:
            model = rename_layers(model)
        return model

    model(model.dummy_inputs,
          training=False)  # build the network with dummy inputs

    assert os.path.isfile(
        resolved_archive_file), "Error retrieving file {}".format(
            resolved_archive_file)
    # 'by_name' allow us to do transfer learning by skipping/adding layers
    # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
    try:
        # added skip_mismatch=True because we will prune full layers
        model.load_weights(resolved_archive_file,
                           by_name=True,
                           skip_mismatch=True)
        # mwahdan: Rename layers
    except OSError:
        raise OSError(
            "Unable to load weights from h5 file. "
            "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
        )

    model(model.dummy_inputs, training=False)  # Make sure restore ops are run

    # mwahdan: Rename layers
    if layer_pruning is not None:
        model = rename_layers(model)

    # Check if the models are the same to output loading informations
    with h5py.File(resolved_archive_file, "r") as f:
        if "layer_names" not in f.attrs and "model_weights" in f:
            f = f["model_weights"]
        hdf5_layer_names = set(
            hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
    model_layer_names = set(layer.name for layer in model.layers)
    missing_keys = list(model_layer_names - hdf5_layer_names)
    unexpected_keys = list(hdf5_layer_names - model_layer_names)
    error_msgs = []

    if len(unexpected_keys) > 0:
        logger.warning(
            f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
            f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
            f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
            f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
            f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
            f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
        )
    else:
        logger.warning(
            f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
        )
    if len(missing_keys) > 0:
        logger.warning(
            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
            f"and are newly initialized: {missing_keys}\n"
            f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
        )
    else:
        logger.warning(
            f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
            f"If your task is similar to the task the model of the ckeckpoint was trained on, "
            f"you can already use {model.__class__.__name__} for predictions without further training."
        )
    if len(error_msgs) > 0:
        raise RuntimeError("Error(s) in loading weights for {}:\n\t{}".format(
            model.__class__.__name__, "\n\t".join(error_msgs)))
    if output_loading_info:
        loading_info = {
            "missing_keys": missing_keys,
            "unexpected_keys": unexpected_keys,
            "error_msgs": error_msgs
        }
        return model, loading_info

    return model
Ejemplo n.º 8
0
def load_weights_from_hdf5_group_by_name_mapping(f,
                                                 layers,
                                                 name_mapping,
                                                 skip_mismatch=False):
    """Implements name-based weight loading.
    (instead of topological weight loading).
    Layers that have no matching name are skipped.
    Args:
        f: A pointer to a HDF5 group.
        layers: a list of target layers.
        name_mapping : name mapping dict
        skip_mismatch: Boolean, whether to skip loading of layers
            where there is a mismatch in the number of weights,
            or a mismatch in the shape of the weights.
    Raises:
        ValueError: in case of mismatch between provided layers
            and weights file and skip_match=False.
    """
    if 'keras_version' in f.attrs:
        original_keras_version = f.attrs['keras_version']
        if hasattr(original_keras_version, 'decode'):
            original_keras_version = original_keras_version.decode('utf8')
    else:
        original_keras_version = '1'
    if 'backend' in f.attrs:
        original_backend = f.attrs['backend']
        if hasattr(original_backend, 'decode'):
            original_backend = original_backend.decode('utf8')
    else:
        original_backend = None

    # New file format.
    layer_names = load_attributes_from_hdf5_group(f, 'layer_names')

    # Reverse index of layer name to list of layers with name.
    index = {}
    for layer in layers:
        if layer.name:
            index.setdefault(layer.name, []).append(layer)

    # We batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [
            np.asarray(g[weight_name]) for weight_name in weight_names
        ]

        for layer in index.get(name, []):
            symbolic_weights = _legacy_weights(layer)
            weight_values = preprocess_weights_for_loading(
                layer, weight_values, original_keras_version, original_backend)
            if len(weight_values) != len(symbolic_weights):
                if skip_mismatch:
                    logging.warning(
                        'Skipping loading of weights for '
                        'layer {}'.format(layer.name) + ' due to mismatch '
                        'in number of weights ({} vs {}).'.format(
                            len(symbolic_weights), len(weight_values)))
                    continue
                raise ValueError('Layer #' + str(k) + ' (named "' +
                                 layer.name + '") expects ' +
                                 str(len(symbolic_weights)) +
                                 ' weight(s), but the saved weights' +
                                 ' have ' + str(len(weight_values)) +
                                 ' element(s).')
            # Set values.
            for i in range(len(weight_values)):
                if backend.int_shape(
                        symbolic_weights[i]) != weight_values[i].shape:
                    if skip_mismatch:
                        logging.warning('Skipping loading of weights for '
                                        'layer {}'.format(layer.name) +
                                        ' due to '
                                        'mismatch in shape ({} vs {}).'.format(
                                            symbolic_weights[i].shape,
                                            weight_values[i].shape))
                        continue
                    raise ValueError(
                        'Layer #' + str(k) + ' (named "' + layer.name +
                        '"), weight ' + str(symbolic_weights[i]) +
                        ' has shape {}'.format(
                            backend.int_shape(symbolic_weights[i])) +
                        ', but the saved weight has shape ' +
                        str(weight_values[i].shape) + '.')

                else:
                    weight_value_tuples.append(
                        (symbolic_weights[i], weight_values[i]))
    backend.batch_set_value(weight_value_tuples)