def test_download_cache_hit(mocker):
    """Check that download is not repeated on cache hit."""
    data = b"Hello, world"
    data_checksum = "4ae7c3b6ac0beff671efa8cf57386151c06e58ca53a78d83f36107316cec125f"
    cached_path = cache_path(f"downloads/{data_checksum}")

    # Tidy up from a previous test, if applicable.
    if cached_path.is_file():
        cached_path.unlink()

    def patched_download(*args):
        return data

    mocker.patch.object(download, "_get_url_data", patched_download)
    mocker.spy(download, "_get_url_data")

    assert (
        download.download(
            "example",
            sha256="4ae7c3b6ac0beff671efa8cf57386151c06e58ca53a78d83f36107316cec125f",
        )
        == data
    )
    download._get_url_data.assert_called_once_with("example")
    assert cached_path.is_file()

    # Cache hit.
    assert (
        download.download(
            "example",
            sha256="4ae7c3b6ac0beff671efa8cf57386151c06e58ca53a78d83f36107316cec125f",
        )
        == data
    )
    assert download._get_url_data.call_count == 1
示例#2
0
def test_download_mismatched_checksum(mocker):
    """Check that error is raised when checksum does not match expected."""
    def patched_download(*args):
        return b"Hello, world"

    mocker.patch.object(download, "_get_url_data", patched_download)

    with pytest.raises(DownloadFailed,
                       match="Checksum of download does not match"):
        download.download("example", sha256="123")
示例#3
0
def test_download_failed_retry_loop(mocker, max_retries: int):
    """Check that download attempts are repeated without sleep() on error."""
    def patched_download(*args):
        raise DownloadFailed

    mocker.patch.object(download, "sleep")
    mocker.patch.object(download, "_do_download_attempt", patched_download)
    mocker.spy(download, "_do_download_attempt")

    with pytest.raises(DownloadFailed):
        download.download(urls="example", max_retries=max_retries)

    assert download._do_download_attempt.call_count == max_retries
    assert download.sleep.call_count == 0
示例#4
0
    def preprocess_poj104_source(src: str) -> str:
        """Pre-process a POJ-104 C++ source file for compilation."""
        # Clean up declaration of main function. Many are missing a return type
        # declaration, or use an incorrect void return type.
        src = src.replace("void main", "int main")
        src = src.replace("\nmain", "int main")
        if src.startswith("main"):
            src = f"int {src}"

        # Pull in the standard library.
        if sys.platform == "linux":
            header = "#include <bits/stdc++.h>\n" "using namespace std;\n"
        else:
            # Download a bits/stdc++ implementation for macOS.
            header = download(
                "https://raw.githubusercontent.com/tekfyl/bits-stdc-.h-for-mac/e1193f4470514d82ea19c3cc1357116fadaa2a4e/stdc%2B%2B.h",
                sha256="b4d9b031d56d89a2b58b5ed80fa9943aa92420d6aed0835747c9a5584469afeb",
            ).decode("utf-8")

        # These defines provide values for commonly undefined symbols. Defining
        # these macros increases the number of POJ-104 programs that compile
        # from 49,302 to 49,821 (+519) on linux.
        defines = "#define LEN 128\n" "#define MAX_LENGTH 1024\n" "#define MAX 1024\n"

        return header + defines + src
示例#5
0
def download_and_unpack_database(db: str, sha256: str) -> Path:
    """Download the given database, unpack it to the local filesystem, and
    return the path.
    """
    local_dir = cache_path(f"state_transition_dataset/{sha256}")
    with _DB_DOWNLOAD_LOCK, InterProcessLock(
        transient_cache_path(".state_transition_database_download.LOCK")
    ):
        if not (local_dir / ".installed").is_file():
            tar_data = io.BytesIO(download(db, sha256))

            local_dir.mkdir(parents=True, exist_ok=True)
            logger.info("Unpacking database to %s ...", local_dir)
            with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc:
                arc.extractall(str(local_dir))

            (local_dir / ".installed").touch()

    unpacked = [f for f in local_dir.iterdir() if f.name != ".installed"]
    if len(unpacked) != 1:
        print(
            f"fatal: Archive {db} expected to contain one file, contains: {len(unpacked)}",
            file=sys.stderr,
        )

    return unpacked[0]
