def benchmark_iterate_hub_local_pytorch_setup( dataset_name, dataset_split, batch_size, prefetch_factor, num_workers=1 ): trans = transforms.Compose([transforms.ToTensor()]) data_path = os.path.join(".", "torch_data") dset_type = getattr(torchvision.datasets, dataset_name) path = os.path.join(".", "hub_data", "tfds") dset = dset_type( data_path, transform=trans, train=(False if "test" in dataset_split else None), download=True, ) Dataset.from_pytorch(HubAdapter(dset)).store(path) dset = Dataset(path, cache=False, storage_cache=False, mode="r") loader = torch.utils.data.DataLoader( dset.to_pytorch(), batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, ) return (loader,)
def time_iter_hub_local_pytorch( dataset_info, batch_size=BATCH_SIZE, prefetch_factor=PREFETCH_SIZE, num_workers=NUM_WORKERS, process=None, ): mnist = prepare_torch_dataset(dataset_info) path = os.path.join(ROOT, "Hub_data", "torch") Dataset.from_pytorch(HubAdapter(mnist)).store(path) dset = Dataset(path, cache=False, storage_cache=False, mode="r") loader = torch.utils.data.DataLoader( dset.to_pytorch(), batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, ) with Timer("Hub (local) `.to_pytorch()`"): for image, label in loader: if process is not None: process(image, label)