def __init__( self, data, transform: Callable, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0 ): """ Args: data (Iterable): input data to load and transform to generate dataset for model. transform: transforms to execute operations on input data. cache_num: number of items to be cached. Default is `sys.maxsize`. will take the minimum of (cache_num, data_length x cache_rate, data_length). cache_rate: percentage of cached data in total, default is 1.0 (cache all). will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. If 0 a single thread will be used. Default is 0. """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data, transform) self.cache_num = min(cache_num, int(len(self) * cache_rate), len(self)) if self.cache_num > 0: self._cache = [None] * self.cache_num if num_workers > 0: self._item_processed = 0 self._thread_lock = threading.Lock() with ThreadPool(num_workers) as p: p.map( self._load_cache_item_thread, [(i, data[i], transform.transforms) for i in range(self.cache_num)], ) else: for i in range(self.cache_num): self._cache[i] = self._load_cache_item(data[i], transform.transforms) progress_bar(i + 1, self.cache_num, "Load and cache transformed data: ")
def _load_cache_item_thread(self, args) -> None: i, item, transforms = args self._cache[i] = self._load_cache_item(item, transforms) with self._thread_lock: self._item_processed += 1 progress_bar(self._item_processed, self.cache_num, "Load and cache transformed data: ")
def _load_cache_item_thread(self, args: Tuple[int, Any, Sequence[Callable]]) -> None: """ Args: args: tuple with contents (i, item, transforms). i: the index to load the cached item to. item: input item to load and transform to generate dataset for model. transforms: transforms to execute operations on input item. """ i, item, transforms = args self._cache[i] = self._load_cache_item(item, transforms) with self._thread_lock: self._item_processed += 1 progress_bar(self._item_processed, self.cache_num, "Load and cache transformed data: ")
def _process_hook(blocknum, blocksize, totalsize): progress_bar(blocknum * blocksize, totalsize, f"Downloading {filepath.split('/')[-1]}:")
def download_url(url: str, filepath: str, md5_value: Optional[str] = None) -> None: """ Download file from specified URL link, support process bar and MD5 check. Args: url: source URL link to download file. filepath: target filepath to save the downloaded file. md5_value: expected MD5 value to validate the downloaded file. if None, skip MD5 validation. Raises: RuntimeError: When the MD5 validation of the ``filepath`` existing file fails. RuntimeError: When a network issue or denied permission prevents the file download from ``url`` to ``filepath``. URLError: See urllib.request.urlretrieve. HTTPError: See urllib.request.urlretrieve. ContentTooShortError: See urllib.request.urlretrieve. IOError: See urllib.request.urlretrieve. RuntimeError: When the MD5 validation of the ``url`` downloaded file fails. """ if os.path.exists(filepath): if not check_md5(filepath, md5_value): raise RuntimeError( f"MD5 check of existing file failed: filepath={filepath}, expected MD5={md5_value}." ) print(f"file {filepath} exists, skip downloading.") return if url.startswith("https://drive.google.com"): os.makedirs(os.path.dirname(filepath), exist_ok=True) gdown.download(url, filepath, quiet=False) if not os.path.exists(filepath): raise RuntimeError( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) elif url.startswith("https://msd-for-monai.s3-us-west-2.amazonaws.com"): block_size = 1024 * 1024 tmp_file_path = filepath + ".part" first_byte = os.path.getsize(tmp_file_path) if os.path.exists( tmp_file_path) else 0 file_size = -1 try: file_size = int(urlopen(url).info().get("Content-Length", -1)) progress_bar(index=first_byte, count=file_size) while first_byte < file_size: last_byte = first_byte + block_size if first_byte + block_size < file_size else file_size - 1 req = Request(url) req.headers["Range"] = "bytes=%s-%s" % (first_byte, last_byte) data_chunk = urlopen(req, timeout=10).read() with open(tmp_file_path, "ab") as f: f.write(data_chunk) progress_bar(index=last_byte, count=file_size) first_byte = last_byte + 1 except IOError as e: logging.debug("IO Error - %s" % e) finally: if file_size == os.path.getsize(tmp_file_path): if md5_value and not check_md5(tmp_file_path, md5_value): raise Exception( "Error validating the file against its MD5 hash") shutil.move(tmp_file_path, filepath) elif file_size == -1: raise Exception( "Error getting Content-Length from server: %s" % url) else: os.makedirs(os.path.dirname(filepath), exist_ok=True) def _process_hook(blocknum: int, blocksize: int, totalsize: int): progress_bar(blocknum * blocksize, totalsize, f"Downloading {filepath.split('/')[-1]}:") try: urlretrieve(url, filepath, reporthook=_process_hook) print(f"\ndownloaded file: {filepath}.") except (URLError, HTTPError, ContentTooShortError, IOError) as e: print(f"download failed from {url} to {filepath}.") raise e if not check_md5(filepath, md5_value): raise RuntimeError( f"MD5 check of downloaded file failed: URL={url}, filepath={filepath}, expected MD5={md5_value}." )
net.eval() metric_vals = [] # test our network using the validation dataset with torch.no_grad(): for bimages, bsegs in val_loader: bimages = bimages.to(device) bsegs = bsegs.to(device) prediction = net(bimages) pred_metric = metric(prediction, bsegs) metric_vals.append(pred_metric.item()) epoch_metrics.append((total_step, np.average(metric_vals))) progress_bar(epoch + 1, num_epochs, f"Validation Metric: {epoch_metrics[-1][1]:.3}") #%% #Graph the results fig, ax = plt.subplots(1, 2, figsize=(20, 6)) ax[0].semilogy(*zip(*step_losses)) ax[0].set_title("Step Loss") ax[1].plot(*zip(*epoch_metrics)) ax[1].set_title("Per-Step Validation Results") buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0)