def load_data( *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, random_crop=False, random_flip=True, ): """ For a dataset, create a generator over (images, kwargs) pairs. Each images is an NCHW float tensor, and the kwargs dict contains zero or more keys, each of which map to a batched Tensor of their own. The kwargs dict can be used for class labels, in which case the key is "y" and the values are integer tensors of class labels. :param data_dir: a dataset directory. :param batch_size: the batch size of each returned pair. :param image_size: the size to which images are resized. :param class_cond: if True, include a "y" key in returned dicts for class label. If classes are not available and this is true, an exception will be raised. :param deterministic: if True, yield results in a deterministic order. :param random_crop: if True, randomly crop the images for augmentation. :param random_flip: if True, randomly flip the images for augmentation. """ if not data_dir: raise ValueError("unspecified data directory") all_files = _list_image_files_recursively(data_dir) classes = None if class_cond: # Assume classes are the first part of the filename, # before an underscore. class_names = [bf.basename(path).split("_")[0] for path in all_files] sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} classes = [sorted_classes[x] for x in class_names] dataset = ImageDataset( image_size, all_files, classes=classes, shard=MPI.COMM_WORLD.Get_rank(), num_shards=MPI.COMM_WORLD.Get_size(), random_crop=random_crop, random_flip=random_flip, ) if deterministic: loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True ) else: loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True ) while True: yield from loader
def test_create_local_intermediate_dirs(): contents = b"meow" with _get_temp_local_path() as path: dirpath = bf.dirname(path) with chdir(dirpath): for filepath in [ bf.join(dirpath, "dirname", "file.name"), bf.join("..", bf.basename(dirpath), "file.name"), "./file.name", "file.name", ]: with bf.BlobFile(filepath, "wb") as f: f.write(contents)
def test_cache_dir(ctx): cache_dir = tempfile.mkdtemp() contents = b"meow!" alternative_contents = b"purr!" with ctx() as path: with bf.BlobFile(path, mode="wb") as f: f.write(contents) with bf.LocalBlobFile(path, mode="rb", cache_dir=cache_dir) as f: assert f.read() == contents content_hash = hashlib.md5(contents).hexdigest() cache_path = bf.join(cache_dir, content_hash, bf.basename(path)) with open(cache_path, "rb") as f: assert f.read() == contents # alter the cached file to make sure we are not re-reading the remote file with open(cache_path, "wb") as f: f.write(alternative_contents) with bf.LocalBlobFile(path, mode="rb", cache_dir=cache_dir) as f: assert f.read() == alternative_contents
def main(): parser = argparse.ArgumentParser() parser.add_argument("--version", required=True) parser.add_argument("--for-real", action="store_true") args = parser.parse_args() with tempfile.TemporaryDirectory() as tmpdir: for filepath in bf.glob( f"gs://openai-procgen/builds/procgen-{args.version}-*.whl"): print(filepath) bf.copy(filepath, bf.join(tmpdir, bf.basename(filepath))) if args.for_real: options = [] else: options = ["--repository-url", "https://test.pypi.org/legacy/"] subprocess.run([ "python", "-m", "twine", "upload", *options, *bf.glob(bf.join(tmpdir, "*.whl")) ], check=True)
def test_basename(): testcases = [ ("/", ""), ("a/", ""), ("a", "a"), ("a/b", "b"), ("", ""), ("gs://a", ""), ("gs://a/", ""), ("gs://a/b/", ""), ("gs://a/b", "b"), ("gs://a/b/c/test.filename", "test.filename"), ("https://a.blob.core.windows.net/b", ""), ("https://a.blob.core.windows.net/b/", ""), ("https://a.blob.core.windows.net/b/c/", ""), ("https://a.blob.core.windows.net/b/c", "c"), ("https://a.blob.core.windows.net/b/c/test.filename", "test.filename"), ] for input_, desired_output in testcases: actual_output = bf.basename(input_) assert desired_output == actual_output