Example #1
0
 def __init__(
     self, data, transform: Callable, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0
 ):
     """
     Args:
         data (Iterable): input data to load and transform to generate dataset for model.
         transform: transforms to execute operations on input data.
         cache_num: number of items to be cached. Default is `sys.maxsize`.
             will take the minimum of (cache_num, data_length x cache_rate, data_length).
         cache_rate: percentage of cached data in total, default is 1.0 (cache all).
             will take the minimum of (cache_num, data_length x cache_rate, data_length).
         num_workers: the number of worker threads to use.
             If 0 a single thread will be used. Default is 0.
     """
     if not isinstance(transform, Compose):
         transform = Compose(transform)
     super().__init__(data, transform)
     self.cache_num = min(cache_num, int(len(self) * cache_rate), len(self))
     if self.cache_num > 0:
         self._cache = [None] * self.cache_num
         if num_workers > 0:
             self._item_processed = 0
             self._thread_lock = threading.Lock()
             with ThreadPool(num_workers) as p:
                 p.map(
                     self._load_cache_item_thread,
                     [(i, data[i], transform.transforms) for i in range(self.cache_num)],
                 )
         else:
             for i in range(self.cache_num):
                 self._cache[i] = self._load_cache_item(data[i], transform.transforms)
                 progress_bar(i + 1, self.cache_num, "Load and cache transformed data: ")
Example #2
0
 def _load_cache_item_thread(self, args) -> None:
     i, item, transforms = args
     self._cache[i] = self._load_cache_item(item, transforms)
     with self._thread_lock:
         self._item_processed += 1
         progress_bar(self._item_processed, self.cache_num,
                      "Load and cache transformed data: ")
Example #3
0
 def _load_cache_item_thread(self, args: Tuple[int, Any, Sequence[Callable]]) -> None:
     """
     Args:
         args: tuple with contents (i, item, transforms).
             i: the index to load the cached item to.
             item: input item to load and transform to generate dataset for model.
             transforms: transforms to execute operations on input item.
     """
     i, item, transforms = args
     self._cache[i] = self._load_cache_item(item, transforms)
     with self._thread_lock:
         self._item_processed += 1
         progress_bar(self._item_processed, self.cache_num, "Load and cache transformed data: ")
Example #4
0
 def _process_hook(blocknum, blocksize, totalsize):
     progress_bar(blocknum * blocksize, totalsize,
                  f"Downloading {filepath.split('/')[-1]}:")
Example #5
0
def download_url(url: str,
                 filepath: str,
                 md5_value: Optional[str] = None) -> None:
    """
    Download file from specified URL link, support process bar and MD5 check.

    Args:
        url: source URL link to download file.
        filepath: target filepath to save the downloaded file.
        md5_value: expected MD5 value to validate the downloaded file.
            if None, skip MD5 validation.

    Raises:
        RuntimeError: When the MD5 validation of the ``filepath`` existing file fails.
        RuntimeError: When a network issue or denied permission prevents the
            file download from ``url`` to ``filepath``.
        URLError: See urllib.request.urlretrieve.
        HTTPError: See urllib.request.urlretrieve.
        ContentTooShortError: See urllib.request.urlretrieve.
        IOError: See urllib.request.urlretrieve.
        RuntimeError: When the MD5 validation of the ``url`` downloaded file fails.

    """
    if os.path.exists(filepath):
        if not check_md5(filepath, md5_value):
            raise RuntimeError(
                f"MD5 check of existing file failed: filepath={filepath}, expected MD5={md5_value}."
            )
        print(f"file {filepath} exists, skip downloading.")
        return

    if url.startswith("https://drive.google.com"):
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        gdown.download(url, filepath, quiet=False)
        if not os.path.exists(filepath):
            raise RuntimeError(
                f"Download of file from {url} to {filepath} failed due to network issue or denied permission."
            )
    elif url.startswith("https://msd-for-monai.s3-us-west-2.amazonaws.com"):
        block_size = 1024 * 1024
        tmp_file_path = filepath + ".part"
        first_byte = os.path.getsize(tmp_file_path) if os.path.exists(
            tmp_file_path) else 0
        file_size = -1

        try:
            file_size = int(urlopen(url).info().get("Content-Length", -1))
            progress_bar(index=first_byte, count=file_size)

            while first_byte < file_size:
                last_byte = first_byte + block_size if first_byte + block_size < file_size else file_size - 1

                req = Request(url)
                req.headers["Range"] = "bytes=%s-%s" % (first_byte, last_byte)
                data_chunk = urlopen(req, timeout=10).read()
                with open(tmp_file_path, "ab") as f:
                    f.write(data_chunk)
                progress_bar(index=last_byte, count=file_size)
                first_byte = last_byte + 1
        except IOError as e:
            logging.debug("IO Error - %s" % e)
        finally:
            if file_size == os.path.getsize(tmp_file_path):
                if md5_value and not check_md5(tmp_file_path, md5_value):
                    raise Exception(
                        "Error validating the file against its MD5 hash")
                shutil.move(tmp_file_path, filepath)
            elif file_size == -1:
                raise Exception(
                    "Error getting Content-Length from server: %s" % url)
    else:
        os.makedirs(os.path.dirname(filepath), exist_ok=True)

        def _process_hook(blocknum: int, blocksize: int, totalsize: int):
            progress_bar(blocknum * blocksize, totalsize,
                         f"Downloading {filepath.split('/')[-1]}:")

        try:
            urlretrieve(url, filepath, reporthook=_process_hook)
            print(f"\ndownloaded file: {filepath}.")
        except (URLError, HTTPError, ContentTooShortError, IOError) as e:
            print(f"download failed from {url} to {filepath}.")
            raise e

    if not check_md5(filepath, md5_value):
        raise RuntimeError(
            f"MD5 check of downloaded file failed: URL={url}, filepath={filepath}, expected MD5={md5_value}."
        )
Example #6
0
    net.eval()
    metric_vals = []

    # test our network using the validation dataset
    with torch.no_grad():
        for bimages, bsegs in val_loader:
            bimages = bimages.to(device)
            bsegs = bsegs.to(device)

            prediction = net(bimages)
            pred_metric = metric(prediction, bsegs)
            metric_vals.append(pred_metric.item())

    epoch_metrics.append((total_step, np.average(metric_vals)))

    progress_bar(epoch + 1, num_epochs, f"Validation Metric: {epoch_metrics[-1][1]:.3}")

#%%

#Graph the results

fig, ax = plt.subplots(1, 2, figsize=(20, 6))

ax[0].semilogy(*zip(*step_losses))
ax[0].set_title("Step Loss")

ax[1].plot(*zip(*epoch_metrics))
ax[1].set_title("Per-Step Validation Results")
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)