def checksum(self, download_path): """ Checksum on a given file. Args: download_path (string): path to the downloaded file. """ if self._hashcode is None: print(f"[ Checksum not provided, skipping for {self._file_name}]") return sha256_hash = hashlib.sha256() destination = os.path.join(download_path, self._file_name) if not PathManager.isfile(destination): # File is not present, nothing to checksum return with PathManager.open(destination, "rb") as f: print(f"[ Starting checksum for {self._file_name}]") for byte_block in iter(lambda: f.read(65536), b""): sha256_hash.update(byte_block) if sha256_hash.hexdigest() != self._hashcode: # remove_dir(download_path) raise AssertionError( f"[ Checksum for {self._file_name} from \n{self._url}\n" "does not match the expected checksum. Please try again. ]" ) else: print(f"[ Checksum successful for {self._file_name}]")
def download_from_google_drive(gd_id, destination, redownload=True): """ Use the requests package to download a file from Google Drive. """ download = not PathManager.isfile(destination) or redownload URL = "https://docs.google.com/uc?export=download" if not download: return download else: # Check first if link is live check_header(gd_id, from_google=True) with requests.Session() as session: response = session.get(URL, params={"id": gd_id}, stream=True) token = _get_confirm_token(response) if token: response.close() params = {"id": gd_id, "confirm": token} response = session.get(URL, params=params, stream=True) CHUNK_SIZE = 32768 with PathManager.open(destination, "wb") as f: for chunk in response.iter_content(CHUNK_SIZE): if chunk: # filter out keep-alive new chunks f.write(chunk) response.close() return download
def load_pretrained_model(model_name_or_path_or_checkpoint, *args, **kwargs): # If this is a file, then load this directly else download and load if PathManager.isfile(model_name_or_path_or_checkpoint): return _load_pretrained_checkpoint(model_name_or_path_or_checkpoint, args, kwargs) else: return _load_pretrained_model(model_name_or_path_or_checkpoint, args, kwargs)
def get_possible_image_paths(path): image_path = path.split(".") # Image path might contain file extension (e.g. .jpg), # In this case, we want the path without the extension image_path = image_path if len(image_path) == 1 else image_path[:-1] for ext in tv_helpers.IMG_EXTENSIONS: image_ext = ".".join(image_path) + ext if PathManager.isfile(image_ext): path = image_ext break return path
def built(path, version_string=None): """ Check if '.built' flag has been set for that task. If a version_string is provided, this has to match, or the version is regarded as not built. Version_string are generally the dataset version + the date the file was last updated. If this doesn't match, dataset will be mark not built. This makes sure that if we update our features or anything else features are updated for the end user. """ if version_string: fname = os.path.join(path, ".built.json") if not PathManager.isfile(fname): return False else: with PathManager.open(fname, "r") as read: text = json.load(read) return text.get("version", None) == version_string else: return PathManager.isfile(os.path.join(path, ".built.json"))
def load_yaml(f): # Convert to absolute path for loading includes abs_f = get_absolute_path(f) try: mapping = OmegaConf.load(abs_f) f = abs_f except FileNotFoundError as e: # Check if this file might be relative to root? # TODO: Later test if this can be removed relative = os.path.abspath(os.path.join(get_multimodelity_root(), f)) if not PathManager.isfile(relative): raise e else: f = relative mapping = OmegaConf.load(f) if mapping is None: mapping = OmegaConf.create() includes = mapping.get("includes", []) if not isinstance(includes, collections.abc.Sequence): raise AttributeError("Includes must be a list, {} provided".format( type(includes))) include_mapping = OmegaConf.create() multimodelity_root_dir = get_multimodelity_root() for include in includes: original_include_path = include include = os.path.join(multimodelity_root_dir, include) # If path doesn't exist relative to multimodelity root, try relative to current file if not PathManager.exists(include): include = os.path.join(os.path.dirname(f), original_include_path) current_include_mapping = load_yaml(include) include_mapping = OmegaConf.merge(include_mapping, current_include_mapping) mapping.pop("includes", None) mapping = OmegaConf.merge(include_mapping, mapping) return mapping
def from_pretrained(cls, model_name_or_path, *args, **kwargs): if not PathManager.isfile(model_name_or_path): model_key = model_name_or_path.split(".")[0] model_cls = registry.get_model_class(model_key) assert (model_cls == cls ), f"Incorrect pretrained model key {model_name_or_path} " "for class {cls.__name__}" output = load_pretrained_model(model_name_or_path, *args, **kwargs) config, checkpoint = output["config"], output["checkpoint"] # Some models need registry updates to be load pretrained model # If they have this method, call it so they can update accordingly if hasattr(cls, "update_registry_for_pretrained"): cls.update_registry_for_pretrained(config, checkpoint, output) instance = cls(config) instance.is_pretrained = True instance.build() incompatible_keys = instance.load_state_dict(checkpoint, strict=False) if len(incompatible_keys.missing_keys) != 0: logger.warning( f"Missing keys {incompatible_keys.missing_keys} in the" + " checkpoint.\n" + "If this is not your checkpoint, please open up an " + "issue on multimodelity GitHub. \n" + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}") if len(incompatible_keys.unexpected_keys) != 0: logger.warning( "Unexpected keys in state dict: " + f"{incompatible_keys.unexpected_keys} \n" + "This is usually not a problem with pretrained models, but " + "if this is your own model, please double check. \n" + "If you think this is an issue, please open up a " + "bug at multimodelity GitHub.") instance.eval() return instance
def remove(self, update): ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) if PathManager.isfile(ckpt_filepath): PathManager.rm(ckpt_filepath)
def download(url, path, fname, redownload=True, disable_tqdm=False): """ Download file using `requests`. If ``redownload`` is set to false, then will not download tar file again if it is present (default ``True``). Returns whether download actually happened or not """ outfile = os.path.join(path, fname) download = not PathManager.isfile(outfile) or redownload retry = 5 exp_backoff = [2**r for r in reversed(range(retry))] pbar = None if download: # First test if the link is actually downloadable check_header(url) if not disable_tqdm: print("[ Downloading: " + url + " to " + outfile + " ]") pbar = tqdm.tqdm(unit="B", unit_scale=True, desc=f"Downloading {fname}", disable=disable_tqdm) while download and retry >= 0: resume_file = outfile + ".part" resume = PathManager.isfile(resume_file) if resume: resume_pos = os.path.getsize(resume_file) mode = "ab" else: resume_pos = 0 mode = "wb" response = None with requests.Session() as session: try: header = ({ "Range": "bytes=%d-" % resume_pos, "Accept-Encoding": "identity" } if resume else {}) response = session.get(url, stream=True, timeout=5, headers=header) # negative reply could be 'none' or just missing if resume and response.headers.get("Accept-Ranges", "none") == "none": resume_pos = 0 mode = "wb" CHUNK_SIZE = 32768 total_size = int(response.headers.get("Content-Length", -1)) # server returns remaining size if resuming, so adjust total total_size += resume_pos pbar.total = total_size done = resume_pos with PathManager.open(resume_file, mode) as f: for chunk in response.iter_content(CHUNK_SIZE): if chunk: # filter out keep-alive new chunks f.write(chunk) if total_size > 0: done += len(chunk) if total_size < done: # don't freak out if content-length was too small total_size = done pbar.total = total_size pbar.update(len(chunk)) break except ( requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout, ): retry -= 1 pbar.clear() if retry >= 0: print("Connection error, retrying. (%d retries left)" % retry) time.sleep(exp_backoff[retry]) else: print("Retried too many times, stopped retrying.") finally: if response: response.close() if retry < 0: raise RuntimeWarning( "Connection broken too many times. Stopped retrying.") if download and retry > 0: pbar.update(done - pbar.n) if done < total_size: raise RuntimeWarning("Received less data than specified in " + "Content-Length header for " + url + ". There may be a download problem.") move(resume_file, outfile) if pbar: pbar.close() return download