def __init__(self,
                 model_name_or_path: str = None,
                 modules: Iterable[nn.Module] = None,
                 device: str = None):
        if modules is not None and not isinstance(modules, OrderedDict):
            modules = OrderedDict([(str(idx), module)
                                   for idx, module in enumerate(modules)])

        if model_name_or_path is not None and model_name_or_path != "":
            logging.info("Load pretrained SentenceTransformer: {}".format(
                model_name_or_path))

            if '/' not in model_name_or_path and '\\' not in model_name_or_path and not os.path.isdir(
                    model_name_or_path):
                logging.info(
                    "Did not find a / or \\ in the name. Assume to download model from server"
                )
                model_name_or_path = __DOWNLOAD_SERVER__ + model_name_or_path + '.zip'

            if model_name_or_path.startswith(
                    'http://') or model_name_or_path.startswith('https://'):
                model_url = model_name_or_path
                folder_name = model_url.replace("https://", "").replace(
                    "http://", "").replace("/", "_")[:250]

                try:
                    from torch.hub import _get_torch_home
                    torch_cache_home = _get_torch_home()
                except ImportError:
                    torch_cache_home = os.path.expanduser(
                        os.getenv(
                            'TORCH_HOME',
                            os.path.join(
                                os.getenv('XDG_CACHE_HOME', '~/.cache'),
                                'torch')))
                default_cache_path = os.path.join(torch_cache_home,
                                                  'sentence_transformers')
                model_path = os.path.join(default_cache_path, folder_name)
                os.makedirs(model_path, exist_ok=True)

                if not os.listdir(model_path):
                    if model_url[-1] is "/":
                        model_url = model_url[:-1]
                    logging.info(
                        "Downloading sentence transformer model from {} and saving it at {}"
                        .format(model_url, model_path))
                    try:
                        zip_save_path = os.path.join(model_path, 'model.zip')
                        http_get(model_url, zip_save_path)
                        with ZipFile(zip_save_path, 'r') as zip:
                            zip.extractall(model_path)
                    except Exception as e:
                        shutil.rmtree(model_path)
                        raise e
            else:
                model_path = model_name_or_path

            #### Load from disk
            if model_path is not None:
                logging.info("Load SentenceTransformer from folder: {}".format(
                    model_path))
                with open(os.path.join(model_path, 'modules.json')) as fIn:
                    contained_modules = json.load(fIn)

                modules = OrderedDict()
                for module_config in contained_modules:
                    module_class = import_from_string(module_config['type'])
                    module = module_class.load(
                        os.path.join(model_path, module_config['path']))
                    modules[module_config['name']] = module

        super().__init__(modules)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info("Use pytorch device: {}".format(device))
        self.device = torch.device(device)
        self.to(device)
    def __init__(self, model_name_or_path: str = None, modules: Iterable[nn.Module] = None, device: str = None):
        if model_name_or_path is not None and model_name_or_path != "":
            logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path))
            model_path = model_name_or_path
            save_model_to = None

            if not os.path.isdir(model_path) and not model_path.startswith('http://') and not model_path.startswith('https://'):
                logger.info("Did not find folder {}".format(model_path))

                if '\\' in model_path or model_path.count('/') > 1:
                    raise AttributeError("Path {} not found".format(model_path))

                model_path = __DOWNLOAD_SERVER__ + model_path + '.zip'
                logger.info("Search model on server: {}".format(model_path))

            if model_path.startswith('http://') or model_path.startswith('https://'):
                model_url = model_path
                folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250].rstrip('.zip')

                cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME')
                if cache_folder is None:
                    try:
                        from torch.hub import _get_torch_home
                        torch_cache_home = _get_torch_home()
                    except ImportError:
                        torch_cache_home = os.path.expanduser(os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))

                    cache_folder = os.path.join(torch_cache_home, 'sentence_transformers')

                model_path = os.path.join(cache_folder, folder_name)

                if not os.path.exists(model_path) or not os.listdir(model_path):
                    if os.path.exists(model_path):
                        os.remove(model_path)

                    model_url = model_url.rstrip("/")
                    logger.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path))

                    model_path_tmp = model_path.rstrip("/").rstrip("\\")+"_part"
                    try:
                        zip_save_path = os.path.join(model_path_tmp, 'model.zip')
                        http_get(model_url, zip_save_path)
                        with ZipFile(zip_save_path, 'r') as zip:
                            zip.extractall(model_path_tmp)
                        os.remove(zip_save_path)
                        os.rename(model_path_tmp, model_path)
                    except requests.exceptions.HTTPError as e:
                        shutil.rmtree(model_path_tmp)
                        if e.response.status_code == 429:
                            raise Exception("Too many requests were detected from this IP for the model {}. Please contact [email protected] for more information.".format(model_name_or_path))

                        if e.response.status_code == 404:
                            logger.warning('SentenceTransformer-Model {} not found. Try to create it from scratch'.format(model_url))
                            logger.warning('Try to create Transformer Model {} with mean pooling'.format(model_name_or_path))

                            save_model_to = model_path
                            model_path = None
                            transformer_model = Transformer(model_name_or_path)
                            pooling_model = Pooling(transformer_model.get_word_embedding_dimension())
                            modules = [transformer_model, pooling_model]
                        else:
                            raise e
                    except Exception as e:
                        shutil.rmtree(model_path)
                        raise e


            #### Load from disk
            if model_path is not None:
                logger.info("Load SentenceTransformer from folder: {}".format(model_path))

                if os.path.exists(os.path.join(model_path, 'config.json')):
                    with open(os.path.join(model_path, 'config.json')) as fIn:
                        config = json.load(fIn)
                        if config['__version__'] > __version__:
                            logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__))

                with open(os.path.join(model_path, 'modules.json')) as fIn:
                    contained_modules = json.load(fIn)

                modules = OrderedDict()
                for module_config in contained_modules:
                    module_class = import_from_string(module_config['type'])
                    module = module_class.load(os.path.join(model_path, module_config['path']))
                    modules[module_config['name']] = module


        if modules is not None and not isinstance(modules, OrderedDict):
            modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])

        super().__init__(modules)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logger.info("Use pytorch device: {}".format(device))

        self._target_device = torch.device(device)

        #We created a new model from scratch based on a Transformer model. Save the SBERT model in the cache folder
        if save_model_to is not None:
            self.save(save_model_to)
