def test_download_cache_hit(mocker): """Check that download is not repeated on cache hit.""" data = b"Hello, world" data_checksum = "4ae7c3b6ac0beff671efa8cf57386151c06e58ca53a78d83f36107316cec125f" cached_path = cache_path(f"downloads/{data_checksum}") # Tidy up from a previous test, if applicable. if cached_path.is_file(): cached_path.unlink() def patched_download(*args): return data mocker.patch.object(download, "_get_url_data", patched_download) mocker.spy(download, "_get_url_data") assert ( download.download( "example", sha256="4ae7c3b6ac0beff671efa8cf57386151c06e58ca53a78d83f36107316cec125f", ) == data ) download._get_url_data.assert_called_once_with("example") assert cached_path.is_file() # Cache hit. assert ( download.download( "example", sha256="4ae7c3b6ac0beff671efa8cf57386151c06e58ca53a78d83f36107316cec125f", ) == data ) assert download._get_url_data.call_count == 1
def test_download_mismatched_checksum(mocker): """Check that error is raised when checksum does not match expected.""" def patched_download(*args): return b"Hello, world" mocker.patch.object(download, "_get_url_data", patched_download) with pytest.raises(DownloadFailed, match="Checksum of download does not match"): download.download("example", sha256="123")
def test_download_failed_retry_loop(mocker, max_retries: int): """Check that download attempts are repeated without sleep() on error.""" def patched_download(*args): raise DownloadFailed mocker.patch.object(download, "sleep") mocker.patch.object(download, "_do_download_attempt", patched_download) mocker.spy(download, "_do_download_attempt") with pytest.raises(DownloadFailed): download.download(urls="example", max_retries=max_retries) assert download._do_download_attempt.call_count == max_retries assert download.sleep.call_count == 0
def preprocess_poj104_source(src: str) -> str: """Pre-process a POJ-104 C++ source file for compilation.""" # Clean up declaration of main function. Many are missing a return type # declaration, or use an incorrect void return type. src = src.replace("void main", "int main") src = src.replace("\nmain", "int main") if src.startswith("main"): src = f"int {src}" # Pull in the standard library. if sys.platform == "linux": header = "#include <bits/stdc++.h>\n" "using namespace std;\n" else: # Download a bits/stdc++ implementation for macOS. header = download( "https://raw.githubusercontent.com/tekfyl/bits-stdc-.h-for-mac/e1193f4470514d82ea19c3cc1357116fadaa2a4e/stdc%2B%2B.h", sha256="b4d9b031d56d89a2b58b5ed80fa9943aa92420d6aed0835747c9a5584469afeb", ).decode("utf-8") # These defines provide values for commonly undefined symbols. Defining # these macros increases the number of POJ-104 programs that compile # from 49,302 to 49,821 (+519) on linux. defines = "#define LEN 128\n" "#define MAX_LENGTH 1024\n" "#define MAX 1024\n" return header + defines + src
def download_and_unpack_database(db: str, sha256: str) -> Path: """Download the given database, unpack it to the local filesystem, and return the path. """ local_dir = cache_path(f"state_transition_dataset/{sha256}") with _DB_DOWNLOAD_LOCK, InterProcessLock( transient_cache_path(".state_transition_database_download.LOCK") ): if not (local_dir / ".installed").is_file(): tar_data = io.BytesIO(download(db, sha256)) local_dir.mkdir(parents=True, exist_ok=True) logger.info("Unpacking database to %s ...", local_dir) with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc: arc.extractall(str(local_dir)) (local_dir / ".installed").touch() unpacked = [f for f in local_dir.iterdir() if f.name != ".installed"] if len(unpacked) != 1: print( f"fatal: Archive {db} expected to contain one file, contains: {len(unpacked)}", file=sys.stderr, ) return unpacked[0]
def test_download_timeout_retry_loop(mocker, max_retries: int): """Check that download attempts are repeated with sleep() on error.""" def patched_download(*args): raise download.TooManyRequests mocker.patch.object(download, "sleep") mocker.patch.object(download, "_do_download_attempt", patched_download) mocker.spy(download, "_do_download_attempt") with pytest.raises(download.TooManyRequests): download.download(urls="example", max_retries=max_retries) assert download._do_download_attempt.call_count == max_retries assert download.sleep.call_count == max_retries starting_wait_time = 10 # The initial wait time in seconds. download.sleep.assert_called_with(starting_wait_time * 1.5 ** (max_retries - 1))
def download_cBench_runtime_data() -> bool: """Download and unpack the cBench runtime dataset.""" if _CBENCH_DATA.is_dir(): return False else: tar_contents = io.BytesIO( download(_CBENCH_DATA_URL, sha256=_CBENCH_DATA_SHA256)) with tarfile.open(fileobj=tar_contents, mode="r:bz2") as tar: _CBENCH_DATA.parent.mkdir(parents=True) tar.extractall(_CBENCH_DATA.parent) assert _CBENCH_DATA.is_dir() return True
def install(self): super().install() if not self._opencl_installed: self._opencl_installed = self._opencl_headers_installed_marker.is_file( ) if self._opencl_installed: return with _CLGEN_INSTALL_LOCK, InterProcessLock(self._tar_lockfile): # Repeat install check now that we are in the locked region. if self._opencl_headers_installed_marker.is_file(): return # Download the libclc headers. shutil.rmtree(self.libclc_dir, ignore_errors=True) logger.info("Downloading OpenCL headers ...") tar_data = io.BytesIO( download( "https://dl.fbaipublicfiles.com/compiler_gym/libclc-v0.tar.bz2", sha256= "f1c511f2ac12adf98dcc0fbfc4e09d0f755fa403c18f1fb1ffa5547e1fa1a499", )) with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc: arc.extractall(str(self.site_data_path / "libclc")) # Download the OpenCL header. with open(self.opencl_h_path, "wb") as f: f.write( download( "https://github.com/ChrisCummins/clgen/raw/463c0adcd8abcf2432b24df0aca594b77a69e9d3/deeplearning/clgen/data/include/opencl.h", sha256= "f95b9f4c8b1d09114e491846d0d41425d24930ac167e024f45dab8071d19f3f7", )) self._opencl_headers_installed_marker.touch()
def download_and_unpack_archive(url: str, sha256: Optional[str] = None) -> Dataset: json_files_before = { f for f in env.inactive_datasets_site_path.iterdir() if f.is_file() and f.name.endswith(".json") } tar_data = io.BytesIO(download(url, sha256)) with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc: arc.extractall(str(env.inactive_datasets_site_path)) json_files_after = { f for f in env.inactive_datasets_site_path.iterdir() if f.is_file() and f.name.endswith(".json") } new_json = json_files_after - json_files_before if not len(new_json): raise OSError(f"Downloaded dataset {url} contains no metadata JSON file") return Dataset.from_json_file(list(new_json)[0])
def _download_llvm_files(destination: Path) -> Path: """Download and unpack the LLVM data pack.""" logger.warning( "Installing the CompilerGym LLVM environment runtime. This may take a few moments ..." ) # Tidy up an incomplete unpack. shutil.rmtree(destination, ignore_errors=True) tar_contents = io.BytesIO(download(_LLVM_URL, sha256=_LLVM_SHA256)) destination.parent.mkdir(parents=True, exist_ok=True) with tarfile.open(fileobj=tar_contents, mode="r:bz2") as tar: tar.extractall(destination) assert destination.is_dir() assert (destination / "LICENSE").is_file() return destination
def download_cBench_runtime_data() -> bool: """Download and unpack the cBench runtime dataset.""" cbench_data = site_data_path("llvm-v0/cbench-v1-runtime-data/runtime_data") if (cbench_data / "unpacked").is_file(): return False else: # Clean up any partially-extracted data directory. if cbench_data.is_dir(): shutil.rmtree(cbench_data) url, sha256 = _CBENCH_RUNTOME_DATA tar_contents = io.BytesIO(download(url, sha256)) with tarfile.open(fileobj=tar_contents, mode="r:bz2") as tar: cbench_data.parent.mkdir(parents=True, exist_ok=True) tar.extractall(cbench_data.parent) assert cbench_data.is_dir() # Create the marker file to indicate that the directory is unpacked # and ready to go. (cbench_data / "unpacked").touch() return True
def _benchmark_uris(self) -> List[str]: """Fetch or download the URI list.""" if self._manifest_path.is_file(): return self._read_manifest_file() # Thread-level and process-level locks to prevent races. with _TAR_MANIFEST_INSTALL_LOCK, InterProcessLock(self._manifest_lockfile): # Now that we have acquired the lock, repeat the check, since # another thread may have downloaded the manifest. if self._manifest_path.is_file(): return self._read_manifest_file() # Determine how to decompress the manifest data. decompressor = { "bz2": lambda compressed_data: bz2.BZ2File(compressed_data), "gz": lambda compressed_data: gzip.GzipFile(compressed_data), }.get(self.manifest_compression, None) if not decompressor: raise TypeError( f"Unknown manifest compression: {self.manifest_compression}" ) # Decompress the manifest data. logger.debug("Downloading %s manifest", self.name) manifest_data = io.BytesIO( download(self.manifest_urls, self.manifest_sha256) ) with decompressor(manifest_data) as f: manifest_data = f.read() # Although we have exclusive-execution locks, we still need to # create the manifest atomically to prevent calls to _benchmark_uris # racing to read an incompletely written manifest. with atomic_file_write(self._manifest_path, fileobj=True) as f: f.write(manifest_data) uris = self._read_manifest(manifest_data.decode("utf-8")) logger.debug("Downloaded %s manifest, %d entries", self.name, len(uris)) return uris
def install(self) -> None: super().install() if self.installed: return # Thread-level and process-level locks to prevent races. with _TAR_INSTALL_LOCK, InterProcessLock(self._tar_lockfile): # Repeat the check to see if we have already installed the # dataset now that we have acquired the lock. if self.installed: return # Remove any partially-completed prior extraction. shutil.rmtree(self.site_data_path / "contents", ignore_errors=True) logger.warning( "Installing the %s dataset. This may take a few moments ...", self.name ) tar_data = io.BytesIO(download(self.tar_urls, self.tar_sha256)) logger.info("Unpacking %s dataset to %s", self.name, self.site_data_path) with tarfile.open( fileobj=tar_data, mode=f"r:{self.tar_compression}" ) as arc: arc.extractall(str(self.site_data_path / "contents")) # We're done. The last thing we do is create the marker file to # signal to any other install() invocations that the dataset is # ready. self._tar_extracted_marker.touch() if self.strip_prefix and not self.dataset_root.is_dir(): raise FileNotFoundError( f"Directory prefix '{self.strip_prefix}' not found in dataset '{self.name}'" )
def test_download_no_urls(): with pytest.raises(ValueError, match="No URLs to download"): download.download(urls=[])