Ejemplo n.º 1
0
def load_data(
    *,
    data_dir,
    batch_size,
    image_size,
    class_cond=False,
    deterministic=False,
    random_crop=False,
    random_flip=True,
):
    """
    For a dataset, create a generator over (images, kwargs) pairs.

    Each images is an NCHW float tensor, and the kwargs dict contains zero or
    more keys, each of which map to a batched Tensor of their own.
    The kwargs dict can be used for class labels, in which case the key is "y"
    and the values are integer tensors of class labels.

    :param data_dir: a dataset directory.
    :param batch_size: the batch size of each returned pair.
    :param image_size: the size to which images are resized.
    :param class_cond: if True, include a "y" key in returned dicts for class
                       label. If classes are not available and this is true, an
                       exception will be raised.
    :param deterministic: if True, yield results in a deterministic order.
    :param random_crop: if True, randomly crop the images for augmentation.
    :param random_flip: if True, randomly flip the images for augmentation.
    """
    if not data_dir:
        raise ValueError("unspecified data directory")
    all_files = _list_image_files_recursively(data_dir)
    classes = None
    if class_cond:
        # Assume classes are the first part of the filename,
        # before an underscore.
        class_names = [bf.basename(path).split("_")[0] for path in all_files]
        sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
        classes = [sorted_classes[x] for x in class_names]
    dataset = ImageDataset(
        image_size,
        all_files,
        classes=classes,
        shard=MPI.COMM_WORLD.Get_rank(),
        num_shards=MPI.COMM_WORLD.Get_size(),
        random_crop=random_crop,
        random_flip=random_flip,
    )
    if deterministic:
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
        )
    else:
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
        )
    while True:
        yield from loader
Ejemplo n.º 2
0
def test_create_local_intermediate_dirs():
    contents = b"meow"
    with _get_temp_local_path() as path:
        dirpath = bf.dirname(path)
        with chdir(dirpath):
            for filepath in [
                    bf.join(dirpath, "dirname", "file.name"),
                    bf.join("..", bf.basename(dirpath), "file.name"),
                    "./file.name",
                    "file.name",
            ]:
                with bf.BlobFile(filepath, "wb") as f:
                    f.write(contents)
Ejemplo n.º 3
0
def test_cache_dir(ctx):
    cache_dir = tempfile.mkdtemp()
    contents = b"meow!"
    alternative_contents = b"purr!"
    with ctx() as path:
        with bf.BlobFile(path, mode="wb") as f:
            f.write(contents)
        with bf.LocalBlobFile(path, mode="rb", cache_dir=cache_dir) as f:
            assert f.read() == contents
        content_hash = hashlib.md5(contents).hexdigest()
        cache_path = bf.join(cache_dir, content_hash, bf.basename(path))
        with open(cache_path, "rb") as f:
            assert f.read() == contents
        # alter the cached file to make sure we are not re-reading the remote file
        with open(cache_path, "wb") as f:
            f.write(alternative_contents)
        with bf.LocalBlobFile(path, mode="rb", cache_dir=cache_dir) as f:
            assert f.read() == alternative_contents
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--version", required=True)
    parser.add_argument("--for-real", action="store_true")
    args = parser.parse_args()

    with tempfile.TemporaryDirectory() as tmpdir:
        for filepath in bf.glob(
                f"gs://openai-procgen/builds/procgen-{args.version}-*.whl"):
            print(filepath)
            bf.copy(filepath, bf.join(tmpdir, bf.basename(filepath)))
        if args.for_real:
            options = []
        else:
            options = ["--repository-url", "https://test.pypi.org/legacy/"]
        subprocess.run([
            "python", "-m", "twine", "upload", *options,
            *bf.glob(bf.join(tmpdir, "*.whl"))
        ],
                       check=True)
Ejemplo n.º 5
0
def test_basename():
    testcases = [
        ("/", ""),
        ("a/", ""),
        ("a", "a"),
        ("a/b", "b"),
        ("", ""),
        ("gs://a", ""),
        ("gs://a/", ""),
        ("gs://a/b/", ""),
        ("gs://a/b", "b"),
        ("gs://a/b/c/test.filename", "test.filename"),
        ("https://a.blob.core.windows.net/b", ""),
        ("https://a.blob.core.windows.net/b/", ""),
        ("https://a.blob.core.windows.net/b/c/", ""),
        ("https://a.blob.core.windows.net/b/c", "c"),
        ("https://a.blob.core.windows.net/b/c/test.filename", "test.filename"),
    ]
    for input_, desired_output in testcases:
        actual_output = bf.basename(input_)
        assert desired_output == actual_output