Exemplo n.º 1
0
def main(_):
    import tensorflow as tf

    from shape_tfds.shape.shapenet import core

    tf.compat.v1.enable_eager_execution()
    FLAGS = flags.FLAGS
    ids, names = core.load_synset_ids()
    name = FLAGS.synset

    synset_id = name if name in names else ids[name]
    if synset_id not in names:
        raise ValueError("Invalid synset_id %s" % synset_id)

    renderer = core.Renderer(FLAGS.renderer, resolution=(FLAGS.resolution,) * 2,)

    builder = core.ShapenetCore(
        config=core.RenderingConfig(
            synset_id=synset_id, renderer=renderer, seed=FLAGS.seed
        ),
        from_cache=FLAGS.from_cache,
    )
    builder.download_and_prepare()

    if FLAGS.vis:

        def vis(example):
            import matplotlib.pyplot as plt

            plt.imshow(example["image"].numpy())
            plt.show()

        dataset = builder.as_dataset(split="train")
        for example in dataset:
            vis(example)
Exemplo n.º 2
0
def main(_):
    tf.compat.v1.enable_eager_execution()
    FLAGS = flags.FLAGS
    ids, names = core.load_synset_ids()
    name = FLAGS.synset

    synset_id = name if name in names else ids[name]
    if synset_id not in names:
        raise ValueError("Invalid synset_id %s" % synset_id)

    builder = core.ShapenetCore(config=core.FrustumVoxelConfig(
        synset_id=synset_id, resolution=FLAGS.resolution, seed=FLAGS.seed))
    builder.download_and_prepare()

    if FLAGS.vis:

        def vis(example):
            import matplotlib.pyplot as plt
            from mpl_toolkits.mplot3d import Axes3D  # pylint: disable=unused-import

            ax = plt.gca(projection="3d")
            ax.voxels(example["voxels"].numpy())
            # ax.axis("square")
            plt.show()

        dataset = builder.as_dataset(split="train")
        for example in dataset:
            vis(example)
Exemplo n.º 3
0
def main(_):
    tf.compat.v1.enable_eager_execution()
    FLAGS = flags.FLAGS
    ids, names = core.load_synset_ids()
    name = FLAGS.synset
    seed = FLAGS.seed

    synset_id = name if name in names else ids[name]
    if synset_id not in names:
        raise ValueError("Invalid synset_id %s" % synset_id)

    configs = dict(
        image=core.TrimeshRenderingConfig(synset_id=synset_id,
                                          resolution=(FLAGS.image_res, ) * 2,
                                          seed=seed),
        voxels=core.FrustumVoxelConfig(synset_id=synset_id,
                                       resolution=FLAGS.vox_res,
                                       seed=seed),
    )
    builders = {
        k: core.ShapenetCore(config=config)
        for k, config in configs.items()
    }
    for b in builders.values():
        b.download_and_prepare()

    if FLAGS.vis:

        def vis(example):
            import matplotlib.pyplot as plt

            image = example["image"].numpy()
            voxels = tf.reduce_any(example["voxels"], axis=-1)
            voxels = tf.image.resize(
                tf.expand_dims(tf.cast(voxels, tf.uint8), axis=-1),
                image.shape[:2],
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
            )
            voxels = tf.cast(tf.squeeze(voxels, axis=-1), tf.bool).numpy()
            # voxels = voxels.T
            # voxels = voxels[:, -1::-1]
            image[np.logical_not(voxels)] = 0
            plt.imshow(image)
            plt.show()

        datasets = {
            k: b.as_dataset(split="train",
                            shuffle_files=False).map(lambda x: x[k])
            for k, b in builders.items()
        }
        dataset = tf.data.Dataset.zip(datasets)
        for example in dataset:
            vis(example)
Exemplo n.º 4
0
    # ax.axis("square")
    plt.show()


synset_id = ids[synset_name]

configs = dict(
    image=core.TrimeshRenderingConfig(
        synset_id=synset_id,
        resolution=resolution,
        view_fn=core.views.random_view_fn(seed_offset),
    ),
    voxels=core.VoxelConfig(synset_id, resolution=32),
)
builders = {
    k: core.ShapenetCore(config=config)
    for k, config in configs.items()
}
for b in builders.values():
    b.download_and_prepare()

datasets = {
    k: b.as_dataset(split="train", shuffle_files=False).map(lambda x: x[k])
    for k, b in builders.items()
}

dataset = tf.data.Dataset.zip(datasets)

for example in dataset:
    image = example["image"].numpy()
    voxels = example["voxels"].numpy()
Exemplo n.º 5
0
import os

import tensorflow as tf

from shape_tfds.shape.shapenet import core

tf.compat.v1.enable_eager_execution()

resolution = 32
split = "train"
names = ("suitcase", "telephone", "table")
ids, _ = core.load_synset_ids()

configs = (core.VoxelConfig(ids[n], resolution) for n in names)
builders = tuple(core.ShapenetCore(config=config) for config in configs)

all_fns = []
for builder in builders:
    builder.download_and_prepare()
    prefix = "%s-%s" % (builder.name.split("/")[0], split)
    data_dir = builder.data_dir
    record_fns = tuple(
        os.path.join(data_dir, fn) for fn in os.listdir(data_dir)
        if fn.startswith(prefix))
    all_fns.extend(record_fns)

print(all_fns)
builder = builders[0]
dataset = tf.data.TFRecordDataset(all_fns, num_parallel_reads=len(all_fns))
dataset = dataset.map(builder._file_format_adapter._parser.parse_example)