def test_download_timeout_retry_loop(mocker, max_retries: int):
    """Check that download attempts are repeated with sleep() on error."""

    def patched_download(*args):
        raise download.TooManyRequests

    mocker.patch.object(download, "sleep")
    mocker.patch.object(download, "_do_download_attempt", patched_download)
    mocker.spy(download, "_do_download_attempt")

    with pytest.raises(download.TooManyRequests):
        download.download(urls="example", max_retries=max_retries)

    assert download._do_download_attempt.call_count == max_retries
    assert download.sleep.call_count == max_retries
    starting_wait_time = 10  # The initial wait time in seconds.
    download.sleep.assert_called_with(starting_wait_time * 1.5 ** (max_retries - 1))
示例#7
0
def download_cBench_runtime_data() -> bool:
    """Download and unpack the cBench runtime dataset."""
    if _CBENCH_DATA.is_dir():
        return False
    else:
        tar_contents = io.BytesIO(
            download(_CBENCH_DATA_URL, sha256=_CBENCH_DATA_SHA256))
        with tarfile.open(fileobj=tar_contents, mode="r:bz2") as tar:
            _CBENCH_DATA.parent.mkdir(parents=True)
            tar.extractall(_CBENCH_DATA.parent)
        assert _CBENCH_DATA.is_dir()
        return True
示例#8
0
    def install(self):
        super().install()

        if not self._opencl_installed:
            self._opencl_installed = self._opencl_headers_installed_marker.is_file(
            )

        if self._opencl_installed:
            return

        with _CLGEN_INSTALL_LOCK, InterProcessLock(self._tar_lockfile):
            # Repeat install check now that we are in the locked region.
            if self._opencl_headers_installed_marker.is_file():
                return

            # Download the libclc headers.
            shutil.rmtree(self.libclc_dir, ignore_errors=True)
            logger.info("Downloading OpenCL headers ...")
            tar_data = io.BytesIO(
                download(
                    "https://dl.fbaipublicfiles.com/compiler_gym/libclc-v0.tar.bz2",
                    sha256=
                    "f1c511f2ac12adf98dcc0fbfc4e09d0f755fa403c18f1fb1ffa5547e1fa1a499",
                ))
            with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc:
                arc.extractall(str(self.site_data_path / "libclc"))

            # Download the OpenCL header.
            with open(self.opencl_h_path, "wb") as f:
                f.write(
                    download(
                        "https://github.com/ChrisCummins/clgen/raw/463c0adcd8abcf2432b24df0aca594b77a69e9d3/deeplearning/clgen/data/include/opencl.h",
                        sha256=
                        "f95b9f4c8b1d09114e491846d0d41425d24930ac167e024f45dab8071d19f3f7",
                    ))

            self._opencl_headers_installed_marker.touch()
示例#9
0
 def download_and_unpack_archive(url: str, sha256: Optional[str] = None) -> Dataset:
     json_files_before = {
         f
         for f in env.inactive_datasets_site_path.iterdir()
         if f.is_file() and f.name.endswith(".json")
     }
     tar_data = io.BytesIO(download(url, sha256))
     with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc:
         arc.extractall(str(env.inactive_datasets_site_path))
     json_files_after = {
         f
         for f in env.inactive_datasets_site_path.iterdir()
         if f.is_file() and f.name.endswith(".json")
     }
     new_json = json_files_after - json_files_before
     if not len(new_json):
         raise OSError(f"Downloaded dataset {url} contains no metadata JSON file")
     return Dataset.from_json_file(list(new_json)[0])
示例#10
0
def _download_llvm_files(destination: Path) -> Path:
    """Download and unpack the LLVM data pack."""
    logger.warning(
        "Installing the CompilerGym LLVM environment runtime. This may take a few moments ..."
    )

    # Tidy up an incomplete unpack.
    shutil.rmtree(destination, ignore_errors=True)

    tar_contents = io.BytesIO(download(_LLVM_URL, sha256=_LLVM_SHA256))
    destination.parent.mkdir(parents=True, exist_ok=True)
    with tarfile.open(fileobj=tar_contents, mode="r:bz2") as tar:
        tar.extractall(destination)

    assert destination.is_dir()
    assert (destination / "LICENSE").is_file()

    return destination
