def main(): # Create dataset ds = Dataset( "davitb/pytorch_example", shape=(640, ), mode="w", schema={ "image": schema.Tensor((512, 512), dtype="float"), "label": schema.Tensor((512, 512), dtype="float"), }, ) # ds["image"][:] = 1 # ds["label"][:] = 2 # Load to pytorch ds = ds.to_pytorch() ds = torch.utils.data.DataLoader( ds, batch_size=8, num_workers=2, ) # Iterate for batch in ds: print(batch["image"], batch["label"])
def benchmark_iterate_hub_local_pytorch_setup( dataset_name, dataset_split, batch_size, prefetch_factor, num_workers=1 ): trans = transforms.Compose([transforms.ToTensor()]) data_path = os.path.join(".", "torch_data") dset_type = getattr(torchvision.datasets, dataset_name) path = os.path.join(".", "hub_data", "tfds") dset = dset_type( data_path, transform=trans, train=(False if "test" in dataset_split else None), download=True, ) Dataset.from_pytorch(HubAdapter(dset)).store(path) dset = Dataset(path, cache=False, storage_cache=False, mode="r") loader = torch.utils.data.DataLoader( dset.to_pytorch(), batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, ) return (loader,)
def benchmark(): schema = {"image": Tensor((256, 256, 3), dtype="uint8")} arr = (np.random.rand(256, 256, 3) * 100).astype("uint8") # ds = Dataset("s3://snark-test/superficial_dataset", mode="w", schema=schema, shape=(5000,)) # for i in tqdm(range(len(ds))): # ds["image", i] = arr # ds.close() ds = Dataset("s3://snark-test/superficial_dataset") tds = ds.to_pytorch() dl = torch.utils.data.DataLoader(tds, batch_size=32, num_workers=16) for i, b in enumerate(tqdm(dl)): pass
def benchmark_iterate_hub_pytorch_setup(dataset_name, batch_size, prefetch_factor, num_workers=1): dset = Dataset(dataset_name, cache=False, storage_cache=False, mode="r") loader = torch.utils.data.DataLoader( dset.to_pytorch(), batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, ) return (loader, )
def time_iter_hub_wasabi_pytorch( dataset_info, batch_size=BATCH_SIZE, prefetch_factor=PREFETCH_SIZE, num_workers=NUM_WORKERS, process=None, ): dset = Dataset(dataset_info["hub_name"], cache=False, storage_cache=False, mode="r") loader = torch.utils.data.DataLoader( dset.to_pytorch(), batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, ) with Timer("Hub (remote - Wasabi) `.to_pytorch()`"): for image, label in loader: if process is not None: process(image, label)
def time_iter_pytorch(dataset_name="activeloop/mnist", batch_size=1, prefetch_factor=0, process=None): dset = Dataset(dataset_name, cache=False, storage_cache=False, mode="r") loader = torch.utils.data.DataLoader( dset.to_pytorch(), batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=1, ) with Timer( f"{dataset_name} PyTorch prefetch {prefetch_factor:03} in batches of {batch_size:03}" ): for idx, (image, label) in enumerate(loader): if process is not None: process(idx, image, label)
def time_iter_hub_local_pytorch( dataset_info, batch_size=BATCH_SIZE, prefetch_factor=PREFETCH_SIZE, num_workers=NUM_WORKERS, process=None, ): mnist = prepare_torch_dataset(dataset_info) path = os.path.join(ROOT, "Hub_data", "torch") Dataset.from_pytorch(HubAdapter(mnist)).store(path) dset = Dataset(path, cache=False, storage_cache=False, mode="r") loader = torch.utils.data.DataLoader( dset.to_pytorch(), batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, ) with Timer("Hub (local) `.to_pytorch()`"): for image, label in loader: if process is not None: process(image, label)