Exemplo n.º 1
0
    def checksum(self, download_path):
        """
        Checksum on a given file.

        Args:
            download_path (string): path to the downloaded file.
        """
        if self._hashcode is None:
            print(f"[ Checksum not provided, skipping for {self._file_name}]")
            return

        sha256_hash = hashlib.sha256()
        destination = os.path.join(download_path, self._file_name)

        if not PathManager.isfile(destination):
            # File is not present, nothing to checksum
            return

        with PathManager.open(destination, "rb") as f:
            print(f"[ Starting checksum for {self._file_name}]")
            for byte_block in iter(lambda: f.read(65536), b""):
                sha256_hash.update(byte_block)
            if sha256_hash.hexdigest() != self._hashcode:
                # remove_dir(download_path)
                raise AssertionError(
                    f"[ Checksum for {self._file_name} from \n{self._url}\n"
                    "does not match the expected checksum. Please try again. ]"
                )
            else:
                print(f"[ Checksum successful for {self._file_name}]")
Exemplo n.º 2
0
def download_from_google_drive(gd_id, destination, redownload=True):
    """
    Use the requests package to download a file from Google Drive.
    """
    download = not PathManager.isfile(destination) or redownload

    URL = "https://docs.google.com/uc?export=download"

    if not download:
        return download
    else:
        # Check first if link is live
        check_header(gd_id, from_google=True)

    with requests.Session() as session:
        response = session.get(URL, params={"id": gd_id}, stream=True)
        token = _get_confirm_token(response)

        if token:
            response.close()
            params = {"id": gd_id, "confirm": token}
            response = session.get(URL, params=params, stream=True)

        CHUNK_SIZE = 32768
        with PathManager.open(destination, "wb") as f:
            for chunk in response.iter_content(CHUNK_SIZE):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)
        response.close()

    return download
Exemplo n.º 3
0
def load_pretrained_model(model_name_or_path_or_checkpoint, *args, **kwargs):
    # If this is a file, then load this directly else download and load
    if PathManager.isfile(model_name_or_path_or_checkpoint):
        return _load_pretrained_checkpoint(model_name_or_path_or_checkpoint,
                                           args, kwargs)
    else:
        return _load_pretrained_model(model_name_or_path_or_checkpoint, args,
                                      kwargs)
Exemplo n.º 4
0
def get_possible_image_paths(path):
    image_path = path.split(".")
    # Image path might contain file extension (e.g. .jpg),
    # In this case, we want the path without the extension
    image_path = image_path if len(image_path) == 1 else image_path[:-1]
    for ext in tv_helpers.IMG_EXTENSIONS:
        image_ext = ".".join(image_path) + ext
        if PathManager.isfile(image_ext):
            path = image_ext
            break
    return path
Exemplo n.º 5
0
def built(path, version_string=None):
    """
    Check if '.built' flag has been set for that task.

    If a version_string is provided, this has to match, or the version
    is regarded as not built.

    Version_string are generally the dataset version + the date the file was
    last updated. If this doesn't match, dataset will be mark not built. This makes
    sure that if we update our features or anything else features are updated
    for the end user.
    """
    if version_string:
        fname = os.path.join(path, ".built.json")
        if not PathManager.isfile(fname):
            return False
        else:
            with PathManager.open(fname, "r") as read:
                text = json.load(read)
            return text.get("version", None) == version_string
    else:
        return PathManager.isfile(os.path.join(path, ".built.json"))
Exemplo n.º 6
0
def load_yaml(f):
    # Convert to absolute path for loading includes
    abs_f = get_absolute_path(f)

    try:
        mapping = OmegaConf.load(abs_f)
        f = abs_f
    except FileNotFoundError as e:
        # Check if this file might be relative to root?
        # TODO: Later test if this can be removed
        relative = os.path.abspath(os.path.join(get_multimodelity_root(), f))
        if not PathManager.isfile(relative):
            raise e
        else:
            f = relative
            mapping = OmegaConf.load(f)

    if mapping is None:
        mapping = OmegaConf.create()

    includes = mapping.get("includes", [])

    if not isinstance(includes, collections.abc.Sequence):
        raise AttributeError("Includes must be a list, {} provided".format(
            type(includes)))

    include_mapping = OmegaConf.create()

    multimodelity_root_dir = get_multimodelity_root()

    for include in includes:
        original_include_path = include
        include = os.path.join(multimodelity_root_dir, include)

        # If path doesn't exist relative to multimodelity root, try relative to current file
        if not PathManager.exists(include):
            include = os.path.join(os.path.dirname(f), original_include_path)

        current_include_mapping = load_yaml(include)
        include_mapping = OmegaConf.merge(include_mapping,
                                          current_include_mapping)

    mapping.pop("includes", None)

    mapping = OmegaConf.merge(include_mapping, mapping)

    return mapping
