コード例 #1
0
ファイル: component.py プロジェクト: ljj7975/relogic
  def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
    if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
      archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
      config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
    else:
      archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
      config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)

    # redirect to the cache, if necessary
    try:
      resolved_archive_file = cached_path(archive_file, cache_dir=RELOGIC_CACHE)
    except EnvironmentError:
      if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
        logger.error(
          "Couldn't reach server at '{}' to download pretrained weights.".format(
            archive_file))
      else:
        logger.error(
          "Model name '{}' was not found in model name list ({}). "
          "We assumed '{}' was a path or url but couldn't find any file "
          "associated to this path or url.".format(
            pretrained_model_name_or_path,
            ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
            archive_file))
      return None

    try:
      resolved_config_file = cached_path(config_file, cache_dir=RELOGIC_CACHE)
    except EnvironmentError:
      if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
        logger.error(
          "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
            config_file))
      else:
        logger.error(
          "Model name '{}' was not found in model name list ({}). "
          "We assumed '{}' was a path or url but couldn't find any file "
          "associated to this path or url.".format(
            pretrained_model_name_or_path,
            ', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
            config_file))
      return None

    if resolved_archive_file == archive_file and resolved_config_file == config_file:
      logger.info("loading weights file {}".format(archive_file))
      logger.info("loading configuration file {}".format(config_file))
    else:
      logger.info("loading weights file {} from cache at {}".format(
        archive_file, resolved_archive_file))
      logger.info("loading configuration file {} from cache at {}".format(
        config_file, resolved_config_file))


    with open(resolved_config_file) as f:
      restore_config = SimpleNamespace(**json.load(f))

    return cls(config=restore_config, predictor=Predictor(restore_config))
コード例 #2
0
ファイル: encoder.py プロジェクト: vrmpx/relogic
 def from_pretrained(cls,
                     pretrained_model_name_or_path,
                     cache_dir=None,
                     output_attentions=False):
     if pretrained_model_name_or_path in PRETRAINED_VECTOR_ARCHIVE_MAP:
         embedding_file = PRETRAINED_VECTOR_ARCHIVE_MAP[
             pretrained_model_name_or_path]
     else:
         embedding_file = pretrained_model_name_or_path
     try:
         resolved_embedding_file = cached_path(embedding_file,
                                               cache_dir=cache_dir)
     except EnvironmentError:
         logger.error(
             "Model name '{}' was not found in model name list ({}). "
             "We assumed '{}' was a path or url but couldn't find any file "
             "associated to this path or url.".format(
                 pretrained_model_name_or_path,
                 ', '.join(PRETRAINED_VECTOR_ARCHIVE_MAP.keys()),
                 embedding_file))
         return None
     if resolved_embedding_file == embedding_file:
         logger.info(
             "will load embedding file from {}".format(embedding_file))
     else:
         logger.info("will load embedding file {} from cache at {}".format(
             embedding_file, resolved_embedding_file))
     return cls(pretrained_model_name_or_path,
                embedding_file_path=resolved_embedding_file)
コード例 #3
0
 def from_pretrained(cls,
                     pretrained_model_name_or_path,
                     cache_dir=None,
                     *inputs,
                     **kwargs):
     if pretrained_model_name_or_path in PRETRAINED_VECTOR_ARCHIVE_MAP:
         vocab_file = PRETRAINED_VECTOR_ARCHIVE_MAP[
             pretrained_model_name_or_path]
     else:
         vocab_file = pretrained_model_name_or_path
     try:
         resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
     except EnvironmentError:
         logger.error(
             "Model name '{}' was not found in model name list ({}). "
             "We assumed '{}' was a path or url but couldn't find any file "
             "associated to this path or url.".format(
                 pretrained_model_name_or_path,
                 ', '.join(PRETRAINED_VECTOR_ARCHIVE_MAP.keys()),
                 vocab_file))
         return None
     if resolved_vocab_file == vocab_file:
         logger.info("loading vocabulary file {}".format(vocab_file))
     else:
         logger.info("loading vocabulary file {} from cache at {}".format(
             vocab_file, resolved_vocab_file))
     tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
     return tokenizer
