Beispiel #1
0
    def __init__(self, path: str = 'small', device=None, **kwargs):
        if device is not None:
            if isinstance(device, torch.device):
                self.device = device
            elif isinstance(device, str):
                self.device = torch.device(device)
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        if path in model_map or is_remote_url(path) or os.path.isfile(path):
            proxies = kwargs.pop("proxies", None)
            cache_dir = kwargs.pop("cache_dir", LTP_CACHE)
            force_download = kwargs.pop("force_download", False)
            resume_download = kwargs.pop("resume_download", False)
            local_files_only = kwargs.pop("local_files_only", False)
            path = cached_path(model_map.get(path, path),
                               cache_dir=cache_dir,
                               force_download=force_download,
                               proxies=proxies,
                               resume_download=resume_download,
                               local_files_only=local_files_only,
                               extract_compressed_file=True)
        elif not os.path.isdir(path):
            raise FileNotFoundError()
        ckpt = torch.load(os.path.join(path, "ltp.model"),
                          map_location=self.device)
        ckpt['model_config']['init'].pop('pretrained')
        self.cache_dir = path
        self.model = Model.from_params(ckpt['model_config'],
                                       config=ckpt['pretrained_config']).to(
                                           self.device)
        self.model.load_state_dict(ckpt['model'])
        self.model.eval()
        # todo fp16
        self.max_length = self.model.pretrained.config.max_position_embeddings
        self.seg_vocab = [WORD_START, WORD_MIDDLE]
        self.pos_vocab = ckpt['pos']
        self.ner_vocab = ckpt['ner']
        self.dep_vocab = ckpt['dep']
        self.sdp_vocab = ckpt['sdp']
        self.srl_vocab = [
            re.sub(r'ARG(\d)', r'A\1', tag.lstrip('ARGM-'))
            for tag in ckpt['srl']
        ]
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, config=self.model.pretrained.config, use_fast=True)
        self.trie = Trie()

        if kwargs.pop("need_config", False):
            config = ckpt['model_config']
            config['init']['seg']['vocab'] = self.seg_vocab
            config['init']['pos']['vocab'] = self.pos_vocab
            config['init']['ner']['vocab'] = self.ner_vocab
            config['init']['dep']['vocab'] = self.dep_vocab
            config['init']['sdp']['vocab'] = self.sdp_vocab
            config['init']['srl']['vocab'] = self.srl_vocab
            config['pretrained_config'] = ckpt['pretrained_config']
            self.config = config