Exemplo n.º 7
0
    def from_pretrained(cls, model_name_or_path, *args, **kwargs):
        if not PathManager.isfile(model_name_or_path):
            model_key = model_name_or_path.split(".")[0]
            model_cls = registry.get_model_class(model_key)
            assert (model_cls == cls
                    ), f"Incorrect pretrained model key {model_name_or_path} "
            "for class {cls.__name__}"
        output = load_pretrained_model(model_name_or_path, *args, **kwargs)
        config, checkpoint = output["config"], output["checkpoint"]
        # Some models need registry updates to be load pretrained model
        # If they have this method, call it so they can update accordingly
        if hasattr(cls, "update_registry_for_pretrained"):
            cls.update_registry_for_pretrained(config, checkpoint, output)

        instance = cls(config)
        instance.is_pretrained = True
        instance.build()
        incompatible_keys = instance.load_state_dict(checkpoint, strict=False)

        if len(incompatible_keys.missing_keys) != 0:
            logger.warning(
                f"Missing keys {incompatible_keys.missing_keys} in the" +
                " checkpoint.\n" +
                "If this is not your checkpoint, please open up an " +
                "issue on multimodelity GitHub. \n" +
                f"Unexpected keys if any: {incompatible_keys.unexpected_keys}")

        if len(incompatible_keys.unexpected_keys) != 0:
            logger.warning(
                "Unexpected keys in state dict: " +
                f"{incompatible_keys.unexpected_keys} \n" +
                "This is usually not a problem with pretrained models, but " +
                "if this is your own model, please double check. \n" +
                "If you think this is an issue, please open up a " +
                "bug at multimodelity GitHub.")

        instance.eval()

        return instance
Exemplo n.º 8
0
 def remove(self, update):
     ckpt_filepath = os.path.join(self.models_foldername,
                                  "model_%d.ckpt" % update)
     if PathManager.isfile(ckpt_filepath):
         PathManager.rm(ckpt_filepath)
Exemplo n.º 9
0
def download(url, path, fname, redownload=True, disable_tqdm=False):
    """
    Download file using `requests`.

    If ``redownload`` is set to false, then will not download tar file again if it is
    present (default ``True``).

    Returns whether download actually happened or not
    """
    outfile = os.path.join(path, fname)
    download = not PathManager.isfile(outfile) or redownload
    retry = 5
    exp_backoff = [2**r for r in reversed(range(retry))]

    pbar = None
    if download:
        # First test if the link is actually downloadable
        check_header(url)
        if not disable_tqdm:
            print("[ Downloading: " + url + " to " + outfile + " ]")
        pbar = tqdm.tqdm(unit="B",
                         unit_scale=True,
                         desc=f"Downloading {fname}",
                         disable=disable_tqdm)

    while download and retry >= 0:
        resume_file = outfile + ".part"
        resume = PathManager.isfile(resume_file)
        if resume:
            resume_pos = os.path.getsize(resume_file)
            mode = "ab"
        else:
            resume_pos = 0
            mode = "wb"
        response = None

        with requests.Session() as session:
            try:
                header = ({
                    "Range": "bytes=%d-" % resume_pos,
                    "Accept-Encoding": "identity"
                } if resume else {})
                response = session.get(url,
                                       stream=True,
                                       timeout=5,
                                       headers=header)

                # negative reply could be 'none' or just missing
                if resume and response.headers.get("Accept-Ranges",
                                                   "none") == "none":
                    resume_pos = 0
                    mode = "wb"

                CHUNK_SIZE = 32768
                total_size = int(response.headers.get("Content-Length", -1))
                # server returns remaining size if resuming, so adjust total
                total_size += resume_pos
                pbar.total = total_size
                done = resume_pos

                with PathManager.open(resume_file, mode) as f:
                    for chunk in response.iter_content(CHUNK_SIZE):
                        if chunk:  # filter out keep-alive new chunks
                            f.write(chunk)
                        if total_size > 0:
                            done += len(chunk)
                            if total_size < done:
                                # don't freak out if content-length was too small
                                total_size = done
                                pbar.total = total_size
                            pbar.update(len(chunk))
                    break
            except (
                    requests.exceptions.ConnectionError,
                    requests.exceptions.ReadTimeout,
            ):
                retry -= 1
                pbar.clear()
                if retry >= 0:
                    print("Connection error, retrying. (%d retries left)" %
                          retry)
                    time.sleep(exp_backoff[retry])
                else:
                    print("Retried too many times, stopped retrying.")
            finally:
                if response:
                    response.close()
    if retry < 0:
        raise RuntimeWarning(
            "Connection broken too many times. Stopped retrying.")

    if download and retry > 0:
        pbar.update(done - pbar.n)
        if done < total_size:
            raise RuntimeWarning("Received less data than specified in " +
                                 "Content-Length header for " + url +
                                 ". There may be a download problem.")
        move(resume_file, outfile)

    if pbar:
        pbar.close()

    return download