コード例 #4
0
    def __init__(self,
                 index_names: List,
                 index_paths: Dict = None,
                 index_language: Dict = None):

        self.index_names = index_names
        self.retrievers = {}
        if index_paths is not None and index_language is not None:
            INDEX_PATHS.update(index_paths)
            INDEX_LANGUAGE.update(index_language)
        for index_name in index_names:
            index_zip_or_dir_path = cached_path(INDEX_PATHS[index_name],
                                                cache_dir=RELOGIC_CACHE)
            if os.path.isdir(index_zip_or_dir_path):
                index_path = index_zip_or_dir_path
            else:
                index_path = index_zip_or_dir_path + "." + index_name
            if not os.path.exists(index_path):
                with ZipFile(index_zip_or_dir_path, 'r') as zipObj:
                    zipObj.extractall(index_path)
                    logger.info("Extract Index {} to {}".format(
                        INDEX_PATHS[index_name], index_path))
            self.retrievers[index_name] = JSearcher(JString(index_path))
            self.retrievers[index_name].setLanguage(INDEX_LANGUAGE[index_name])
コード例 #5
0
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
                - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading.

                - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
                - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            return_unused_kwargs: (`optional`) bool:

                - If False, then this function returns just the final configuration object.
                - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.

        Examples::

            # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
            # derived class: BertConfig
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            config = BertConfig.from_pretrained('./test/saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
            config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
            config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
            assert config.output_attention == True
            config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
                                                               foo=False, return_unused_kwargs=True)
            assert config.output_attention == True
            assert unused_kwargs == {'foo': False}

        """
        cache_dir = kwargs.pop('cache_dir', None)
        force_download = kwargs.pop('force_download', False)
        proxies = kwargs.pop('proxies', None)
        return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)

        if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
            config_file = cls.pretrained_config_archive_map[
                pretrained_model_name_or_path]
        elif os.path.isdir(pretrained_model_name_or_path):
            config_file = os.path.join(pretrained_model_name_or_path,
                                       CONFIG_NAME)
        else:
            config_file = pretrained_model_name_or_path
        # redirect to the cache, if necessary
        try:
            resolved_config_file = cached_path(config_file,
                                               cache_dir=cache_dir,
                                               force_download=force_download,
                                               proxies=proxies)
        except EnvironmentError:
            if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
                msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
                    config_file)
            else:
                msg = "Model name '{}' was not found in model name list ({}). " \
                      "We assumed '{}' was a path or url to a configuration file named {} or " \
                      "a directory containing such a file but couldn't find any such file at this path or url.".format(
                        pretrained_model_name_or_path,
                        ', '.join(cls.pretrained_config_archive_map.keys()),
                        config_file, CONFIG_NAME)
            raise EnvironmentError(msg)

        if resolved_config_file == config_file:
            logger.info("loading configuration file {}".format(config_file))
        else:
            logger.info(
                "loading configuration file {} from cache at {}".format(
                    config_file, resolved_config_file))

        # Load config
        config = cls.from_json_file(resolved_config_file)

        if hasattr(config, 'pruned_heads'):
            config.pruned_heads = dict(
                (int(key), value)
                for key, value in config.pruned_heads.items())

        # Update config with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

        logger.info("Model config %s", str(config))
        if return_unused_kwargs:
            return config, kwargs
        else:
            return config
コード例 #6
0
ファイル: __init__.py プロジェクト: vrmpx/relogic
import os
import sys
from relogic.utils.file_utils import cached_path, RELOGIC_CACHE



PACKAGE_PATH = {
  "Anserini": "https://git.uwaterloo.ca/p8shi/data-server/raw/master/anserini-0.6.0-SNAPSHOT-fatjar.jar"
}

anserini_cache_path = cached_path(PACKAGE_PATH['Anserini'], cache_dir=RELOGIC_CACHE)


if sys.platform == 'win32':
  separator = ';'
else:
  separator = ':'

jar = os.path.join(separator + anserini_cache_path)

if 'CLASSPATH' not in os.environ:
  os.environ['CLASSPATH'] = jar
else:
  os.environ['CLASSPATH'] += jar
コード例 #7
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
    The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
    To train the model, you should first set it back in training mode with ``model.train()``
    The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
    It is up to you to train those weights with a downstream fine-tuning task.
    The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
    Parameters:
        pretrained_model_name_or_path: either:
            - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
            - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
            - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
            - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
        model_args: (`optional`) Sequence of positional arguments:
            All remaning positional arguments will be passed to the underlying model's ``__init__`` method
        config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
            Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
            - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
            - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
            - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
        state_dict: (`optional`) dict:
            an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
            This option can be used if you want to create a model from a pretrained configuration but load your own weights.
            In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
        cache_dir: (`optional`) string:
            Path to a directory in which a downloaded pre-trained model
            configuration should be cached if the standard cache should not be used.
        force_download: (`optional`) boolean, default False:
            Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
        proxies: (`optional`) dict, default None:
            A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
            The proxies are used on each request.
        output_loading_info: (`optional`) boolean:
            Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
        kwargs: (`optional`) Remaining dictionary of keyword arguments:
            Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
            - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
            - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
    Examples::
        model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
        model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
        assert model.config.output_attention == True
        # Loading from a TF checkpoint file instead of a PyTorch model (slower)
        config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
        model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
    """
        config = kwargs.pop('config', None)
        state_dict = kwargs.pop('state_dict', None)
        cache_dir = kwargs.pop('cache_dir', None)
        from_tf = kwargs.pop('from_tf', False)
        force_download = kwargs.pop('force_download', False)
        proxies = kwargs.pop('proxies', None)
        output_loading_info = kwargs.pop('output_loading_info', False)

        # Load config
        if config is None:
            config, model_kwargs = cls.config_class.from_pretrained(
                pretrained_model_name_or_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                **kwargs)
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                archive_file = cls.pretrained_model_archive_map[
                    pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False"
                        .format([
                            WEIGHTS_NAME, TF2_WEIGHTS_NAME,
                            TF_WEIGHTS_NAME + ".index"
                        ], pretrained_model_name_or_path))
            elif os.path.isfile(pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(
                    pretrained_model_name_or_path)
                archive_file = pretrained_model_name_or_path + ".index"

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies)
            except EnvironmentError:
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                    msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
                        archive_file)
                else:
                    msg = "Model name '{}' was not found in model name list ({}). " \
                          "We assumed '{}' was a path or url to model weight files named one of {} but " \
                          "couldn't find any such file at this path or url.".format(
                      pretrained_model_name_or_path,
                      ', '.join(cls.pretrained_model_archive_map.keys()),
                      archive_file,
                      [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if state_dict is None and not from_tf:
            state_dict = torch.load(resolved_archive_file, map_location='cpu')

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        if from_tf:
            if resolved_archive_file.endswith('.index'):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(
                    model, config,
                    resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
                    from transformers import load_tf2_checkpoint_in_pytorch_model
                    model = load_tf2_checkpoint_in_pytorch_model(
                        model, resolved_archive_file, allow_missing_keys=True)
                except ImportError as e:
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
                    raise e
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
                if 'gamma' in key:
                    new_key = key.replace('gamma', 'weight')
                if 'beta' in key:
                    new_key = key.replace('beta', 'bias')
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(
                    prefix[:-1], {})
                module._load_from_state_dict(state_dict, prefix,
                                             local_metadata, True,
                                             missing_keys, unexpected_keys,
                                             error_msgs)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')

            # Make sure we are able to load base models as well as derived models (with heads)
            start_prefix = ''
            model_to_load = model
            if not hasattr(model, cls.base_model_prefix) and any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                start_prefix = cls.base_model_prefix + '.'
            if hasattr(model, cls.base_model_prefix) and not any(
                    s.startswith(cls.base_model_prefix)
                    for s in state_dict.keys()):
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)
            if len(missing_keys) > 0:
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".
                    format(model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError(
                    'Error(s) in loading state_dict for {}:\n\t{}'.format(
                        model.__class__.__name__, "\n\t".join(error_msgs)))

        if hasattr(model, 'tie_weights'):
            model.tie_weights(
            )  # make sure word embedding weights are still tied

        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs
            }
            return model, loading_info

        return model
コード例 #8
0
ファイル: branching_encoder.py プロジェクト: vrmpx/relogic
  def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
                      from_tf=False, *inputs, **kwargs):
    if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
      archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
    else:
      archive_file = pretrained_model_name_or_path
    try:
      resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
    except EnvironmentError:
      logger.error(
        "Model name '{}' was not found in model name list ({}). "
        "We assumed '{}' was a path or url but couldn't find any file "
        "associated to this path or url.".format(
          pretrained_model_name_or_path,
          ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
          archive_file))
      return None
    if resolved_archive_file == archive_file:
      logger.info("loading archive file {}".format(archive_file))
    else:
      logger.info("loading archive file {} from cache at {}".format(
        archive_file, resolved_archive_file))

    tempdir = None
    if os.path.isdir(resolved_archive_file) or from_tf:
      serialization_dir = resolved_archive_file
    else:
      # Extract archive to temp dir
      tempdir = tempfile.mkdtemp()
      logger.info("extracting archive file {} to temp dir {}".format(
        resolved_archive_file, tempdir))
      with tarfile.open(resolved_archive_file, 'r:gz') as archive:
        archive.extractall(tempdir)
      serialization_dir = tempdir
    # Load config
    config_file = os.path.join(serialization_dir, CONFIG_NAME)
    config = BertConfig.from_json_file(config_file)
    logger.info("Model config {}".format(config))
    # Instantiate model.
    model = cls(config, *inputs, **kwargs)

    if state_dict is None and not from_tf:
      logger.info("Load model from torch model")
      weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
      state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
    if tempdir:
      # Clean up temp dir
      shutil.rmtree(tempdir)
    if from_tf:
      logger.info("Load model from tensorflow model")
      # Directly load from a TensorFlow checkpoint
      weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
      return load_tf_weights_in_bert(model, weights_path)
    # Load from a PyTorch state_dict
    old_keys = []
    new_keys = []
    for key in state_dict.keys():
      new_key = None
      if 'gamma' in key:
        new_key = key.replace('gamma', 'weight')
      if 'beta' in key:
        new_key = key.replace('beta', 'bias')
      if new_key:
        old_keys.append(key)
        new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
      state_dict[new_key] = state_dict.pop(old_key)

    # Replicate the parameter for layer
    keys = list(state_dict.keys())
    for idx, width in enumerate(kwargs["encoder_structure"]):
      prefix = "bert.encoder.layer.{}.".format(idx)
      for key in keys:
        if key.startswith(prefix):
          for i in range(width):
            new_prefix = "bert.encoder.layer.{}.{}.".format(idx, i)
            new_key = key.replace(prefix, new_prefix)
            state_dict[new_key] = copy.deepcopy(state_dict[key])
          state_dict.pop(key)

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
      state_dict._metadata = metadata

    def load(module, prefix=''):
      local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
      module._load_from_state_dict(
        state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
      for name, child in module._modules.items():
        if child is not None:
          load(child, prefix + name + '.')

    start_prefix = ''
    if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
      start_prefix = 'bert.'
    load(model, prefix=start_prefix)
    if len(missing_keys) > 0:
      logger.info("Weights of {} not initialized from pretrained model: {}".format(
        model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
      logger.info("Weights from pretrained model not used in {}: {}".format(
        model.__class__.__name__, unexpected_keys))
    if len(error_msgs) > 0:
      raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
        model.__class__.__name__, "\n\t".join(error_msgs)))
    return model
コード例 #9
0
  def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
    if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
      archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
      config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
    else:
      archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
      config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)

    # redirect to the cache, if necessary
    try:
      resolved_archive_file = cached_path(archive_file, cache_dir=RELOGIC_CACHE)
    except EnvironmentError:
      if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
        logger.error(
          "Couldn't reach server at '{}' to download pretrained weights.".format(
            archive_file))
      else:
        logger.error(
          "Model name '{}' was not found in model name list ({}). "
          "We assumed '{}' was a path or url but couldn't find any file "
          "associated to this path or url.".format(
            pretrained_model_name_or_path,
            ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
            archive_file))
      return None

    try:
      resolved_config_file = cached_path(config_file, cache_dir=RELOGIC_CACHE)
    except EnvironmentError:
      if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
        logger.error(
          "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
            config_file))
      else:
        logger.error(
          "Model name '{}' was not found in model name list ({}). "
          "We assumed '{}' was a path or url but couldn't find any file "
          "associated to this path or url.".format(
            pretrained_model_name_or_path,
            ', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
            config_file))
      return None

    if resolved_archive_file == archive_file and resolved_config_file == config_file:
      logger.info("loading weights file {}".format(archive_file))
      logger.info("loading configuration file {}".format(config_file))
    else:
      logger.info("loading weights file {} from cache at {}".format(
        archive_file, resolved_archive_file))
      logger.info("loading configuration file {} from cache at {}".format(
        config_file, resolved_config_file))


    resolved_config_file_dir = os.path.dirname(resolved_config_file)
    resolved_config_file_name = os.path.basename(resolved_config_file)
    restore_config = Argument.restore_from_model_path(model_path=resolved_config_file_dir, config_name=resolved_config_file_name)
    predictor = Predictor(restore_config)
    resolved_model_file_dir = os.path.dirname(resolved_archive_file)
    resolved_model_file_name = os.path.basename(resolved_archive_file)
    predictor.restore(model_path=resolved_model_file_dir, model_name=resolved_model_file_name)

    # The model name here is not that reasonable, sometimes it will be a self defined path.
    return cls(model_name=pretrained_model_name_or_path, config=restore_config, predictor=predictor)