Beispiel #2
0
Datei: ltp.py Projekt: xwsss1/ltp
    def __init__(self,
                 path: str = 'small',
                 batch_size: int = 10,
                 device=None,
                 vocab: str = None,
                 **kwargs):
        if device is not None:
            if isinstance(device, torch.device):
                self.device = device
            elif isinstance(device, str):
                self.device = torch.device(device)
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        if os.path.isdir(path):
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
        elif path in model_map or is_remote_url(path) or os.path.isfile(path):
            cache_dir = kwargs.pop("cache_dir", LTP_CACHE)
            force_download = kwargs.pop("force_download", False)
            resume_download = kwargs.pop("resume_download", False)
            proxies = kwargs.pop("proxies", None)
            local_files_only = kwargs.pop("local_files_only", False)
            resolved_archive_path = cached_path(
                model_map.get(path, path),
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
                extract_compressed_file=True)
            resolved_ckpt_file = os.path.join(resolved_archive_path,
                                              "ltp.model")
            ckpt = torch.load(resolved_ckpt_file, map_location=self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(
                resolved_archive_path, use_fast=True)
        else:
            raise FileNotFoundError()

        ckpt['model_config']['init'].pop('pretrained')
        self.model = Model.from_params(ckpt['model_config'],
                                       config=ckpt['pretrained_config'])
        self.model.to(self.device)
        self.model.load_state_dict(ckpt['model'])
        self.model.eval()
        self.seg_vocab = [WORD_START, WORD_MIDDLE]
        self.pos_vocab = ckpt['pos']
        self.ner_vocab = ckpt['ner']
        self.dep_vocab = ckpt['dep']
        self.sdp_vocab = ckpt['sdp']
        self.srl_vocab = ckpt['srl']
        self.dep_fix = len(self.dep_vocab)
        self.split = lambda a: map(lambda b: a[b:b + batch_size],
                                   range(0, len(a), batch_size))
Beispiel #3
0
def download_file_from_hf(pretrained_model_name_or_path: str,
                          file_name: str) -> str:
    # Load model
    if pretrained_model_name_or_path is not None:
        if os.path.isdir(pretrained_model_name_or_path):
            if os.path.isfile(
                    os.path.join(pretrained_model_name_or_path, file_name)):
                # Load from a PyTorch checkpoint
                archive_file = os.path.join(pretrained_model_name_or_path,
                                            file_name)
            else:
                raise EnvironmentError(
                    "Error no file named {} found in directory {}".format(
                        file_name,
                        pretrained_model_name_or_path,
                    ))
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
                pretrained_model_name_or_path):
            archive_file = pretrained_model_name_or_path
        else:
            archive_file = hf_bucket_url(
                pretrained_model_name_or_path,
                filename=file_name,
                revision=None,
                mirror=None,
            )

        try:
            # Load from URL or cache if already cached
            resolved_archive_file = cached_path(
                archive_file,
                cache_dir=None,
                force_download=False,
                proxies=None,
                resume_download=False,
                local_files_only=False,
            )
        except EnvironmentError as err:
            logger.error(err)
            msg = (
                f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on"
                f"'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a"
                f"file named one of {file_name}.\n\n")
            raise EnvironmentError(msg)

        if resolved_archive_file == archive_file:
            logger.info("loading weights file {}".format(archive_file))
        else:
            logger.info("loading weights file {} from cache at {}".format(
                archive_file, resolved_archive_file))
    else:
        resolved_archive_file = None

    return resolved_archive_file
Beispiel #4
0
 def _get_config_dict(cls, path, **kw):
     local_files_only = kw.pop("local_files_only", False)
     from_pipeline = kw.pop("_from_pipeline", None)
     user_agent = {
         "file_type": "config",
         "from_auto_class": kw.pop("_from_auto", False)
     }
     if from_pipeline is not None:
         user_agent["using_pipeline"] = from_pipeline
     if is_offline_mode() and not local_files_only:
         log.info("Offline mode: forcing local_files_only=True")
         local_files_only = True
     path = str(path)
     if os.path.isfile(path) or is_remote_url(path):
         x = path
     else:
         f = kw.pop("_configuration_file", CONFIG_NAME)
         if os.path.isdir(path):
             x = os.path.join(path, f)
         else:
             x = hf_bucket_url(path,
                               filename=f,
                               revision=kw.pop("revision", None),
                               mirror=None)
     try:
         x2 = cached_path(
             x,
             cache_dir=kw.pop("cache_dir", None),
             force_download=kw.pop("force_download", False),
             proxies=kw.pop("proxies", None),
             resume_download=kw.pop("resume_download", False),
             local_files_only=local_files_only,
             use_auth_token=kw.pop("use_auth_token", None),
             user_agent=user_agent,
         )
     except RepositoryNotFoundError as e:
         raise OSError() from e
     except RevisionNotFoundError as e:
         raise OSError() from e
     except EntryNotFoundError as e:
         raise OSError() from e
     except HTTPError as e:
         raise OSError() from e
     except OSError as e:
         raise e
     try:
         y = cls._dict_from_json_file(x2)
     except (json.JSONDecodeError, UnicodeDecodeError) as e:
         raise OSError() from e
     if x2 == x:
         log.info(f"loading {x}")
     else:
         log.info(f"loading {x} from cache at {x2}")
     return y, kw