示例#11
0
def download_cBench_runtime_data() -> bool:
    """Download and unpack the cBench runtime dataset."""
    cbench_data = site_data_path("llvm-v0/cbench-v1-runtime-data/runtime_data")
    if (cbench_data / "unpacked").is_file():
        return False
    else:
        # Clean up any partially-extracted data directory.
        if cbench_data.is_dir():
            shutil.rmtree(cbench_data)

        url, sha256 = _CBENCH_RUNTOME_DATA
        tar_contents = io.BytesIO(download(url, sha256))
        with tarfile.open(fileobj=tar_contents, mode="r:bz2") as tar:
            cbench_data.parent.mkdir(parents=True, exist_ok=True)
            tar.extractall(cbench_data.parent)
        assert cbench_data.is_dir()
        # Create the marker file to indicate that the directory is unpacked
        # and ready to go.
        (cbench_data / "unpacked").touch()
        return True
示例#12
0
    def _benchmark_uris(self) -> List[str]:
        """Fetch or download the URI list."""
        if self._manifest_path.is_file():
            return self._read_manifest_file()

        # Thread-level and process-level locks to prevent races.
        with _TAR_MANIFEST_INSTALL_LOCK, InterProcessLock(self._manifest_lockfile):
            # Now that we have acquired the lock, repeat the check, since
            # another thread may have downloaded the manifest.
            if self._manifest_path.is_file():
                return self._read_manifest_file()

            # Determine how to decompress the manifest data.
            decompressor = {
                "bz2": lambda compressed_data: bz2.BZ2File(compressed_data),
                "gz": lambda compressed_data: gzip.GzipFile(compressed_data),
            }.get(self.manifest_compression, None)
            if not decompressor:
                raise TypeError(
                    f"Unknown manifest compression: {self.manifest_compression}"
                )

            # Decompress the manifest data.
            logger.debug("Downloading %s manifest", self.name)
            manifest_data = io.BytesIO(
                download(self.manifest_urls, self.manifest_sha256)
            )
            with decompressor(manifest_data) as f:
                manifest_data = f.read()

            # Although we have exclusive-execution locks, we still need to
            # create the manifest atomically to prevent calls to _benchmark_uris
            # racing to read an incompletely written manifest.
            with atomic_file_write(self._manifest_path, fileobj=True) as f:
                f.write(manifest_data)

            uris = self._read_manifest(manifest_data.decode("utf-8"))
            logger.debug("Downloaded %s manifest, %d entries", self.name, len(uris))
            return uris
示例#13
0
    def install(self) -> None:
        super().install()

        if self.installed:
            return

        # Thread-level and process-level locks to prevent races.
        with _TAR_INSTALL_LOCK, InterProcessLock(self._tar_lockfile):
            # Repeat the check to see if we have already installed the
            # dataset now that we have acquired the lock.
            if self.installed:
                return

            # Remove any partially-completed prior extraction.
            shutil.rmtree(self.site_data_path / "contents", ignore_errors=True)

            logger.warning(
                "Installing the %s dataset. This may take a few moments ...", self.name
            )

            tar_data = io.BytesIO(download(self.tar_urls, self.tar_sha256))
            logger.info("Unpacking %s dataset to %s", self.name, self.site_data_path)
            with tarfile.open(
                fileobj=tar_data, mode=f"r:{self.tar_compression}"
            ) as arc:
                arc.extractall(str(self.site_data_path / "contents"))

            # We're done. The last thing we do is create the marker file to
            # signal to any other install() invocations that the dataset is
            # ready.
            self._tar_extracted_marker.touch()

        if self.strip_prefix and not self.dataset_root.is_dir():
            raise FileNotFoundError(
                f"Directory prefix '{self.strip_prefix}' not found in dataset '{self.name}'"
            )
def test_download_no_urls():
    with pytest.raises(ValueError, match="No URLs to download"):
        download.download(urls=[])