def load(cls, model_path: Union[str, Path]): """ Loads the model from the given file. :param model_path: the model file :return: the loaded text classifier model """ model_file = cls._fetch_model(str(model_path)) with warnings.catch_warnings(): warnings.filterwarnings("ignore") # load_big_file is a workaround byhttps://github.com/highway11git # to load models on some Mac/Windows setups # see https://github.com/zalandoresearch/flair/issues/351 f = file_utils.load_big_file(str(model_file)) state = torch.load(f, map_location="cpu") model = cls._init_model_with_state_dict(state) if "model_card" in state: model.model_card = state["model_card"] model.eval() model.to(flair.device) return model
def load_flair_upos_fast(): """Loads flair 'upos-fast' SequenceTagger. This is a temporary workaround for flair v0.6. Will be fixed when flair pushes the bug fix. """ import pathlib import warnings from flair import file_utils import torch hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models" upos_path = "/".join([hu_path, "upos-fast", "en-upos-ontonotes-fast-v0.4.pt"]) model_path = file_utils.cached_path(upos_path, cache_dir=pathlib.Path("models")) model_file = SequenceTagger._fetch_model(model_path) with warnings.catch_warnings(): warnings.filterwarnings("ignore") # load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups # see https://github.com/zalandoresearch/flair/issues/351 f = file_utils.load_big_file(str(model_file)) state = torch.load(f, map_location="cpu") model = SequenceTagger._init_model_with_state_dict(state) model.eval() model.to(textattack.shared.utils.device) return model