Beispiel #5
0
    def __init__(self, path: str = 'small', device=None, **kwargs):
        if device is not None:
            if isinstance(device, torch.device):
                self.device = device
            elif isinstance(device, str):
                self.device = torch.device(device)
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        if path in model_map or is_remote_url(path) or os.path.isfile(path):
            proxies = kwargs.pop("proxies", None)
            cache_dir = kwargs.pop("cache_dir", LTP_CACHE)
            force_download = kwargs.pop("force_download", False)
            resume_download = kwargs.pop("resume_download", False)
            local_files_only = kwargs.pop("local_files_only", False)
            path = cached_path(model_map.get(path, path),
                               cache_dir=cache_dir,
                               force_download=force_download,
                               proxies=proxies,
                               resume_download=resume_download,
                               local_files_only=local_files_only,
                               extract_compressed_file=True)
        elif not os.path.isdir(path):
            raise FileNotFoundError()
        try:
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)
        except Exception as e:
            fake_import_pytorch_lightning()
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)

        self.cache_dir = path
        config = AutoConfig.for_model(**ckpt['transformer_config'])
        self.model = Model(ckpt['model_config'], config=config).to(self.device)
        self.model.load_state_dict(ckpt['model'], strict=False)
        self.model.eval()
        self.max_length = self.model.transformer.config.max_position_embeddings
        self.seg_vocab = ckpt.get('seg', [WORD_MIDDLE, WORD_START])
        self.pos_vocab = ckpt.get('pos', [])
        self.ner_vocab = ckpt.get('ner', [])
        self.dep_vocab = ckpt.get('dep', [])
        self.sdp_vocab = ckpt.get('sdp', [])
        self.srl_vocab = [
            re.sub(r'ARG(\d)', r'A\1', tag.lstrip('ARGM-'))
            for tag in ckpt.get('srl', [])
        ]
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, config=self.model.transformer.config, use_fast=True)
        self.trie = Trie()
