def from_pretrained(cls, pretrained_path, **kwargs): # load weights from hf hub if not os.path.isfile(pretrained_path): # retrieve correct hub url download_url = hf_hub_url(repo_id=pretrained_path, filename=PROCESSOR_FILE_NAME) pretrained_path = str( cached_download( url=download_url, library_name=LIBRARY_NAME, library_version=VERSION, cache_dir=CACHE_DIRECTORY, ) ) with open(pretrained_path, "r") as f: config = json.load(f) try: processor_name = config["processor_name"] processor_class = CONFIG_MAPPING[processor_name] processor_class = processor_class( data_dir=None, loaded_mapper_path=pretrained_path ) return processor_class except Exception: raise ValueError( "Unrecognized processor in {}. " "Should have a `processor_name` key in its config.json, or contain one of the following strings " "in its name: {}".format( pretrained_path, ", ".join(CONFIG_MAPPING.keys()) ) )
def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None): url = hf_hub_url(checkpoint, "config.json", revision=revision) cached_filed = cached_download(url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()) return load_cfg_from_json(cached_filed)
def from_pretrained(cls, pretrained_path, **kwargs): # load weights from hf hub if not os.path.isfile(pretrained_path): # retrieve correct hub url download_url = hf_hub_url(repo_id=pretrained_path, filename=CONFIG_FILE_NAME) pretrained_path = str( cached_download( url=download_url, library_name=LIBRARY_NAME, library_version=VERSION, cache_dir=CACHE_DIRECTORY, )) with open(pretrained_path) as f: config = yaml.load(f, Loader=yaml.SafeLoader) try: model_type = config["model_type"] config_class = CONFIG_MAPPING[model_type] config_class = config_class(**config[model_type + "_params"], **kwargs) config_class.set_pretrained_config(config) return config_class except Exception: raise ValueError( "Unrecognized config in {}. " "Should have a `model_type` key in its config.yaml, or contain one of the following strings " "in its name: {}".format(pretrained_path, ", ".join(CONFIG_MAPPING.keys())))
def get_dpt_config(checkpoint_url): config = DPTConfig() if "large" in checkpoint_url: config.hidden_size = 1024 config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 config.backbone_out_indices = [5, 11, 17, 23] config.neck_hidden_sizes = [256, 512, 1024, 1024] expected_shape = (1, 384, 384) if "ade" in checkpoint_url: config.use_batch_norm_in_fusion_residual = True config.num_labels = 150 repo_id = "datasets/huggingface/label-files" filename = "ade20k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} expected_shape = [1, 150, 480, 480] return config, expected_shape
def convert_cvt_checkpoint(cvt_file, pytorch_dump_folder): """ Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint """ img_labels_file = "imagenet-1k-id2label.json" num_labels = 1000 repo_id = "datasets/huggingface/label-files" num_labels = num_labels id2label = json.load( open(cached_download(hf_hub_url(repo_id, img_labels_file)), "r")) id2label = {int(k): v for k, v in id2label.items()} id2label = id2label label2id = {v: k for k, v in id2label.items()} config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id) # For depth size 13 (13 = 1+2+10) if cvt_file.rsplit("/", 1)[-1][4:6] == "13": config.depth = [1, 2, 10] # For depth size 21 (21 = 1+4+16) elif cvt_file.rsplit("/", 1)[-1][4:6] == "21": config.depth = [1, 4, 16] # For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20) else: config.depth = [2, 2, 20] config.num_heads = [3, 12, 16] config.embed_dim = [192, 768, 1024] model = CvtForImageClassification(config) feature_extractor = AutoFeatureExtractor.from_pretrained( "facebook/convnext-base-224-22k-1k") original_weights = torch.load(cvt_file, map_location=torch.device("cpu")) huggingface_weights = OrderedDict() list_of_state_dict = [] for idx in range(config.num_stages): if config.cls_token[idx]: list_of_state_dict = list_of_state_dict + cls_token(idx) list_of_state_dict = list_of_state_dict + embeddings(idx) for cnt in range(config.depth[idx]): list_of_state_dict = list_of_state_dict + attention(idx, cnt) list_of_state_dict = list_of_state_dict + final() for gg in list_of_state_dict: print(gg) for i in range(len(list_of_state_dict)): huggingface_weights[list_of_state_dict[i][0]] = original_weights[ list_of_state_dict[i][1]] model.load_state_dict(huggingface_weights) model.save_pretrained(pytorch_dump_folder) feature_extractor.save_pretrained(pytorch_dump_folder)
def snapshot_download(repo_id: str, revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, ignore_files: Optional[List[str]] = None) -> str: """ Method derived from huggingface_hub. Adds a new parameters 'ignore_files', which allows to ignore certain files / file-patterns """ if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) _api = HfApi() model_info = _api.model_info(repo_id=repo_id, revision=revision) storage_folder = os.path.join( cache_dir, repo_id.replace("/", REPO_ID_SEPARATOR) + "." + model_info.sha) for model_file in model_info.siblings: if ignore_files is not None: skip_download = False for pattern in ignore_files: if fnmatch.fnmatch(model_file.rfilename, pattern): skip_download = True break if skip_download: continue url = hf_hub_url(repo_id, filename=model_file.rfilename, revision=model_info.sha) relative_filepath = os.path.join(*model_file.rfilename.split("/")) # Create potential nested dir nested_dirname = os.path.dirname( os.path.join(storage_folder, relative_filepath)) os.makedirs(nested_dirname, exist_ok=True) path = cached_download( url, cache_dir=storage_folder, force_filename=relative_filepath, library_name=library_name, library_version=library_version, user_agent=user_agent, ) if os.path.exists(path + ".lock"): os.remove(path + ".lock") return storage_folder
def cached_download(filename_or_url): """Download from URL and cache the result in ASTEROID_CACHE. Args: filename_or_url (str): Name of a model as named on the Zenodo Community page (ex: ``"mpariente/ConvTasNet_WHAM!_sepclean"``), or model id from the Hugging Face model hub (ex: ``"julien-c/DPRNNTasNet-ks16_WHAM_sepclean"``), or a URL to a model file (ex: ``"https://zenodo.org/.../model.pth"``), or a filename that exists locally (ex: ``"local/tmp_model.pth"``) Returns: str, normalized path to the downloaded (or not) model """ from .. import __version__ as asteroid_version # Avoid circular imports if os.path.isfile(filename_or_url): return filename_or_url if filename_or_url.startswith(huggingface_hub.HUGGINGFACE_CO_URL_HOME): filename_or_url = filename_or_url[len(huggingface_hub. HUGGINGFACE_CO_URL_HOME):] if filename_or_url.startswith(("http://", "https://")): url = filename_or_url elif filename_or_url in MODELS_URLS_HASHTABLE: url = MODELS_URLS_HASHTABLE[filename_or_url] else: # Finally, let's try to find it on Hugging Face model hub # e.g. julien-c/DPRNNTasNet-ks16_WHAM_sepclean is a valid model id # and julien-c/DPRNNTasNet-ks16_WHAM_sepclean@main supports specifying a commit/branch/tag. if "@" in filename_or_url: model_id = filename_or_url.split("@")[0] revision = filename_or_url.split("@")[1] else: model_id = filename_or_url revision = None url = huggingface_hub.hf_hub_url( model_id, filename=huggingface_hub.PYTORCH_WEIGHTS_NAME, revision=revision) return huggingface_hub.cached_download( url, cache_dir=get_cache_dir(), library_name="asteroid", library_version=asteroid_version, ) cached_filename = url_to_filename(url) cached_dir = os.path.join(get_cache_dir(), cached_filename) cached_path = os.path.join(cached_dir, "model.pth") os.makedirs(cached_dir, exist_ok=True) if not os.path.isfile(cached_path): hub.download_url_to_file(url, cached_path) return cached_path # It was already downloaded print(f"Using cached model `{filename_or_url}`") return cached_path
def has_file( path_or_repo: Union[str, os.PathLike], filename: str, revision: Optional[str] = None, proxies: Optional[Dict[str, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, ): """ Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders. <Tip warning={false}> This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for this repo, but will return False for regular connection errors. </Tip> """ if os.path.isdir(path_or_repo): return os.path.isfile(os.path.join(path_or_repo, filename)) url = hf_hub_url(path_or_repo, filename=filename, revision=revision) headers = {"user-agent": http_user_agent()} if isinstance(use_auth_token, str): headers["authorization"] = f"Bearer {use_auth_token}" elif use_auth_token: token = HfFolder.get_token() if token is None: raise EnvironmentError( "You specified use_auth_token=True, but a huggingface token was not found." ) headers["authorization"] = f"Bearer {token}" r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10) try: huggingface_hub.utils._errors._raise_for_status(r) return True except RepositoryNotFoundError as e: logger.error(e) raise EnvironmentError( f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'." ) except RevisionNotFoundError as e: logger.error(e) raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions." ) except requests.HTTPError: # We return false for EntryNotFoundError (logical) as well as any connection error. return False
def load_sparse_weight(self, model_name_or_path): sparse_weight_path = os.path.join(model_name_or_path, 'sparse_weight.pt') # check file exists if not os.path.isfile(sparse_weight_path): # download from huggingface hub and cache it sparse_weight_url = hf_hub_url(model_name_or_path, filename="sparse_weight.pt") sparse_weight_path = cached_download(sparse_weight_url) self.sparse_weight = torch.load(sparse_weight_path) return self.sparse_weight
def load_sparse_encoder(self, model_name_or_path): sparse_encoder_path = os.path.join(model_name_or_path, 'sparse_encoder.pk') # check file exists if not os.path.isfile(sparse_encoder_path): # download from huggingface hub and cache it sparse_encoder_url = hf_hub_url(model_name_or_path, filename="sparse_encoder.pk") sparse_encoder_path = cached_download(sparse_encoder_url) self.sparse_encoder = SparseEncoder().load_encoder( path=sparse_encoder_path) return self.sparse_encoder
def from_pretrained(cls, config=None, pretrained_path=None, **kwargs): is_build = kwargs.pop("is_build", True) # load weights from hf hub if pretrained_path is not None: if not os.path.isfile(pretrained_path): # retrieve correct hub url download_url = hf_hub_url(repo_id=pretrained_path, filename=MODEL_FILE_NAME) downloaded_file = str( cached_download( url=download_url, library_name=LIBRARY_NAME, library_version=VERSION, cache_dir=CACHE_DIRECTORY, )) # load config from repo as well if config is None: from tensorflow_tts.inference import AutoConfig config = AutoConfig.from_pretrained(pretrained_path) pretraine_path = downloaded_file assert config is not None, "Please make sure to pass a config along to load a model from a local file" for config_class, model_class in TF_MODEL_MAPPING.items(): if isinstance(config, config_class) and str( config_class.__name__) in str(config): model = model_class(config=config, **kwargs) if is_build: model._build() if pretrained_path is not None and ".h5" in pretrained_path: try: model.load_weights(pretrained_path) except: model.load_weights(pretrained_path, by_name=True, skip_mismatch=True) return model raise ValueError( "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Model type should be one of {}.".format( config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys()), ))
def _hf_hub_download(url, model_identifier: str, filename: Optional[str], cache_dir: Union[str, Path]) -> str: revision: Optional[str] if "@" in model_identifier: repo_id = model_identifier.split("@")[0] revision = model_identifier.split("@")[1] else: repo_id = model_identifier revision = None if filename is not None: hub_url = hf_hub.hf_hub_url(repo_id=repo_id, filename=filename, revision=revision) cache_path = str( hf_hub.cached_download( url=hub_url, library_name="allennlp", library_version=VERSION, cache_dir=cache_dir, )) # HF writes it's own meta '.json' file which uses the same format we used to use and still # support, but is missing some fields that we like to have. # So we overwrite it when it we can. with FileLock(cache_path + ".lock", read_only_ok=True): meta = _Meta.from_path(cache_path + ".json") # The file HF writes will have 'resource' set to the 'http' URL corresponding to the 'hf://' URL, # but we want 'resource' to be the original 'hf://' URL. if meta.resource != url: meta.resource = url meta.to_file() else: cache_path = str( hf_hub.snapshot_download(repo_id, revision=revision, cache_dir=cache_dir)) # Need to write the meta file for snapshot downloads if it doesn't exist. with FileLock(cache_path + ".lock", read_only_ok=True): if not os.path.exists(cache_path + ".json"): meta = _Meta( resource=url, cached_path=cache_path, creation_time=time.time(), extraction_dir=True, size=_get_resource_size(cache_path), ) meta.to_file() return cache_path
def get_swin_config(swin_name): config = SwinConfig() name_split = swin_name.split("_") model_size = name_split[1] img_size = int(name_split[4]) window_size = int(name_split[3][-1]) if model_size == "tiny": embed_dim = 96 depths = (2, 2, 6, 2) num_heads = (3, 6, 12, 24) elif model_size == "small": embed_dim = 96 depths = (2, 2, 18, 2) num_heads = (3, 6, 12, 24) elif model_size == "base": embed_dim = 128 depths = (2, 2, 18, 2) num_heads = (4, 8, 16, 32) else: embed_dim = 192 depths = (2, 2, 18, 2) num_heads = (6, 12, 24, 48) if "in22k" in swin_name: num_classes = 21841 else: num_classes = 1000 repo_id = "datasets/huggingface/label-files" filename = "imagenet-1k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.image_size = img_size config.num_labels = num_classes config.embed_dim = embed_dim config.depths = depths config.num_heads = num_heads config.window_size = window_size return config
def get_convnext_config(checkpoint_url): config = ConvNextConfig() if "tiny" in checkpoint_url: depths = [3, 3, 9, 3] hidden_sizes = [96, 192, 384, 768] if "small" in checkpoint_url: depths = [3, 3, 27, 3] hidden_sizes = [96, 192, 384, 768] if "base" in checkpoint_url: depths = [3, 3, 27, 3] hidden_sizes = [128, 256, 512, 1024] if "large" in checkpoint_url: depths = [3, 3, 27, 3] hidden_sizes = [192, 384, 768, 1536] if "xlarge" in checkpoint_url: depths = [3, 3, 27, 3] hidden_sizes = [256, 512, 1024, 2048] if "1k" in checkpoint_url: num_labels = 1000 filename = "imagenet-1k-id2label.json" expected_shape = (1, 1000) else: num_labels = 21841 filename = "imagenet-22k-id2label.json" expected_shape = (1, 21841) repo_id = "datasets/huggingface/label-files" config.num_labels = num_labels id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} if "1k" not in checkpoint_url: # this dataset contains 21843 labels but the model only has 21841 # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18 del id2label[9205] del id2label[15027] config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.hidden_sizes = hidden_sizes config.depths = depths return config, expected_shape
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True): filename = "imagenet-1k-id2label.json" num_labels = 1000 expected_shape = (1, num_labels) repo_id = "datasets/huggingface/label-files" num_labels = num_labels id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} id2label = id2label label2id = {v: k for k, v in id2label.items()} ImageNetPreTrainedConfig = partial(ResNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) names_to_config = { "resnet18": ImageNetPreTrainedConfig( depths=[2, 2, 2, 2], hidden_sizes=[64, 128, 256, 512], layer_type="basic" ), "resnet26": ImageNetPreTrainedConfig( depths=[2, 2, 2, 2], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" ), "resnet34": ImageNetPreTrainedConfig( depths=[3, 4, 6, 3], hidden_sizes=[64, 128, 256, 512], layer_type="basic" ), "resnet50": ImageNetPreTrainedConfig( depths=[3, 4, 6, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" ), "resnet101": ImageNetPreTrainedConfig( depths=[3, 4, 23, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" ), "resnet152": ImageNetPreTrainedConfig( depths=[3, 8, 36, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" ), } if model_name: convert_weight_and_push(model_name, names_to_config[model_name], save_directory, push_to_hub) else: for model_name, config in names_to_config.items(): convert_weight_and_push(model_name, config, save_directory, push_to_hub) return config, expected_shape
def test_full_deserialization_hub(self): # Check we can read this file. # This used to fail because of BufReader that would fail because the # file exceeds the buffer capacity api = HfApi() not_loadable = [] invalid_pre_tokenizer = [] # models = api.list_models(filter="transformers") # for model in tqdm.tqdm(models): # model_id = model.modelId # for model_file in model.siblings: # filename = model_file.rfilename # if filename == "tokenizer.json": # all_models.append((model_id, filename)) all_models = [("HueyNemud/das22-10-camembert_pretrained", "tokenizer.json")] for (model_id, filename) in tqdm.tqdm(all_models): tokenizer_file = cached_download( hf_hub_url(model_id, filename=filename)) is_ok = check(tokenizer_file) if not is_ok: print(f"{model_id} is affected by no type") invalid_pre_tokenizer.append(model_id) try: Tokenizer.from_file(tokenizer_file) except Exception as e: print(f"{model_id} is not loadable: {e}") not_loadable.append(model_id) except: print(f"{model_id} is not loadable: Rust error") not_loadable.append(model_id) self.assertEqual(invalid_pre_tokenizer, []) self.assertEqual(not_loadable, [])
def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architecture="MLM"): """ Copy/paste/tweak model's weights to our Perceiver structure. """ # load parameters as FlatMapping data structure with open(pickle_file, "rb") as f: checkpoint = pickle.loads(f.read()) state = None if isinstance(checkpoint, dict) and architecture in [ "image_classification", "image_classification_fourier", "image_classification_conv", ]: # the image classification_conv checkpoint also has batchnorm states (running_mean and running_var) params = checkpoint["params"] state = checkpoint["state"] else: params = checkpoint # turn into initial state dict state_dict = dict() for scope_name, parameters in hk.data_structures.to_mutable_dict( params).items(): for param_name, param in parameters.items(): state_dict[scope_name + "/" + param_name] = param if state is not None: # add state variables for scope_name, parameters in hk.data_structures.to_mutable_dict( state).items(): for param_name, param in parameters.items(): state_dict[scope_name + "/" + param_name] = param # rename keys rename_keys(state_dict, architecture=architecture) # load HuggingFace model config = PerceiverConfig() subsampling = None repo_id = "datasets/huggingface/label-files" if architecture == "MLM": config.qk_channels = 8 * 32 config.v_channels = 1280 model = PerceiverForMaskedLM(config) elif "image_classification" in architecture: config.num_latents = 512 config.d_latents = 1024 config.d_model = 512 config.num_blocks = 8 config.num_self_attends_per_block = 6 config.num_cross_attention_heads = 1 config.num_self_attention_heads = 8 config.qk_channels = None config.v_channels = None # set labels config.num_labels = 1000 filename = "imagenet-1k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} if architecture == "image_classification": config.image_size = 224 model = PerceiverForImageClassificationLearned(config) elif architecture == "image_classification_fourier": config.d_model = 261 model = PerceiverForImageClassificationFourier(config) elif architecture == "image_classification_conv": config.d_model = 322 model = PerceiverForImageClassificationConvProcessing(config) else: raise ValueError(f"Architecture {architecture} not supported") elif architecture == "optical_flow": config.num_latents = 2048 config.d_latents = 512 config.d_model = 322 config.num_blocks = 1 config.num_self_attends_per_block = 24 config.num_self_attention_heads = 16 config.num_cross_attention_heads = 1 model = PerceiverForOpticalFlow(config) elif architecture == "multimodal_autoencoding": config.num_latents = 28 * 28 * 1 config.d_latents = 512 config.d_model = 704 config.num_blocks = 1 config.num_self_attends_per_block = 8 config.num_self_attention_heads = 8 config.num_cross_attention_heads = 1 config.num_labels = 700 # define dummy inputs + subsampling (as each forward pass is only on a chunk of image + audio data) images = torch.randn((1, 16, 3, 224, 224)) audio = torch.randn((1, 30720, 1)) nchunks = 128 image_chunk_size = np.prod((16, 224, 224)) // nchunks audio_chunk_size = audio.shape[1] // config.samples_per_patch // nchunks # process the first chunk chunk_idx = 0 subsampling = { "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)), "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)), "label": None, } model = PerceiverForMultimodalAutoencoding(config) # set labels filename = "kinetics700-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} else: raise ValueError(f"Architecture {architecture} not supported") model.eval() # load weights model.load_state_dict(state_dict) # prepare dummy input input_mask = None if architecture == "MLM": tokenizer = PerceiverTokenizer.from_pretrained( "/Users/NielsRogge/Documents/Perceiver/Tokenizer files") text = "This is an incomplete sentence where some words are missing." encoding = tokenizer(text, padding="max_length", return_tensors="pt") # mask " missing.". Note that the model performs much better if the masked chunk starts with a space. encoding.input_ids[0, 51:60] = tokenizer.mask_token_id inputs = encoding.input_ids input_mask = encoding.attention_mask elif architecture in [ "image_classification", "image_classification_fourier", "image_classification_conv" ]: feature_extractor = PerceiverFeatureExtractor() image = prepare_img() encoding = feature_extractor(image, return_tensors="pt") inputs = encoding.pixel_values elif architecture == "optical_flow": inputs = torch.randn(1, 2, 27, 368, 496) elif architecture == "multimodal_autoencoding": images = torch.randn((1, 16, 3, 224, 224)) audio = torch.randn((1, 30720, 1)) inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700))) # forward pass if architecture == "multimodal_autoencoding": outputs = model(inputs=inputs, attention_mask=input_mask, subsampled_output_points=subsampling) else: outputs = model(inputs=inputs, attention_mask=input_mask) logits = outputs.logits # verify logits if not isinstance(logits, dict): print("Shape of logits:", logits.shape) else: for k, v in logits.items(): print(f"Shape of logits of modality {k}", v.shape) if architecture == "MLM": expected_slice = torch.tensor([[-11.8336, -11.6850, -11.8483], [-12.8149, -12.5863, -12.7904], [-12.8440, -12.6410, -12.8646]]) assert torch.allclose(logits[0, :3, :3], expected_slice) masked_tokens_predictions = logits[0, 51:60].argmax(dim=-1).tolist() expected_list = [38, 115, 111, 121, 121, 111, 116, 109, 52] assert masked_tokens_predictions == expected_list print("Greedy predictions:") print(masked_tokens_predictions) print() print("Predicted string:") print(tokenizer.decode(masked_tokens_predictions)) elif architecture in [ "image_classification", "image_classification_fourier", "image_classification_conv" ]: print("Predicted class:", model.config.id2label[logits.argmax(-1).item()]) # Finally, save files Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path)
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our BEiT structure. """ # define default BEiT configuration config = BeitConfig() has_lm_head = False is_semantic = False repo_id = "datasets/huggingface/label-files" # set config parameters based on URL if checkpoint_url[-9:-4] == "pt22k": # masked image modeling config.use_shared_relative_position_bias = True config.use_mask_token = True has_lm_head = True elif checkpoint_url[-9:-4] == "ft22k": # intermediate fine-tuning on ImageNet-22k config.use_relative_position_bias = True config.num_labels = 21841 filename = "imagenet-22k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} # this dataset contains 21843 labels but the model only has 21841 # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18 del id2label[9205] del id2label[15027] config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} elif checkpoint_url[-8:-4] == "to1k": # fine-tuning on ImageNet-1k config.use_relative_position_bias = True config.num_labels = 1000 filename = "imagenet-1k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} if "384" in checkpoint_url: config.image_size = 384 if "512" in checkpoint_url: config.image_size = 512 elif "ade20k" in checkpoint_url: # fine-tuning config.use_relative_position_bias = True config.num_labels = 150 filename = "ade20k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.image_size = 640 is_semantic = True else: raise ValueError( "Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'" ) # size of the architecture if "base" in checkpoint_url: pass elif "large" in checkpoint_url: config.hidden_size = 1024 config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 if "ade20k" in checkpoint_url: config.image_size = 640 config.out_indices = [7, 11, 15, 23] else: raise ValueError( "Should either find 'base' or 'large' in checkpoint URL") # load state_dict of original model, remove and rename some keys state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True) state_dict = state_dict[ "model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"] rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic) for src, dest in rename_keys: rename_key(state_dict, src, dest) read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic) if is_semantic: # add prefix to decoder keys for key, val in state_dict.copy().items(): val = state_dict.pop(key) if key.startswith("backbone.fpn"): key = key.replace("backbone.fpn", "fpn") state_dict[key] = val # load HuggingFace model if checkpoint_url[-9:-4] == "pt22k": model = BeitForMaskedImageModeling(config) elif "ade20k" in checkpoint_url: model = BeitForSemanticSegmentation(config) else: model = BeitForImageClassification(config) model.eval() model.load_state_dict(state_dict) # Check outputs on an image if is_semantic: feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False) ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") image = Image.open(ds[0]["file"]) else: feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False) image = prepare_img() encoding = feature_extractor(images=image, return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) logits = outputs.logits # verify logits expected_shape = torch.Size([1, 1000]) if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"): expected_shape = torch.Size([1, 196, 8192]) elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"): expected_shape = torch.Size([1, 196, 8192]) elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"): expected_shape = torch.Size([1, 21841]) expected_logits = torch.tensor([2.2288, 2.4671, 0.7395]) expected_class_idx = 2397 elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"): expected_shape = torch.Size([1, 21841]) expected_logits = torch.tensor([1.6881, -0.2787, 0.5901]) expected_class_idx = 2396 elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"): expected_logits = torch.tensor([0.1241, 0.0798, -0.6569]) expected_class_idx = 285 elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"): expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108]) expected_class_idx = 281 elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"): expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147]) expected_class_idx = 761 elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"): expected_logits = torch.tensor([0.4610, -0.0928, 0.2086]) expected_class_idx = 761 elif checkpoint_url[:-4].endswith( "beit_large_patch16_224_pt22k_ft22kto1k"): expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837]) expected_class_idx = 761 elif checkpoint_url[:-4].endswith( "beit_large_patch16_384_pt22k_ft22kto1k"): expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]]) expected_class_idx = 761 elif checkpoint_url[:-4].endswith( "beit_large_patch16_512_pt22k_ft22kto1k"): expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852]) expected_class_idx = 761 elif checkpoint_url[:-4].endswith( "beit_base_patch16_640_pt22k_ft22ktoade20k"): expected_shape = (1, 150, 160, 160) expected_logits = torch.tensor([ [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], ]) elif checkpoint_url[:-4].endswith( "beit_large_patch16_640_pt22k_ft22ktoade20k"): expected_shape = (1, 150, 160, 160) expected_logits = torch.tensor([ [[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]], [[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]], [[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]], ]) else: raise ValueError("Can't verify logits as model is not supported") assert logits.shape == expected_shape, "Shape of logits not as expected" if not has_lm_head: if is_semantic: assert torch.allclose( logits[0, :3, :3, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" else: print("Predicted class idx:", logits.argmax(-1).item()) assert torch.allclose( logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" assert logits.argmax(-1).item( ) == expected_class_idx, "Predicted class index not as expected" Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) print(f"Saving feature extractor to {pytorch_dump_folder_path}") feature_extractor.save_pretrained(pytorch_dump_folder_path)
def convert_deit_checkpoint(deit_name, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our DeiT structure. """ # define default DeiT configuration config = DeiTConfig() # all deit models have fine-tuned heads base_model = False # dataset (fine-tuned on ImageNet 2012), patch_size and image_size config.num_labels = 1000 repo_id = "datasets/huggingface/label-files" filename = "imagenet-1k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.patch_size = int(deit_name[-6:-4]) config.image_size = int(deit_name[-3:]) # size of the architecture if deit_name[9:].startswith("tiny"): config.hidden_size = 192 config.intermediate_size = 768 config.num_hidden_layers = 12 config.num_attention_heads = 3 elif deit_name[9:].startswith("small"): config.hidden_size = 384 config.intermediate_size = 1536 config.num_hidden_layers = 12 config.num_attention_heads = 6 if deit_name[9:].startswith("base"): pass elif deit_name[4:].startswith("large"): config.hidden_size = 1024 config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 # load original model from timm timm_model = timm.create_model(deit_name, pretrained=True) timm_model.eval() # load state_dict of original model, remove and rename some keys state_dict = timm_model.state_dict() rename_keys = create_rename_keys(config, base_model) for src, dest in rename_keys: rename_key(state_dict, src, dest) read_in_q_k_v(state_dict, config, base_model) # load HuggingFace model model = DeiTForImageClassificationWithTeacher(config).eval() model.load_state_dict(state_dict) # Check outputs on an image, prepared by DeiTFeatureExtractor size = int( (256 / 224) * config.image_size ) # to maintain same ratio w.r.t. 224 images, see https://github.com/facebookresearch/deit/blob/ab5715372db8c6cad5740714b2216d55aeae052e/datasets.py#L103 feature_extractor = DeiTFeatureExtractor(size=size, crop_size=config.image_size) encoding = feature_extractor(images=prepare_img(), return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) timm_logits = timm_model(pixel_values) assert timm_logits.shape == outputs.logits.shape assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model {deit_name} to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) print(f"Saving feature extractor to {pytorch_dump_folder_path}") feature_extractor.save_pretrained(pytorch_dump_folder_path)
def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our PoolFormer structure. """ # load default PoolFormer configuration config = PoolFormerConfig() # set attributes based on model_name repo_id = "datasets/huggingface/label-files" size = model_name[-3:] config.num_labels = 1000 filename = "imagenet-1k-id2label.json" expected_shape = (1, 1000) # set config attributes id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} if size == "s12": config.depths = [2, 2, 6, 2] config.hidden_sizes = [64, 128, 320, 512] config.mlp_ratio = 4.0 crop_pct = 0.9 elif size == "s24": config.depths = [4, 4, 12, 4] config.hidden_sizes = [64, 128, 320, 512] config.mlp_ratio = 4.0 crop_pct = 0.9 elif size == "s36": config.depths = [6, 6, 18, 6] config.hidden_sizes = [64, 128, 320, 512] config.mlp_ratio = 4.0 config.layer_scale_init_value = 1e-6 crop_pct = 0.9 elif size == "m36": config.depths = [6, 6, 18, 6] config.hidden_sizes = [96, 192, 384, 768] config.mlp_ratio = 4.0 config.layer_scale_init_value = 1e-6 crop_pct = 0.95 elif size == "m48": config.depths = [8, 8, 24, 8] config.hidden_sizes = [96, 192, 384, 768] config.mlp_ratio = 4.0 config.layer_scale_init_value = 1e-6 crop_pct = 0.95 else: raise ValueError(f"Size {size} not supported") # load feature extractor feature_extractor = PoolFormerFeatureExtractor(crop_pct=crop_pct) # Prepare image image = prepare_img() pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values logger.info(f"Converting model {model_name}...") # load original state dict state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) # rename keys state_dict = rename_keys(state_dict) # create HuggingFace model and load state dict model = PoolFormerForImageClassification(config) model.load_state_dict(state_dict) model.eval() # Define feature extractor feature_extractor = PoolFormerFeatureExtractor(crop_pct=crop_pct) pixel_values = feature_extractor(images=prepare_img(), return_tensors="pt").pixel_values # forward pass outputs = model(pixel_values) logits = outputs.logits # define expected logit slices for different models if size == "s12": expected_slice = torch.tensor([-0.3045, -0.6758, -0.4869]) elif size == "s24": expected_slice = torch.tensor([0.4402, -0.1374, -0.8045]) elif size == "s36": expected_slice = torch.tensor([-0.6080, -0.5133, -0.5898]) elif size == "m36": expected_slice = torch.tensor([0.3952, 0.2263, -1.2668]) elif size == "m48": expected_slice = torch.tensor([0.1167, -0.0656, -0.3423]) else: raise ValueError(f"Size {size} not supported") # verify logits assert logits.shape == expected_shape assert torch.allclose(logits[0, :3], expected_slice, atol=1e-2) # finally, save model and feature extractor logger.info(f"Saving PyTorch model and feature extractor to {pytorch_dump_folder_path}...") Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path) print(f"Saving feature extractor to {pytorch_dump_folder_path}") feature_extractor.save_pretrained(pytorch_dump_folder_path)
def _download_from_hf(model_id: str, filename: str): hf_model_id, hf_revision = hf_split(model_id) url = hf_hub_url(hf_model_id, filename, revision=hf_revision) return cached_download(url, cache_dir=get_cache_dir('hf'))
def fetch( filename, source, savedir="./pretrained_model_checkpoints", overwrite=False, save_filename=None, use_auth_token=False, ): """Ensures you have a local copy of the file, returns its path In case the source is an external location, downloads the file. In case the source is already accessible on the filesystem, creates a symlink in the savedir. Thus, the side effects of this function always look similar: savedir/save_filename can be used to access the file. And save_filename defaults to the filename arg. Arguments --------- filename : str Name of the file including extensions. source : str Where to look for the file. This is interpreted in special ways: First, if the source begins with "http://" or "https://", it is interpreted as a web address and the file is downloaded. Second, if the source is a valid directory path, a symlink is created to the file. Otherwise, the source is interpreted as a Huggingface model hub ID, and the file is downloaded from there. savedir : str Path where to save downloads/symlinks. overwrite : bool If True, always overwrite existing savedir/filename file and download or recreate the link. If False (as by default), if savedir/filename exists, assume it is correct and don't download/relink. Note that Huggingface local cache is always used - with overwrite=True we just relink from the local cache. save_filename : str The filename to use for saving this file. Defaults to filename if not given. use_auth_token : bool (default: False) If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub, default is False because majority of models are public. Returns ------- pathlib.Path Path to file on local file system. """ if save_filename is None: save_filename = filename savedir = pathlib.Path(savedir) savedir.mkdir(parents=True, exist_ok=True) sourcefile = f"{source}/{filename}" destination = savedir / save_filename if destination.exists() and not overwrite: MSG = f"Fetch {filename}: Using existing file/symlink in {str(destination)}." logger.info(MSG) return destination if str(source).startswith("http:") or str(source).startswith("https:"): # Interpret source as web address. MSG = ( f"Fetch {filename}: Downloading from normal URL {str(sourcefile)}." ) logger.info(MSG) # Download try: urllib.request.urlretrieve(sourcefile, destination) except urllib.error.URLError: raise ValueError( f"Interpreted {source} as web address, but could not download." ) elif pathlib.Path(source).is_dir(): # Interpret source as local directory path # Just symlink sourcepath = pathlib.Path(sourcefile).absolute() MSG = f"Fetch {filename}: Linking to local file in {str(sourcepath)}." logger.info(MSG) _missing_ok_unlink(destination) destination.symlink_to(sourcepath) else: # Interpret source as huggingface hub ID # Use huggingface hub's fancy cached download. MSG = f"Fetch {filename}: Delegating to Huggingface hub, source {str(source)}." logger.info(MSG) url = huggingface_hub.hf_hub_url(source, filename) fetched_file = huggingface_hub.cached_download(url, use_auth_token) # Huggingface hub downloads to etag filename, symlink to the expected one: sourcepath = pathlib.Path(fetched_file).absolute() _missing_ok_unlink(destination) destination.symlink_to(sourcepath) return destination
def __init__(self, model_id: str): self.model = joblib.load( open(cached_download(hf_hub_url(model_id, DEFAULT_FILENAME)), "rb"))
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True): filename = "imagenet-1k-id2label.json" num_labels = 1000 repo_id = "datasets/huggingface/label-files" num_labels = num_labels id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} id2label = id2label label2id = {v: k for k, v in id2label.items()} ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) names_to_config = { "regnet-y-10b-seer": ImageNetPreTrainedConfig( depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 ), # finetuned on imagenet "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig( depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 ), } # add seer weights logic def load_using_classy_vision(checkpoint_url: str) -> Tuple[Dict, Dict]: files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu") # check if we have a head, if yes add it model_state_dict = files["classy_state_dict"]["base_model"]["model"] return model_state_dict["trunk"], model_state_dict["heads"] names_to_from_model = { "regnet-y-10b-seer": partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch", ), "regnet-y-10b-seer-in1k": partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch", ), } from_to_ours_keys = get_from_to_our_keys(model_name) if not (save_directory / f"{model_name}.pth").exists(): logger.info("Loading original state_dict.") from_state_dict_trunk, from_state_dict_head = names_to_from_model[model_name]() from_state_dict = from_state_dict_trunk if "in1k" in model_name: # add the head from_state_dict = {**from_state_dict_trunk, **from_state_dict_head} logger.info("Done!") converted_state_dict = {} not_used_keys = list(from_state_dict.keys()) regex = r"\.block.-part." # this is "interesting", so the original checkpoints have `block[0,1]-part` in each key name, we remove it for key in from_state_dict.keys(): # remove the weird "block[0,1]-part" from the key src_key = re.sub(regex, "", key) # now src_key from the model checkpoints is the one we got from the original model after tracing, so use it to get the correct destination key dest_key = from_to_ours_keys[src_key] # store the parameter with our key converted_state_dict[dest_key] = from_state_dict[key] not_used_keys.remove(key) # check that all keys have been updated assert len(not_used_keys) == 0, f"Some keys where not used {','.join(not_used_keys)}" logger.info(f"The following keys were not used: {','.join(not_used_keys)}") # save our state dict to disk torch.save(converted_state_dict, save_directory / f"{model_name}.pth") del converted_state_dict else: logger.info("The state_dict was already stored on disk.") if push_to_hub: logger.info(f"Token is {os.environ['HF_TOKEN']}") logger.info("Loading our model.") # create our model our_config = names_to_config[model_name] our_model_func = RegNetModel if "in1k" in model_name: our_model_func = RegNetForImageClassification our_model = our_model_func(our_config) # place our model to the meta device (so remove all the weights) our_model.to(torch.device("meta")) logger.info("Loading state_dict in our model.") # load state dict state_dict_keys = our_model.state_dict().keys() PreTrainedModel._load_pretrained_model_low_mem( our_model, state_dict_keys, [save_directory / f"{model_name}.pth"] ) logger.info("Finally, pushing!") # push it to hub our_model.push_to_hub( repo_path_or_name=save_directory / model_name, commit_message="Add model", output_dir=save_directory / model_name, ) size = 384 # we can use the convnext one feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size) feature_extractor.push_to_hub( repo_path_or_name=save_directory / model_name, commit_message="Add feature extractor", output_dir=save_directory / model_name, )
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our DETR structure. """ # load default config config = DetrConfig() # set backbone and dilation attributes if "resnet101" in model_name: config.backbone = "resnet101" if "dc5" in model_name: config.dilation = True is_panoptic = "panoptic" in model_name if is_panoptic: config.num_labels = 250 else: config.num_labels = 91 repo_id = "datasets/huggingface/label-files" filename = "coco-detection-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} # load feature extractor format = "coco_panoptic" if is_panoptic else "coco_detection" feature_extractor = DetrFeatureExtractor(format=format) # prepare image img = prepare_img() encoding = feature_extractor(images=img, return_tensors="pt") pixel_values = encoding["pixel_values"] logger.info(f"Converting model {model_name}...") # load original model from torch hub detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval() state_dict = detr.state_dict() # rename keys for src, dest in rename_keys: if is_panoptic: src = "detr." + src rename_key(state_dict, src, dest) state_dict = rename_backbone_keys(state_dict) # query, key and value matrices need special treatment read_in_q_k_v(state_dict, is_panoptic=is_panoptic) # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them prefix = "detr.model." if is_panoptic else "model." for key in state_dict.copy().keys(): if is_panoptic: if (key.startswith("detr") and not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor")): val = state_dict.pop(key) state_dict["detr.model" + key[4:]] = val elif "class_labels_classifier" in key or "bbox_predictor" in key: val = state_dict.pop(key) state_dict["detr." + key] = val elif key.startswith("bbox_attention") or key.startswith( "mask_head"): continue else: val = state_dict.pop(key) state_dict[prefix + key] = val else: if not key.startswith("class_labels_classifier" ) and not key.startswith("bbox_predictor"): val = state_dict.pop(key) state_dict[prefix + key] = val # finally, create HuggingFace model and load state dict model = DetrForSegmentation( config) if is_panoptic else DetrForObjectDetection(config) model.load_state_dict(state_dict) model.eval() # verify our conversion original_outputs = detr(pixel_values) outputs = model(pixel_values) assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4) assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4) if is_panoptic: assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4) # Save model and feature extractor logger.info( f"Saving PyTorch model and feature extractor to {pytorch_dump_folder_path}..." ) Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path) feature_extractor.save_pretrained(pytorch_dump_folder_path)
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True): filename = "imagenet-1k-id2label.json" num_labels = 1000 expected_shape = (1, num_labels) repo_id = "datasets/huggingface/label-files" num_labels = num_labels id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} id2label = id2label label2id = {v: k for k, v in id2label.items()} ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) names_to_config = { "regnet-x-002": ImageNetPreTrainedConfig(depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8, layer_type="x"), "regnet-x-004": ImageNetPreTrainedConfig(depths=[1, 2, 7, 12], hidden_sizes=[32, 64, 160, 384], groups_width=16, layer_type="x"), "regnet-x-006": ImageNetPreTrainedConfig(depths=[1, 3, 5, 7], hidden_sizes=[48, 96, 240, 528], groups_width=24, layer_type="x"), "regnet-x-008": ImageNetPreTrainedConfig(depths=[1, 3, 7, 5], hidden_sizes=[64, 128, 288, 672], groups_width=16, layer_type="x"), "regnet-x-016": ImageNetPreTrainedConfig(depths=[2, 4, 10, 2], hidden_sizes=[72, 168, 408, 912], groups_width=24, layer_type="x"), "regnet-x-032": ImageNetPreTrainedConfig(depths=[2, 6, 15, 2], hidden_sizes=[96, 192, 432, 1008], groups_width=48, layer_type="x"), "regnet-x-040": ImageNetPreTrainedConfig(depths=[2, 5, 14, 2], hidden_sizes=[80, 240, 560, 1360], groups_width=40, layer_type="x"), "regnet-x-064": ImageNetPreTrainedConfig(depths=[2, 4, 10, 1], hidden_sizes=[168, 392, 784, 1624], groups_width=56, layer_type="x"), "regnet-x-080": ImageNetPreTrainedConfig(depths=[2, 5, 15, 1], hidden_sizes=[80, 240, 720, 1920], groups_width=120, layer_type="x"), "regnet-x-120": ImageNetPreTrainedConfig(depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112, layer_type="x"), "regnet-x-160": ImageNetPreTrainedConfig(depths=[2, 6, 13, 1], hidden_sizes=[256, 512, 896, 2048], groups_width=128, layer_type="x"), "regnet-x-320": ImageNetPreTrainedConfig(depths=[2, 7, 13, 1], hidden_sizes=[336, 672, 1344, 2520], groups_width=168, layer_type="x"), # y variant "regnet-y-002": ImageNetPreTrainedConfig(depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8), "regnet-y-004": ImageNetPreTrainedConfig(depths=[1, 3, 6, 6], hidden_sizes=[48, 104, 208, 440], groups_width=8), "regnet-y-006": ImageNetPreTrainedConfig(depths=[1, 3, 7, 4], hidden_sizes=[48, 112, 256, 608], groups_width=16), "regnet-y-008": ImageNetPreTrainedConfig(depths=[1, 3, 8, 2], hidden_sizes=[64, 128, 320, 768], groups_width=16), "regnet-y-016": ImageNetPreTrainedConfig(depths=[2, 6, 17, 2], hidden_sizes=[48, 120, 336, 888], groups_width=24), "regnet-y-032": ImageNetPreTrainedConfig(depths=[2, 5, 13, 1], hidden_sizes=[72, 216, 576, 1512], groups_width=24), "regnet-y-040": ImageNetPreTrainedConfig(depths=[2, 6, 12, 2], hidden_sizes=[128, 192, 512, 1088], groups_width=64), "regnet-y-064": ImageNetPreTrainedConfig(depths=[2, 7, 14, 2], hidden_sizes=[144, 288, 576, 1296], groups_width=72), "regnet-y-080": ImageNetPreTrainedConfig(depths=[2, 4, 10, 1], hidden_sizes=[168, 448, 896, 2016], groups_width=56), "regnet-y-120": ImageNetPreTrainedConfig(depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112), "regnet-y-160": ImageNetPreTrainedConfig(depths=[2, 4, 11, 1], hidden_sizes=[224, 448, 1232, 3024], groups_width=112), "regnet-y-320": ImageNetPreTrainedConfig(depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232), # models created by SEER -> https://arxiv.org/abs/2202.08360 "regnet-y-320-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232), "regnet-y-640-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328), "regnet-y-1280-seer": RegNetConfig(depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264), "regnet-y-2560-seer": RegNetConfig(depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640), "regnet-y-10b-seer": ImageNetPreTrainedConfig(depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010), # finetuned on imagenet "regnet-y-320-seer-in1k": ImageNetPreTrainedConfig(depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232), "regnet-y-640-seer-in1k": ImageNetPreTrainedConfig(depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328), "regnet-y-1280-seer-in1k": ImageNetPreTrainedConfig(depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264), "regnet-y-2560-seer-in1k": ImageNetPreTrainedConfig(depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640), "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig(depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010), } names_to_ours_model_map = NameToOurModelFuncMap() names_to_from_model_map = NameToFromModelFuncMap() # add seer weights logic def load_using_classy_vision( checkpoint_url: str, model_func: Callable[[], nn.Module]) -> Tuple[nn.Module, Dict]: files = torch.hub.load_state_dict_from_url( checkpoint_url, model_dir=str(save_directory), map_location="cpu") model = model_func() # check if we have a head, if yes add it model_state_dict = files["classy_state_dict"]["base_model"]["model"] state_dict = model_state_dict["trunk"] model.load_state_dict(state_dict) return model.eval(), model_state_dict["heads"] # pretrained names_to_from_model_map["regnet-y-320-seer"] = partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch", lambda: FakeRegNetVisslWrapper(RegNetY32gf()), ) names_to_from_model_map["regnet-y-640-seer"] = partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch", lambda: FakeRegNetVisslWrapper(RegNetY64gf()), ) names_to_from_model_map["regnet-y-1280-seer"] = partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch", lambda: FakeRegNetVisslWrapper(RegNetY128gf()), ) names_to_from_model_map["regnet-y-10b-seer"] = partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch", lambda: FakeRegNetVisslWrapper( RegNet( RegNetParams( depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52) )), ) # IN1K finetuned names_to_from_model_map["regnet-y-320-seer-in1k"] = partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch", lambda: FakeRegNetVisslWrapper(RegNetY32gf()), ) names_to_from_model_map["regnet-y-640-seer-in1k"] = partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch", lambda: FakeRegNetVisslWrapper(RegNetY64gf()), ) names_to_from_model_map["regnet-y-1280-seer-in1k"] = partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch", lambda: FakeRegNetVisslWrapper(RegNetY128gf()), ) names_to_from_model_map["regnet-y-10b-seer-in1k"] = partial( load_using_classy_vision, "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch", lambda: FakeRegNetVisslWrapper( RegNet( RegNetParams( depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52) )), ) if model_name: convert_weight_and_push( model_name, names_to_from_model_map[model_name], names_to_ours_model_map[model_name], names_to_config[model_name], save_directory, push_to_hub, ) else: for model_name, config in names_to_config.items(): convert_weight_and_push( model_name, names_to_from_model_map[model_name], names_to_ours_model_map[model_name], config, save_directory, push_to_hub, ) return config, expected_shape
def create_markdown_model_card(model_id: str): """Creates rich Markdown wrapper for hub readme""" readme_url = hf_hub_url(model_id, filename="README.md") r = requests.get(readme_url) r.raise_for_status() return Markdown(r.text)
def _fetch_model(model_name) -> str: # core Flair models on Huggingface ModelHub huggingface_model_map = { "ner": "flair/ner-english", "ner-fast": "flair/ner-english-fast", "ner-ontonotes": "flair/ner-english-ontonotes", "ner-ontonotes-fast": "flair/ner-english-ontonotes-fast", # Large NER models, "ner-large": "flair/ner-english-large", "ner-ontonotes-large": "flair/ner-english-ontonotes-large", "de-ner-large": "flair/ner-german-large", "nl-ner-large": "flair/ner-dutch-large", "es-ner-large": "flair/ner-spanish-large", # Multilingual NER models "ner-multi": "flair/ner-multi", "multi-ner": "flair/ner-multi", "ner-multi-fast": "flair/ner-multi-fast", # English POS models "upos": "flair/upos-english", "upos-fast": "flair/upos-english-fast", "pos": "flair/pos-english", "pos-fast": "flair/pos-english-fast", # Multilingual POS models "pos-multi": "flair/upos-multi", "multi-pos": "flair/upos-multi", "pos-multi-fast": "flair/upos-multi-fast", "multi-pos-fast": "flair/upos-multi-fast", # English SRL models "frame": "flair/frame-english", "frame-fast": "flair/frame-english-fast", # English chunking models "chunk": "flair/chunk-english", "chunk-fast": "flair/chunk-english-fast", # Language-specific NER models "da-ner": "flair/ner-danish", "de-ner": "flair/ner-german", "de-ler": "flair/ner-german-legal", "de-ner-legal": "flair/ner-german-legal", "fr-ner": "flair/ner-french", "nl-ner": "flair/ner-dutch", } hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models" hu_model_map = { # English NER models "ner": "/".join([hu_path, "ner", "en-ner-conll03-v0.4.pt"]), "ner-pooled": "/".join([hu_path, "ner-pooled", "en-ner-conll03-pooled-v0.5.pt"]), "ner-fast": "/".join([hu_path, "ner-fast", "en-ner-fast-conll03-v0.4.pt"]), "ner-ontonotes": "/".join([hu_path, "ner-ontonotes", "en-ner-ontonotes-v0.4.pt"]), "ner-ontonotes-fast": "/".join([ hu_path, "ner-ontonotes-fast", "en-ner-ontonotes-fast-v0.4.pt" ]), # Multilingual NER models "ner-multi": "/".join([hu_path, "multi-ner", "quadner-large.pt"]), "multi-ner": "/".join([hu_path, "multi-ner", "quadner-large.pt"]), "ner-multi-fast": "/".join([hu_path, "multi-ner-fast", "ner-multi-fast.pt"]), # English POS models "upos": "/".join([hu_path, "upos", "en-pos-ontonotes-v0.4.pt"]), "upos-fast": "/".join([hu_path, "upos-fast", "en-upos-ontonotes-fast-v0.4.pt"]), "pos": "/".join([hu_path, "pos", "en-pos-ontonotes-v0.5.pt"]), "pos-fast": "/".join([hu_path, "pos-fast", "en-pos-ontonotes-fast-v0.5.pt"]), # Multilingual POS models "pos-multi": "/".join([hu_path, "multi-pos", "pos-multi-v0.1.pt"]), "multi-pos": "/".join([hu_path, "multi-pos", "pos-multi-v0.1.pt"]), "pos-multi-fast": "/".join([hu_path, "multi-pos-fast", "pos-multi-fast.pt"]), "multi-pos-fast": "/".join([hu_path, "multi-pos-fast", "pos-multi-fast.pt"]), # English SRL models "frame": "/".join([hu_path, "frame", "en-frame-ontonotes-v0.4.pt"]), "frame-fast": "/".join( [hu_path, "frame-fast", "en-frame-ontonotes-fast-v0.4.pt"]), # English chunking models "chunk": "/".join([hu_path, "chunk", "en-chunk-conll2000-v0.4.pt"]), "chunk-fast": "/".join( [hu_path, "chunk-fast", "en-chunk-conll2000-fast-v0.4.pt"]), # Danish models "da-pos": "/".join([hu_path, "da-pos", "da-pos-v0.1.pt"]), "da-ner": "/".join([hu_path, "NER-danish", "da-ner-v0.1.pt"]), # German models "de-pos": "/".join([hu_path, "de-pos", "de-pos-ud-hdt-v0.5.pt"]), "de-pos-tweets": "/".join([hu_path, "de-pos-tweets", "de-pos-twitter-v0.1.pt"]), "de-ner": "/".join([hu_path, "de-ner", "de-ner-conll03-v0.4.pt"]), "de-ner-germeval": "/".join([hu_path, "de-ner-germeval", "de-ner-germeval-0.4.1.pt"]), "de-ler": "/".join([hu_path, "de-ner-legal", "de-ner-legal.pt"]), "de-ner-legal": "/".join([hu_path, "de-ner-legal", "de-ner-legal.pt"]), # French models "fr-ner": "/".join([hu_path, "fr-ner", "fr-ner-wikiner-0.4.pt"]), # Dutch models "nl-ner": "/".join([hu_path, "nl-ner", "nl-ner-bert-conll02-v0.8.pt"]), "nl-ner-rnn": "/".join([hu_path, "nl-ner-rnn", "nl-ner-conll02-v0.5.pt"]), # Malayalam models "ml-pos": "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-xpos-model.pt", "ml-upos": "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-upos-model.pt", # Portuguese models "pt-pos-clinical": "/".join([ hu_path, "pt-pos-clinical", "pucpr-flair-clinical-pos-tagging-best-model.pt", ]), # Keyphase models "keyphrase": "/".join([hu_path, "keyphrase", "keyphrase-en-scibert.pt"]), "negation-speculation": "/".join([ hu_path, "negation-speculation", "negation-speculation-model.pt" ]), # Biomedical models "hunflair-paper-cellline": "/".join([ hu_path, "hunflair_smallish_models", "cellline", "hunflair-celline-v1.0.pt", ]), "hunflair-paper-chemical": "/".join([ hu_path, "hunflair_smallish_models", "chemical", "hunflair-chemical-v1.0.pt", ]), "hunflair-paper-disease": "/".join([ hu_path, "hunflair_smallish_models", "disease", "hunflair-disease-v1.0.pt", ]), "hunflair-paper-gene": "/".join([ hu_path, "hunflair_smallish_models", "gene", "hunflair-gene-v1.0.pt" ]), "hunflair-paper-species": "/".join([ hu_path, "hunflair_smallish_models", "species", "hunflair-species-v1.0.pt", ]), "hunflair-cellline": "/".join([ hu_path, "hunflair_smallish_models", "cellline", "hunflair-celline-v1.0.pt", ]), "hunflair-chemical": "/".join([ hu_path, "hunflair_allcorpus_models", "huner-chemical", "hunflair-chemical-full-v1.0.pt", ]), "hunflair-disease": "/".join([ hu_path, "hunflair_allcorpus_models", "huner-disease", "hunflair-disease-full-v1.0.pt", ]), "hunflair-gene": "/".join([ hu_path, "hunflair_allcorpus_models", "huner-gene", "hunflair-gene-full-v1.0.pt", ]), "hunflair-species": "/".join([ hu_path, "hunflair_allcorpus_models", "huner-species", "hunflair-species-full-v1.1.pt", ]), } cache_dir = Path("models") get_from_model_hub = False # check if model name is a valid local file if Path(model_name).exists(): model_path = model_name # check if model key is remapped to HF key - if so, print out information elif model_name in huggingface_model_map: # get mapped name hf_model_name = huggingface_model_map[model_name] # use mapped name instead model_name = hf_model_name get_from_model_hub = True # if not, check if model key is remapped to direct download location. If so, download model elif model_name in hu_model_map: model_path = cached_path(hu_model_map[model_name], cache_dir=cache_dir) # special handling for the taggers by the @redewiegergabe project (TODO: move to model hub) elif model_name == "de-historic-indirect": model_file = flair.cache_root / cache_dir / "indirect" / "final-model.pt" if not model_file.exists(): cached_path( "http://www.redewiedergabe.de/models/indirect.zip", cache_dir=cache_dir, ) unzip_file( flair.cache_root / cache_dir / "indirect.zip", flair.cache_root / cache_dir, ) model_path = str(flair.cache_root / cache_dir / "indirect" / "final-model.pt") elif model_name == "de-historic-direct": model_file = flair.cache_root / cache_dir / "direct" / "final-model.pt" if not model_file.exists(): cached_path( "http://www.redewiedergabe.de/models/direct.zip", cache_dir=cache_dir, ) unzip_file( flair.cache_root / cache_dir / "direct.zip", flair.cache_root / cache_dir, ) model_path = str(flair.cache_root / cache_dir / "direct" / "final-model.pt") elif model_name == "de-historic-reported": model_file = flair.cache_root / cache_dir / "reported" / "final-model.pt" if not model_file.exists(): cached_path( "http://www.redewiedergabe.de/models/reported.zip", cache_dir=cache_dir, ) unzip_file( flair.cache_root / cache_dir / "reported.zip", flair.cache_root / cache_dir, ) model_path = str(flair.cache_root / cache_dir / "reported" / "final-model.pt") elif model_name == "de-historic-free-indirect": model_file = flair.cache_root / cache_dir / "freeIndirect" / "final-model.pt" if not model_file.exists(): cached_path( "http://www.redewiedergabe.de/models/freeIndirect.zip", cache_dir=cache_dir, ) unzip_file( flair.cache_root / cache_dir / "freeIndirect.zip", flair.cache_root / cache_dir, ) model_path = str(flair.cache_root / cache_dir / "freeIndirect" / "final-model.pt") # for all other cases (not local file or special download location), use HF model hub else: get_from_model_hub = True # if not a local file, get from model hub if get_from_model_hub: hf_model_name = "pytorch_model.bin" revision = "main" if "@" in model_name: model_name_split = model_name.split("@") revision = model_name_split[-1] model_name = model_name_split[0] # use model name as subfolder if "/" in model_name: model_folder = model_name.split("/", maxsplit=1)[1] else: model_folder = model_name # Lazy import from huggingface_hub import cached_download, hf_hub_url url = hf_hub_url(model_name, revision=revision, filename=hf_model_name) try: model_path = cached_download( url=url, library_name="flair", library_version=flair.__version__, cache_dir=flair.cache_root / "models" / model_folder, ) except HTTPError: # output information log.error("-" * 80) log.error( f"ACHTUNG: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!" ) # log.error(f" - Error message: {e}") log.error( " -> Please check https://huggingface.co/models?filter=flair for all available models." ) log.error( " -> Alternatively, point to a model file on your local drive." ) log.error("-" * 80) Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid return model_path
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our ViT structure. """ # define default ViT configuration config = ViTConfig() base_model = False # dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size if vit_name[-5:] == "in21k": base_model = True config.patch_size = int(vit_name[-12:-10]) config.image_size = int(vit_name[-9:-6]) else: config.num_labels = 1000 repo_id = "datasets/huggingface/label-files" filename = "imagenet-1k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.patch_size = int(vit_name[-6:-4]) config.image_size = int(vit_name[-3:]) # size of the architecture if "deit" in vit_name: if vit_name[9:].startswith("tiny"): config.hidden_size = 192 config.intermediate_size = 768 config.num_hidden_layers = 12 config.num_attention_heads = 3 elif vit_name[9:].startswith("small"): config.hidden_size = 384 config.intermediate_size = 1536 config.num_hidden_layers = 12 config.num_attention_heads = 6 else: pass else: if vit_name[4:].startswith("small"): config.hidden_size = 768 config.intermediate_size = 2304 config.num_hidden_layers = 8 config.num_attention_heads = 8 elif vit_name[4:].startswith("base"): pass elif vit_name[4:].startswith("large"): config.hidden_size = 1024 config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 elif vit_name[4:].startswith("huge"): config.hidden_size = 1280 config.intermediate_size = 5120 config.num_hidden_layers = 32 config.num_attention_heads = 16 # load original model from timm timm_model = timm.create_model(vit_name, pretrained=True) timm_model.eval() # load state_dict of original model, remove and rename some keys state_dict = timm_model.state_dict() if base_model: remove_classification_head_(state_dict) rename_keys = create_rename_keys(config, base_model) for src, dest in rename_keys: rename_key(state_dict, src, dest) read_in_q_k_v(state_dict, config, base_model) # load HuggingFace model if vit_name[-5:] == "in21k": model = ViTModel(config).eval() else: model = ViTForImageClassification(config).eval() model.load_state_dict(state_dict) # Check outputs on an image, prepared by ViTFeatureExtractor/DeiTFeatureExtractor if "deit" in vit_name: feature_extractor = DeiTFeatureExtractor(size=config.image_size) else: feature_extractor = ViTFeatureExtractor(size=config.image_size) encoding = feature_extractor(images=prepare_img(), return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) if base_model: timm_pooled_output = timm_model.forward_features(pixel_values) assert timm_pooled_output.shape == outputs.pooler_output.shape assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3) else: timm_logits = timm_model(pixel_values) assert timm_logits.shape == outputs.logits.shape assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model {vit_name} to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) print(f"Saving feature extractor to {pytorch_dump_folder_path}") feature_extractor.save_pretrained(pytorch_dump_folder_path)
from rudalle.pipelines import generate_images, show from rudalle import get_rudalle_model, get_tokenizer, get_vae from rudalle.utils import seed_everything from huggingface_hub import hf_hub_url, cached_download import os import torch has_gpu = torch.cuda.is_available() fp16 = True if has_gpu else False # load models model_filename = "pytorch_model.bin" device = "cuda" if has_gpu else "cpu" cache_dir = os.getenv("cache_dir", "../../models") config_file_url = hf_hub_url(repo_id="minimaxir/ai-generated-pokemon-rudalle", filename=model_filename) cached_download(config_file_url, cache_dir=cache_dir, force_filename=model_filename) model = get_rudalle_model('Malevich', cache_dir=cache_dir, pretrained=False, fp16=fp16, device=device) model.load_state_dict( torch.load(os.path.join(cache_dir, model_filename), map_location='cpu')) vae = get_vae().to(device) tokenizer = get_tokenizer() # generate images_per_row = 4