import os
import shutil
import tempfile
import fnmatch
from functools import wraps
from hashlib import sha256
from io import open

import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm

try:
    from torch.hub import _get_torch_home
    torch_cache_home = _get_torch_home()
except ImportError:
    torch_cache_home = os.path.expanduser(
        os.getenv('TORCH_HOME', os.path.join(
            os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers')

try:
    from urllib.parse import urlparse
except ImportError:
    from urlparse import urlparse

try:
    from pathlib import Path
    PYTORCH_PRETRAINED_BERT_CACHE = Path(
        os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
示例#4
0
import torch
import wget
from torch.hub import _get_torch_home

from nemo.collections.nlp.modules.common.megatron.megatron_bert import MegatronBertEncoder
from nemo.utils import AppState, logging

__all__ = [
    "get_megatron_lm_model",
    "get_megatron_lm_models_list",
    "get_megatron_checkpoint",
    "is_lower_cased_megatron",
    "get_megatron_tokenizer",
]

MEGATRON_CACHE = os.path.join(_get_torch_home(), "megatron")

CONFIGS = {
    "345m": {
        "hidden_size": 1024,
        "num_attention_heads": 16,
        "num_layers": 24,
        "max_position_embeddings": 512
    }
}

MEGATRON_CONFIG_MAP = {
    "megatron-bert-345m-uncased": {
        "config": CONFIGS["345m"],
        "checkpoint":
        "https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.0/files/release/mp_rank_00/model_optim_rng.pt",
示例#5
0
    def __init__(self,
                 model_name_or_path: str = None,
                 modules: Iterable[nn.Module] = None,
                 device: str = None,
                 logfile: str = None,
                 tboard_logdir: str = None):
        if logfile is not None:
            print("Logs go to file %s" % logfile)
            self.logfile = open(logfile, "w")

        if tboard_logdir is not None:
            print("Tensorboard logs go to dir %s" % tboard_logdir)
            self.tboard_logger = SummaryWriter(tboard_logdir)

        if modules is not None and not isinstance(modules, OrderedDict):
            modules = OrderedDict([(str(idx), module)
                                   for idx, module in enumerate(modules)])

        if model_name_or_path is not None and model_name_or_path != "":
            logging.info("Load pretrained SentenceTransformer: {}".format(
                model_name_or_path))

            if '/' not in model_name_or_path and '\\' not in model_name_or_path and not os.path.isdir(
                    model_name_or_path):
                logging.info(
                    "Did not find a / or \\ in the name. Assume to download model from server"
                )
                model_name_or_path = __DOWNLOAD_SERVER__ + model_name_or_path + '.zip'

            if model_name_or_path.startswith(
                    'http://') or model_name_or_path.startswith('https://'):
                model_url = model_name_or_path
                folder_name = model_url.replace("https://", "").replace(
                    "http://", "").replace("/", "_")[:250]

                try:
                    from torch.hub import _get_torch_home
                    torch_cache_home = _get_torch_home()
                except ImportError:
                    torch_cache_home = os.path.expanduser(
                        os.getenv(
                            'TORCH_HOME',
                            os.path.join(
                                os.getenv('XDG_CACHE_HOME', '~/.cache'),
                                'torch')))
                default_cache_path = os.path.join(torch_cache_home,
                                                  'sentence_transformers')
                model_path = os.path.join(default_cache_path, folder_name)
                os.makedirs(model_path, exist_ok=True)

                if not os.listdir(model_path):
                    if model_url[-1] is "/":
                        model_url = model_url[:-1]
                    logging.info(
                        "Downloading sentence transformer model from {} and saving it at {}"
                        .format(model_url, model_path))
                    try:
                        zip_save_path = os.path.join(model_path, 'model.zip')
                        http_get(model_url, zip_save_path)
                        with ZipFile(zip_save_path, 'r') as zip:
                            zip.extractall(model_path)
                    except Exception as e:
                        shutil.rmtree(model_path)
                        raise e
            else:
                model_path = model_name_or_path

            #### Load from disk
            if model_path is not None:
                logging.info("Load SentenceTransformer from folder: {}".format(
                    model_path))

                if os.path.exists(os.path.join(model_path, 'config.json')):
                    with open(os.path.join(model_path, 'config.json')) as fIn:
                        config = json.load(fIn)
                        if config['__version__'] > __version__:
                            logging.warning(
                                "You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n"
                                .format(config['__version__'], __version__))

                with open(os.path.join(model_path, 'modules.json')) as fIn:
                    contained_modules = json.load(fIn)

                modules = OrderedDict()
                for module_config in contained_modules:
                    module_class = import_from_string(module_config['type'])
                    module = module_class.load(
                        os.path.join(model_path, module_config['path']))
                    modules[module_config['name']] = module

        super().__init__(modules)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info("Use pytorch device: {}".format(device))
        self.device = torch.device(device)
        self.to(device)
示例#6
0
    def __init__(self,
                 model_name_or_path: str = None,
                 modules: Iterable[nn.Module] = None,
                 device: str = None):
        if model_name_or_path is not None and model_name_or_path != "":
            logging.info("Load pretrained SentenceTransformer: {}".format(
                model_name_or_path))
            model_path = model_name_or_path

            if not os.path.isdir(model_path) and not model_path.startswith(
                    'http://') and not model_path.startswith('https://'):
                logging.info(
                    "Did not find folder {}. Assume to download model from server."
                    .format(model_path))
                model_path = __DOWNLOAD_SERVER__ + model_path + '.zip'

            if model_path.startswith('http://') or model_path.startswith(
                    'https://'):
                model_url = model_path
                folder_name = model_url.replace("https://", "").replace(
                    "http://", "").replace("/", "_")[:250].rstrip('.zip')

                try:
                    from torch.hub import _get_torch_home
                    torch_cache_home = _get_torch_home()
                except ImportError:
                    torch_cache_home = os.path.expanduser(
                        os.getenv(
                            'TORCH_HOME',
                            os.path.join(
                                os.getenv('XDG_CACHE_HOME', '~/.cache'),
                                'torch')))
                default_cache_path = os.path.join(torch_cache_home,
                                                  'sentence_transformers')
                model_path = os.path.join(default_cache_path, folder_name)
                os.makedirs(model_path, exist_ok=True)

                if not os.listdir(model_path):
                    if model_url[-1] == "/":
                        model_url = model_url[:-1]
                    logging.info(
                        "Downloading sentence transformer model from {} and saving it at {}"
                        .format(model_url, model_path))
                    try:
                        zip_save_path = os.path.join(model_path, 'model.zip')
                        http_get(model_url, zip_save_path)
                        with ZipFile(zip_save_path, 'r') as zip:
                            zip.extractall(model_path)
                        os.remove(zip_save_path)
                    except requests.exceptions.HTTPError as e:
                        shutil.rmtree(model_path)
                        if e.response.status_code == 404:
                            logging.warning(
                                'SentenceTransformer-Model {} not found. Try to create it from scratch'
                                .format(model_url))
                            logging.warning(
                                'Try to create Transformer Model {} with mean pooling'
                                .format(model_name_or_path))

                            model_path = None
                            transformer_model = Transformer(model_name_or_path)
                            pooling_model = Pooling(
                                transformer_model.get_word_embedding_dimension(
                                ))
                            modules = [transformer_model, pooling_model]

                        else:
                            raise e
                    except Exception as e:
                        shutil.rmtree(model_path)
                        raise e

            #### Load from disk
            if model_path is not None:
                logging.info("Load SentenceTransformer from folder: {}".format(
                    model_path))

                if os.path.exists(os.path.join(model_path, 'config.json')):
                    with open(os.path.join(model_path, 'config.json')) as fIn:
                        config = json.load(fIn)
                        if config['__version__'] > __version__:
                            logging.warning(
                                "You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n"
                                .format(config['__version__'], __version__))

                with open(os.path.join(model_path, 'modules.json')) as fIn:
                    contained_modules = json.load(fIn)

                modules = OrderedDict()
                for module_config in contained_modules:
                    module_class = import_from_string(module_config['type'])
                    module = module_class.load(
                        os.path.join(model_path, module_config['path']))
                    modules[module_config['name']] = module

        if modules is not None and not isinstance(modules, OrderedDict):
            modules = OrderedDict([(str(idx), module)
                                   for idx, module in enumerate(modules)])

        super().__init__(modules)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info("Use pytorch device: {}".format(device))

        self._target_device = torch.device(device)
示例#7
0
import tempfile
from functools import wraps
from hashlib import sha256
from io import open

import boto3
import numpy as np
import requests
from botocore.exceptions import ClientError
from dotmap import DotMap
from tqdm import tqdm

try:
    from torch.hub import _get_torch_home

    torch_cache_home = Path(_get_torch_home())
except ImportError:
    torch_cache_home = Path(
        os.path.expanduser(
            os.getenv("TORCH_HOME",
                      Path(os.getenv("XDG_CACHE_HOME", "~/.cache")) /
                      "torch")))
default_cache_path = torch_cache_home / "farm"

try:
    from urllib.parse import urlparse
except ImportError:
    from urlparse import urlparse

try:
    from pathlib import Path
    def __init__(self, model_name_or_path: str = None, sentence_transformer_config: SentenceTransformerConfig = None):
        """
        Creates a Sentence BERT model based on either a pretrained model downloaded from the internet or the file system
        or based on a config for a new model

        When a model_url is given, then the files are downloaded from the URL. They are stored at model_path or in a
        temp folder based on the URL, when no model_path is given.
        When no model_url is given, but a model_path, then the model is loaded from the file system.
        When neither url nor path is given, then a new model is created based on the sbert_config

        :param model_name_or_path:
            A pre-trained model name, a URL or a path on the file system
        :param sentence_transformer_config:
            configuration for a new model
        """
        model_path = None
        if model_name_or_path is not None:
            if '/' not in model_name_or_path and '\\' not in model_name_or_path:
                model_name_or_path = 'https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.1/' + model_name_or_path

            if model_name_or_path.startswith('http://') or model_name_or_path.startswith('https://'):
                model_url = model_name_or_path
                folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250]

                try:
                    from torch.hub import _get_torch_home
                    torch_cache_home = _get_torch_home()
                except ImportError:
                    torch_cache_home = os.path.expanduser(
                        os.getenv('TORCH_HOME', os.path.join(
                            os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
                default_cache_path = os.path.join(torch_cache_home, 'sentence_transformers')
                model_path = os.path.join(default_cache_path, folder_name)
                os.makedirs(model_path, exist_ok=True)

                if not os.listdir(model_path):
                    if model_url[-1] is "/":
                        model_url = model_url[:-1]
                    logging.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path))
                    try:
                        http_get(model_url + "/" + WEIGHTS_NAME, os.path.join(model_path, WEIGHTS_NAME))
                        http_get(model_url + "/" + CONFIG_NAME, os.path.join(model_path, CONFIG_NAME))
                        http_get(model_url + "/" + 'sentence_transformer_config.json', os.path.join(model_path, 'sentence_transformer_config.json'))
                    except Exception as e:
                        shutil.rmtree(model_path)
                        raise e
            else:
                model_path = model_name_or_path


        if model_path is not None:
            logging.info("Loading model from {}".format(model_path))
            output_model_file = os.path.join(model_path, WEIGHTS_NAME)
            output_transformer_config_file = os.path.join(model_path, CONFIG_NAME)
            output_sentence_transformer_config_file = os.path.join(model_path, 'sentence_transformer_config.json')

            if not os.path.exists(output_model_file) or not os.path.exists(output_transformer_config_file) or not os.path.exists(output_sentence_transformer_config_file):
                raise Exception("It appears that files are missing in {}. The sentence transformer model cannot be loaded".format(model_path))


            sentence_transformer_config = SentenceTransformerConfig.from_json_file(output_sentence_transformer_config_file)
            logging.info("Transformer Model config {}".format(sentence_transformer_config))

            transformer_config = PretrainedConfig.from_json_file(output_transformer_config_file)
            model_class = self.import_from_string(sentence_transformer_config.model)
            self.transformer_model = model_class(transformer_config, sentence_transformer_config=sentence_transformer_config)
            self.transformer_model.load_state_dict(torch.load(output_model_file, map_location='cuda' if torch.cuda.is_available() else 'cpu'))
            
        elif sentence_transformer_config is not None:
            logging.info("Creating a new {} model with config {}".format(sentence_transformer_config.model, sentence_transformer_config))
            model_class = self.import_from_string(sentence_transformer_config.model)
            self.transformer_model = model_class.from_pretrained(sentence_transformer_config.tokenizer_model)
            self.transformer_model.set_config(sentence_transformer_config)
            
        else:
            raise ValueError("model_url, model_path and config can not be all None.")

        self.transformer_model.set_tokenizer(sentence_transformer_config.tokenizer_model, sentence_transformer_config.do_lower_case)
        self.encoder = SentenceEncoder(self.transformer_model, sentence_transformer_config)
        self.trainer = SentenceTrainer(self.transformer_model)
示例#9
0
import wget
from torch.hub import _get_torch_home

from nemo.collections.nlp.modules.common.megatron.megatron_bert import MegatronBertEncoder
from nemo.utils import AppState, logging

__all__ = [
    "get_megatron_lm_model",
    "get_megatron_lm_models_list",
    "get_megatron_checkpoint",
    "is_lower_cased_megatron",
    "get_megatron_tokenizer",
]


torch_home = _get_torch_home()

if not isinstance(torch_home, str):
    logging.info("Torch home not found, caching megatron in cwd")
    torch_home = os.getcwd()

MEGATRON_CACHE = os.path.join(torch_home, "megatron")


CONFIGS = {"345m": {"hidden_size": 1024, "num_attention_heads": 16, "num_layers": 24, "max_position_embeddings": 512}}

MEGATRON_CONFIG_MAP = {
    "megatron-bert-345m-uncased": {
        "config": CONFIGS["345m"],
        "checkpoint": "https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.0/files/release/mp_rank_00/model_optim_rng.pt",
        "vocab": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
示例#10
0
import os
import pathlib
from typing import Optional, Union

from torch.hub import _get_torch_home

HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"


def home(root: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
    global HOME
    if root is not None:
        HOME = pathlib.Path(root).expanduser().resolve()
        return HOME

    root = os.getenv("TORCHVISION_DATASETS_HOME")
    if root is not None:
        return pathlib.Path(root)

    return HOME
    def __init__(self,
                 model_name_or_path: str = None,
                 modules: Iterable[nn.Module] = None,
                 device: str = None):
        if modules is not None and not isinstance(modules, OrderedDict):
            modules = OrderedDict([(str(idx), module)
                                   for idx, module in enumerate(modules)])

        if model_name_or_path is not None and model_name_or_path != "":
            logging.info("Load pretrained SentenceTransformer: {}".format(
                model_name_or_path))

            if '/' not in model_name_or_path and '\\' not in model_name_or_path and not os.path.isdir(
                    model_name_or_path):
                logging.info(
                    "Did not find a '/' or '\\' in the name. Assume to download model from server."
                )
                model_name_or_path = __DOWNLOAD_SERVER__ + model_name_or_path + '.zip'

            if model_name_or_path.startswith(
                    'http://') or model_name_or_path.startswith('https://'):
                model_url = model_name_or_path
                folder_name = model_url.replace("https://", "").replace(
                    "http://", "").replace("/", "_")[:250]

                try:
                    from torch.hub import _get_torch_home
                    torch_cache_home = _get_torch_home()
                except ImportError:
                    torch_cache_home = os.path.expanduser(
                        os.getenv(
                            'TORCH_HOME',
                            os.path.join(
                                os.getenv('XDG_CACHE_HOME', '~/.cache'),
                                'torch')))
                default_cache_path = os.path.join(torch_cache_home,
                                                  'sentence_transformers')
                model_path = os.path.join(default_cache_path, folder_name)
                os.makedirs(model_path, exist_ok=True)

                if not os.listdir(model_path):
                    if model_url[-1] == "/":
                        model_url = model_url[:-1]
                    logging.info(
                        "Downloading sentence transformer model from {} and saving it at {}"
                        .format(model_url, model_path))
                    try:
                        zip_save_path = os.path.join(model_path, 'model.zip')
                        http_get(model_url, zip_save_path)
                        with ZipFile(zip_save_path, 'r') as zip:
                            zip.extractall(model_path)
                    except Exception as e:
                        shutil.rmtree(model_path)
                        raise e
            else:
                model_path = model_name_or_path

            #### Load from disk
            if model_path is not None:
                logging.info("Load SentenceTransformer from folder: {}".format(
                    model_path))

                if os.path.exists(os.path.join(model_path, 'config.json')):
                    with open(os.path.join(model_path, 'config.json')) as fIn:
                        config = json.load(fIn)
                        if config['__version__'] > __version__:
                            logging.warning(
                                "You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n"
                                .format(config['__version__'], __version__))

                with open(os.path.join(model_path, 'modules.json')) as fIn:
                    contained_modules = json.load(fIn)

                modules = OrderedDict()
                for module_config in contained_modules:
                    module_class = import_from_string(module_config['type'])
                    module = module_class.load(
                        os.path.join(model_path, module_config['path']))
                    modules[module_config['name']] = module

        super().__init__(modules)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info("Use pytorch device: {}".format(device))

        self._target_device = torch.device(device)
        self.parallel_tokenization = multiprocessing.get_start_method(
        ) == 'fork'  #parallel_tokenization only works if the Operating System support fork
        self.parallel_tokenization_processes = min(
            4, cpu_count()
        )  #Number of parallel processes used for tokenization. Increase up to cpu_count() for faster tokenization
        self.parallel_tokenization_chunksize = 5000  #Number of sentences sent per chunk to each process. Increase for faster tokenization
示例#12
0
def load_state_dict_from_url(url,
                             model_dir=None,
                             map_location=None,
                             progress=True,
                             check_hash=False):
    r"""Loads the Torch serialized object at the given URL.

    If downloaded file is a zip file, it will be automatically
    decompressed.

    If the object is already present in `model_dir`, it's deserialized and
    returned.
    The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
    filesytem layout, with a default value ``~/.cache`` if not set.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr.
            Default: True
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: False

    Example:
        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """
    # Issue warning to move data if old env is set
    if os.getenv("TORCH_MODEL_ZOO"):
        warnings.warn(
            "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead")

    if model_dir is None:
        torch_home = _get_torch_home()
        model_dir = os.path.join(torch_home, "checkpoints")

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = HASH_REGEX.search(filename).group(
            1) if check_hash else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    return cached_file