Beispiel #6
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        """Instantiate a pretrained pytorch model from a pre-trained model configuration.

        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
                - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
        config = kwargs.pop('config', None)
        state_dict = kwargs.pop('state_dict', None)
        cache_dir = kwargs.pop('cache_dir', None)
        from_tf = kwargs.pop('from_tf', False)
        force_download = kwargs.pop('force_download', False)
        proxies = kwargs.pop('proxies', None)
        output_loading_info = kwargs.pop('output_loading_info', False)
        random_init = kwargs.pop("random_init", False)
        use_cdn = kwargs.pop("use_cdn", True)
        local_files_only = kwargs.pop("local_files_only", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        kwargs_config = kwargs.copy()

        mapping_keys_state_dic = kwargs.pop("mapping_keys_state_dic", None)
        kwargs_config.pop("mapping_keys_state_dic", None)

        if config is None:

            config, model_kwargs = cls.config_class.from_pretrained(
                pretrained_model_name_or_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                **kwargs_config)
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False"
                        .format(
                            [
                                WEIGHTS_NAME, TF2_WEIGHTS_NAME,
                                TF_WEIGHTS_NAME + ".index"
                            ],
                            pretrained_model_name_or_path,
                        ))
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                assert (
                    from_tf
                ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
                    pretrained_model_name_or_path + ".index")
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
                    use_cdn=use_cdn,
                )

            try:
                # Load from URL or cache if already cached
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
                if resolved_archive_file is None:
                    raise EnvironmentError
            except EnvironmentError:
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        # Instantiate model.

        model = cls(config, *model_args, **model_kwargs)

        if state_dict is None and not from_tf:
            state_dict = torch.load(resolved_archive_file, map_location='cpu')

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        if from_tf:
            if resolved_archive_file.endswith('.index'):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(
                    model, config,
                    resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
                    from transformers import load_tf2_checkpoint_in_pytorch_model
                    model = load_tf2_checkpoint_in_pytorch_model(
                        model, resolved_archive_file, allow_missing_keys=True)
                except ImportError as e:
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
                    raise e
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []

            for key in state_dict.keys():
                new_key = None
                if 'gamma' in key:
                    new_key = key.replace('gamma', 'weight')
                if 'beta' in key:
                    new_key = key.replace('beta', 'bias')
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata
            # assert mapping_keys_state_dic is not None, "ERROR did not found mapping dicts for {} ".format(pretrained_model_name_or_path)
            # mapping_keys_state_dic = {"roberta": "encoder", "lm_head": "head.mlm"}
            if mapping_keys_state_dic is not None:
                assert isinstance(mapping_keys_state_dic, dict), "ERROR "
                print(
                    "INFO : from loading from pretrained method (assuming loading original google model : "
                    "need to rename some keys {})".format(
                        mapping_keys_state_dic))
                state_dict = cls.adapt_state_dic_to_multitask(
                    state_dict,
                    keys_mapping=mapping_keys_state_dic,
                    add_prefix=pretrained_model_name_or_path ==
                    "asafaya/bert-base-arabic")
                #pdb.set_trace()

            def load(module, prefix=''):

                local_metadata = {"version": 1}

                if not prefix.startswith("head") or prefix.startswith(
                        "head.mlm"):
                    assert len(
                        missing_keys
                    ) == 0, "ERROR {} missing keys in state_dict {}".format(
                        prefix, missing_keys)
                else:
                    if len(missing_keys) == 0:
                        print(
                            "Warning {} missing keys in state_dict {} (warning expected for task-specific fine-tuning)"
                            .format(prefix, missing_keys))

                module._load_from_state_dict(state_dict, prefix,
                                             local_metadata, True,
                                             missing_keys, unexpected_keys,
                                             error_msgs)
                for name, child in module._modules.items():

                    # load_params_only_ls = kwargs.get("load_params_only_ls ")
                    not_load_params_ls = kwargs.get(
                        "not_load_params_ls") if kwargs.get(
                            "not_load_params_ls") is not None else []
                    assert isinstance(
                        not_load_params_ls, list
                    ), f"Argument error not_load_params_ls should be a list but is {not_load_params_ls}"
                    matching_not_load = []
                    # RANDOM-INIT
                    for pattern in not_load_params_ls:
                        matching = re.match(pattern, prefix + name)
                        if matching is not None:
                            matching_not_load.append(matching)
                    if len(matching_not_load) > 0:
                        # means there is at least one patter in not load pattern that matched --> so should load
                        print("MATCH not loading : {} parameters {} ".format(
                            prefix + name, not_load_params_ls))
                    if child is not None and len(matching_not_load) == 0:
                        #print("MODEL loading : child {} full {} ".format(name, prefix + name + '.'))
                        load(child, prefix + name + '.')
                    else:
                        print(
                            "MODEL not loading : child {} matching_not_load {} "
                            .format(child, matching_not_load))

            # Make sure we are able to load base models as well as derived models (with heads)
            start_prefix = ''
            model_to_load = model
            if not hasattr(model, cls.base_model_prefix) and any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                start_prefix = cls.base_model_prefix + '.'
            if hasattr(model, cls.base_model_prefix) and not any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                model_to_load = getattr(model, cls.base_model_prefix)
            if not random_init:
                load(model_to_load, prefix=start_prefix)
            else:
                print("WARNING : RANDOM INTIALIZATION OF BERTMULTITASK")

            if len(missing_keys) > 0:
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".
                    format(model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError(
                    'Error(s) in loading state_dict for {}:\n\t{}'.format(
                        model.__class__.__name__, "\n\t".join(error_msgs)))

        if hasattr(model, 'tie_weights'):
            model.tie_weights(
            )  # make sure word embedding weights are still tied

        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs
            }
            return model, loading_info

        return model
Beispiel #7
0
def get_pretrained_state_dict(pretrained_model_name_or_path, *model_args,
                              **kwargs):
    """Get PyTorch state dict via HuggingFace transformers library."""
    config = kwargs.pop("config", None)
    state_dict = kwargs.pop("state_dict", None)
    cache_dir = kwargs.pop("cache_dir", None)
    # from_tf = kwargs.pop("from_tf", False)
    force_download = kwargs.pop("force_download", False)
    resume_download = kwargs.pop("resume_download", False)
    proxies = kwargs.pop("proxies", None)
    output_loading_info = kwargs.pop("output_loading_info", False)
    local_files_only = kwargs.pop("local_files_only", False)
    use_cdn = kwargs.pop("use_cdn", True)
    mirror = kwargs.pop("mirror", None)

    if pretrained_model_name_or_path is not None:
        if os.path.isdir(pretrained_model_name_or_path):
            if os.path.isfile(
                    os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                # Load from a PyTorch checkpoint
                archive_file = os.path.join(pretrained_model_name_or_path,
                                            WEIGHTS_NAME)
            else:
                raise EnvironmentError(
                    "Error no file named {} found in directory {}".format(
                        WEIGHTS_NAME,
                        pretrained_model_name_or_path,
                    ))
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
                pretrained_model_name_or_path):
            archive_file = pretrained_model_name_or_path
        elif os.path.isfile(pretrained_model_name_or_path + ".index"):
            assert False, "Loading TensorFlow checkpoints is not supported"
        else:
            archive_file = hf_bucket_url(
                pretrained_model_name_or_path,
                filename=WEIGHTS_NAME,
                use_cdn=use_cdn,
                mirror=mirror,
            )

        try:
            # Load from URL or cache if already cached
            resolved_archive_file = cached_path(
                archive_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
            )
            if resolved_archive_file is None:
                raise EnvironmentError
        except EnvironmentError:
            msg = (
                f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
            )
            raise EnvironmentError(msg)

        if resolved_archive_file == archive_file:
            print("loading weights file {}".format(archive_file))
        else:
            print("loading weights file {} from cache at {}".format(
                archive_file, resolved_archive_file))
    else:
        resolved_archive_file = None

    if state_dict is None:
        try:
            state_dict = torch.load(resolved_archive_file, map_location="cpu")
        except Exception:
            raise OSError(
                "Unable to load weights from pytorch checkpoint file.")
    return state_dict
Beispiel #8
0
    def get_config_dict(cls,
                        pretrained_model_name_or_path: str,
                        pretrained_config_archive_map: Optional[Dict] = None,
                        **kwargs) -> Tuple[Dict, Dict]:
        """
        From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
        for instantiating a Config using `from_dict`.

        Parameters:
            pretrained_model_name_or_path (:obj:`string`):
                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
            pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict:
                A map of `shortcut names` to `url`. By default, will use the current class attribute.

        Returns:
            :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.

        """
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)

        if pretrained_config_archive_map is None:
            pretrained_config_archive_map = cls.pretrained_config_archive_map

        if pretrained_model_name_or_path in pretrained_config_archive_map:
            config_file = pretrained_config_archive_map[
                pretrained_model_name_or_path]
        elif os.path.isdir(pretrained_model_name_or_path):
            config_file = os.path.join(pretrained_model_name_or_path,
                                       CONFIG_NAME)
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
                pretrained_model_name_or_path):
            config_file = pretrained_model_name_or_path
        else:
            config_file = hf_bucket_url(pretrained_model_name_or_path,
                                        postfix=CONFIG_NAME)

        try:
            # Load from URL or cache if already cached
            resolved_config_file = cached_path(
                config_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
            )
            # Load config dict
            if resolved_config_file is None:
                raise EnvironmentError
            config_dict = cls._dict_from_json_file(resolved_config_file)

        except EnvironmentError:
            if pretrained_model_name_or_path in pretrained_config_archive_map:
                msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
                    config_file)
            else:
                msg = (
                    "Can't load '{}'. Make sure that:\n\n"
                    "- '{}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    "- or '{}' is the correct path to a directory containing a '{}' file\n\n"
                    .format(
                        pretrained_model_name_or_path,
                        pretrained_model_name_or_path,
                        pretrained_model_name_or_path,
                        CONFIG_NAME,
                    ))
            raise EnvironmentError(msg)

        except json.JSONDecodeError:
            msg = (
                "Couldn't reach server at '{}' to download configuration file or "
                "configuration file is not a valid JSON file. "
                "Please check network or file content here: {}.".format(
                    config_file, resolved_config_file))
            raise EnvironmentError(msg)

        if resolved_config_file == config_file:
            logger.info("loading configuration file {}".format(config_file))
        else:
            logger.info(
                "loading configuration file {} from cache at {}".format(
                    config_file, resolved_config_file))

        return config_dict, kwargs
Beispiel #9
0
def from_pretrained_detailed(model_class, pretrained_model_name_or_path,
                             *model_args, **kwargs):
    r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.

    The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
    It is up to you to train those weights with a downstream fine-tuning task.

    The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

    Parameters:
        pretrained_model_name_or_path: either:
            - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
            - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
            - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
            - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.

        model_args: (`optional`) Sequence of positional arguments:
            All remaning positional arguments will be passed to the underlying model's ``__init__`` method

        config: (`optional`) one of:
                - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`

            Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

        from_pt: (`optional`) boolean, default False:
            Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).

        cache_dir: (`optional`) string:
            Path to a directory in which a downloaded pre-trained model
            configuration should be cached if the standard cache should not be used.

        force_download: (`optional`) boolean, default False:
            Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

        resume_download: (`optional`) boolean, default False:
            Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

        proxies: (`optional`) dict, default None:
            A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
            The proxies are used on each request.

        output_loading_info: (`optional`) boolean:
            Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

        kwargs: (`optional`) Remaining dictionary of keyword arguments:
            Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

            - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
            - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
            
            - If layer pruning is supported, ``layer_pruning`` will passed as a dictionary contains layer pruning configurations as follows:
                - strategy:
                    can be one of these values: {`top`, `buttom`, `symmetric`, `alternate`, `custom`}
                - k:
                    is the number of layers to prune. mandatory if strategy is one of {`top`, `buttom`, `symmetric`, `alternate`}
                - layers_indexes:
                    is array of layers indexs to prune. mandatory if strategy is `custom`
                - is_odd:
                    is odd alternate or not. mandatory if strategy is `alternate`

    Examples::

        # For example purposes. Not runnable.
        model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
        model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
        assert model.config.output_attention == True
        # Loading from a TF checkpoint file instead of a PyTorch model (slower)
        config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
        model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)

    """
    config = kwargs.pop("config", None)
    cache_dir = kwargs.pop("cache_dir", None)
    from_pt = kwargs.pop("from_pt", False)
    force_download = kwargs.pop("force_download", False)
    resume_download = kwargs.pop("resume_download", False)
    proxies = kwargs.pop("proxies", None)
    output_loading_info = kwargs.pop("output_loading_info", False)
    local_files_only = kwargs.pop("local_files_only", False)
    use_cdn = kwargs.pop("use_cdn", True)

    # mwahdan: Read layer_pruning config if exist
    layer_pruning = kwargs.pop("layer_pruning", None)

    # Load config if we don't provide a configuration
    if not isinstance(config, PretrainedConfig):
        config_path = config if config is not None else pretrained_model_name_or_path
        config, model_kwargs = model_class.config_class.from_pretrained(
            config_path,
            *model_args,
            cache_dir=cache_dir,
            return_unused_kwargs=True,
            force_download=force_download,
            resume_download=resume_download,
            proxies=proxies,
            local_files_only=local_files_only,
            **kwargs,
        )
    else:
        model_kwargs = kwargs

    # Load model
    if pretrained_model_name_or_path is not None:
        if os.path.isdir(pretrained_model_name_or_path):
            if os.path.isfile(
                    os.path.join(pretrained_model_name_or_path,
                                 TF2_WEIGHTS_NAME)):
                # Load from a TF 2.0 checkpoint
                archive_file = os.path.join(pretrained_model_name_or_path,
                                            TF2_WEIGHTS_NAME)
            elif from_pt and os.path.isfile(
                    os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                # Load from a PyTorch checkpoint
                archive_file = os.path.join(pretrained_model_name_or_path,
                                            WEIGHTS_NAME)
            else:
                raise EnvironmentError(
                    "Error no file named {} found in directory {} or `from_pt` set to False"
                    .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME],
                            pretrained_model_name_or_path))
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
                pretrained_model_name_or_path):
            archive_file = pretrained_model_name_or_path
        elif os.path.isfile(pretrained_model_name_or_path + ".index"):
            archive_file = pretrained_model_name_or_path + ".index"
        else:
            archive_file = hf_bucket_url(
                pretrained_model_name_or_path,
                filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
            )

        try:
            # Load from URL or cache if already cached
            resolved_archive_file = cached_path(
                archive_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
            )
            if resolved_archive_file is None:
                raise EnvironmentError
        except EnvironmentError:
            msg = (
                f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
            )
            raise EnvironmentError(msg)
        if resolved_archive_file == archive_file:
            logger.info("loading weights file {}".format(archive_file))
        else:
            logger.info("loading weights file {} from cache at {}".format(
                archive_file, resolved_archive_file))
    else:
        resolved_archive_file = None

    # mwahdan: Modify config
    if layer_pruning:
        layer_pruning_k = layer_pruning_layers_indexes = layer_pruning_is_odd = None
        layer_pruning_strategy = get_mandatory_parameter(
            'strategy', layer_pruning)
        if layer_pruning_strategy in {'top', 'buttom', 'symmetric'}:
            layer_pruning_k = get_mandatory_parameter('k', layer_pruning)
            config, original_num_layers = modify_num_of_layers(
                config, k=layer_pruning_k)
        elif layer_pruning_strategy == 'custom':
            layer_pruning_layers_indexes = get_mandatory_parameter(
                'layers_indexes', layer_pruning)
            config, original_num_layers = modify_num_of_layers(
                config, layers_indexes=layer_pruning_layers_indexes)
        elif layer_pruning_strategy == 'alternate':
            layer_pruning_k = get_mandatory_parameter('k', layer_pruning)
            layer_pruning_is_odd = get_mandatory_parameter(
                'is_odd', layer_pruning)
            config, original_num_layers = modify_num_of_layers(
                config, k=layer_pruning_k, is_alternate=True)
        else:
            raise Exception('`%s` is not a supported layer pruning strategy' %
                            layer_pruning_strategy)

    # Instantiate model.
    model = model_class(config, *model_args, **model_kwargs)

    # mwahdan: Rename layers
    if layer_pruning:
        model = rename_layers_in_strategy(model, layer_pruning_strategy,
                                          original_num_layers, layer_pruning_k,
                                          layer_pruning_layers_indexes,
                                          layer_pruning_is_odd)

    if from_pt:
        # Load from a PyTorch checkpoint
        model = load_pytorch_checkpoint_in_tf2_model(model,
                                                     resolved_archive_file,
                                                     allow_missing_keys=True)
        # mwahdan: Rename layers
        if layer_pruning is not None:
            model = rename_layers(model)
        return model

    model(model.dummy_inputs,
          training=False)  # build the network with dummy inputs

    assert os.path.isfile(
        resolved_archive_file), "Error retrieving file {}".format(
            resolved_archive_file)
    # 'by_name' allow us to do transfer learning by skipping/adding layers
    # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
    try:
        # added skip_mismatch=True because we will prune full layers
        model.load_weights(resolved_archive_file,
                           by_name=True,
                           skip_mismatch=True)
        # mwahdan: Rename layers
    except OSError:
        raise OSError(
            "Unable to load weights from h5 file. "
            "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
        )

    model(model.dummy_inputs, training=False)  # Make sure restore ops are run

    # mwahdan: Rename layers
    if layer_pruning is not None:
        model = rename_layers(model)

    # Check if the models are the same to output loading informations
    with h5py.File(resolved_archive_file, "r") as f:
        if "layer_names" not in f.attrs and "model_weights" in f:
            f = f["model_weights"]
        hdf5_layer_names = set(
            hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
    model_layer_names = set(layer.name for layer in model.layers)
    missing_keys = list(model_layer_names - hdf5_layer_names)
    unexpected_keys = list(hdf5_layer_names - model_layer_names)
    error_msgs = []

    if len(unexpected_keys) > 0:
        logger.warning(
            f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
            f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
            f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
            f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
            f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
            f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
        )
    else:
        logger.warning(
            f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
        )
    if len(missing_keys) > 0:
        logger.warning(
            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
            f"and are newly initialized: {missing_keys}\n"
            f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
        )
    else:
        logger.warning(
            f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
            f"If your task is similar to the task the model of the ckeckpoint was trained on, "
            f"you can already use {model.__class__.__name__} for predictions without further training."
        )
    if len(error_msgs) > 0:
        raise RuntimeError("Error(s) in loading weights for {}:\n\t{}".format(
            model.__class__.__name__, "\n\t".join(error_msgs)))
    if output_loading_info:
        loading_info = {
            "missing_keys": missing_keys,
            "unexpected_keys": unexpected_keys,
            "error_msgs": error_msgs
        }
        return model, loading_info

    return model
Beispiel #10
0
def cached_path(
    url_or_filename,
    cache_dir=None,
    force_download=False,
    proxies=None,
    resume_download=False,
    user_agent: Union[Dict, str, None] = None,
    extract_compressed_file=False,
    force_extract=False,
    local_files_only=False,
) -> Optional[str]:
    """
    Given something that might be a URL (or might be a local path), determine which.
    If it's a URL, download the file and cache it, and return the path to the cached
    file. If it's already a local path, make sure the file exists and then return the
    path.

    Args:
        cache_dir: specify a cache directory to save the file to (overwrite the default
            cache dir).
        force_download: if True, re-download the file even if it's already cached in
            the cache dir.
        resume_download: if True, resume the download if incompletely received file is
            found.
        user_agent: Optional string or dict that will be appended to the user-agent on
            remote requests.
        extract_compressed_file: if True and the path point to a zip or tar file,
            extract the compressed file in a folder along the archive.
        force_extract: if True when extract_compressed_file is True and the archive was
            already extracted, re-extract the archive and override the folder where it
            was extracted.

    Return:
        Local path (string) of file or if networking is off, last version of file
        cached on disk.

    Raises:
        In case of non-recoverable file (non-existent or inaccessible url + no cache on
        disk).
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    if isinstance(url_or_filename, Path):
        url_or_filename = str(url_or_filename)
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    if is_remote_url(url_or_filename):
        # URL, so get it from the cache (downloading if necessary)
        output_path = get_from_cache(
            url_or_filename,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            user_agent=user_agent,
            local_files_only=local_files_only,
        )
    elif os.path.exists(url_or_filename):
        # File, and it exists.
        output_path = url_or_filename
    elif urlparse(url_or_filename).scheme == "":
        # File, but it doesn't exist.
        raise EnvironmentError("file {} not found".format(url_or_filename))
    else:
        # Something unknown
        raise ValueError(
            "unable to parse {} as a URL or as a local path".format(url_or_filename)
        )

    if extract_compressed_file:
        if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
            return output_path

        # Path where we extract compressed archives
        # We avoid '.' in dir name and add "-extracted" at the end:
        # "./model.zip" => "./model-zip-extracted/"
        output_dir, output_file = os.path.split(output_path)
        output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
        output_path_extracted = os.path.join(output_dir, output_extract_dir_name)

        if (
            os.path.isdir(output_path_extracted)
            and os.listdir(output_path_extracted)
            and not force_extract
        ):
            return output_path_extracted

        # Prevent parallel extractions
        lock_path = output_path + ".lock"
        with FileLock(lock_path):
            shutil.rmtree(output_path_extracted, ignore_errors=True)
            os.makedirs(output_path_extracted)
            if is_zipfile(output_path):
                with ZipFile(output_path, "r") as zip_file:
                    zip_file.extractall(output_path_extracted)
                    zip_file.close()
            elif tarfile.is_tarfile(output_path):
                tar_file = tarfile.open(output_path)
                tar_file.extractall(output_path_extracted)
                tar_file.close()
            else:
                raise EnvironmentError(
                    "Archive format of {} could not be identified".format(output_path)
                )

        return output_path_extracted

    return output_path