Ejemplo n.º 1
0
def main(_):
    import tensorflow as tf

    from shape_tfds.shape.shapenet import core

    tf.compat.v1.enable_eager_execution()
    FLAGS = flags.FLAGS
    ids, names = core.load_synset_ids()
    name = FLAGS.synset

    synset_id = name if name in names else ids[name]
    if synset_id not in names:
        raise ValueError("Invalid synset_id %s" % synset_id)

    renderer = core.Renderer(FLAGS.renderer, resolution=(FLAGS.resolution,) * 2,)

    builder = core.ShapenetCore(
        config=core.RenderingConfig(
            synset_id=synset_id, renderer=renderer, seed=FLAGS.seed
        ),
        from_cache=FLAGS.from_cache,
    )
    builder.download_and_prepare()

    if FLAGS.vis:

        def vis(example):
            import matplotlib.pyplot as plt

            plt.imshow(example["image"].numpy())
            plt.show()

        dataset = builder.as_dataset(split="train")
        for example in dataset:
            vis(example)
Ejemplo n.º 2
0
def main(_):
    tf.compat.v1.enable_eager_execution()
    FLAGS = flags.FLAGS
    ids, names = core.load_synset_ids()
    name = FLAGS.synset

    synset_id = name if name in names else ids[name]
    if synset_id not in names:
        raise ValueError("Invalid synset_id %s" % synset_id)

    builder = core.ShapenetCore(config=core.FrustumVoxelConfig(
        synset_id=synset_id, resolution=FLAGS.resolution, seed=FLAGS.seed))
    builder.download_and_prepare()

    if FLAGS.vis:

        def vis(example):
            import matplotlib.pyplot as plt
            from mpl_toolkits.mplot3d import Axes3D  # pylint: disable=unused-import

            ax = plt.gca(projection="3d")
            ax.voxels(example["voxels"].numpy())
            # ax.axis("square")
            plt.show()

        dataset = builder.as_dataset(split="train")
        for example in dataset:
            vis(example)
Ejemplo n.º 3
0
def main(_):
    tf.compat.v1.enable_eager_execution()
    FLAGS = flags.FLAGS
    ids, names = core.load_synset_ids()
    name = FLAGS.synset
    seed = FLAGS.seed

    synset_id = name if name in names else ids[name]
    if synset_id not in names:
        raise ValueError("Invalid synset_id %s" % synset_id)

    configs = dict(
        image=core.TrimeshRenderingConfig(synset_id=synset_id,
                                          resolution=(FLAGS.image_res, ) * 2,
                                          seed=seed),
        voxels=core.FrustumVoxelConfig(synset_id=synset_id,
                                       resolution=FLAGS.vox_res,
                                       seed=seed),
    )
    builders = {
        k: core.ShapenetCore(config=config)
        for k, config in configs.items()
    }
    for b in builders.values():
        b.download_and_prepare()

    if FLAGS.vis:

        def vis(example):
            import matplotlib.pyplot as plt

            image = example["image"].numpy()
            voxels = tf.reduce_any(example["voxels"], axis=-1)
            voxels = tf.image.resize(
                tf.expand_dims(tf.cast(voxels, tf.uint8), axis=-1),
                image.shape[:2],
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
            )
            voxels = tf.cast(tf.squeeze(voxels, axis=-1), tf.bool).numpy()
            # voxels = voxels.T
            # voxels = voxels[:, -1::-1]
            image[np.logical_not(voxels)] = 0
            plt.imshow(image)
            plt.show()

        datasets = {
            k: b.as_dataset(split="train",
                            shuffle_files=False).map(lambda x: x[k])
            for k, b in builders.items()
        }
        dataset = tf.data.Dataset.zip(datasets)
        for example in dataset:
            vis(example)
Ejemplo n.º 4
0
def main(_):
    pass

    import matplotlib.pyplot as plt
    import numpy as np

    from shape_tfds.shape.shapenet import core

    synset = FLAGS.synset
    resolution = (FLAGS.resolution,) * 2
    renderer = core.Renderer.named(FLAGS.renderer, resolution=resolution)

    ids, _ = core.load_synset_ids()
    synset_id = ids[synset]
    resolution = 128

    config = core.RenderingConfig(synset_id, renderer)
    frustum_config = core.FrustumVoxelConfig(
        synset_id, resolution=resolution, use_cached_voxels=False
    )
    with config.lazy_mapping() as renderings:
        with frustum_config.lazy_mapping() as voxels:
            for k in renderings:
                _, ax = plt.subplots(2, 2)
                ax = ax.reshape((-1,))
                vox = np.any(voxels[k]["voxels"], axis=-1)

                image = np.array(np.array(renderings[k]["image"]))
                image = image.astype(np.float32) / np.max(image)
                ax[0].imshow(image)
                ax[1].imshow(vox)
                sil_image = image.copy()
                sil_image[np.logical_not(vox)] = 1
                ax[2].imshow(sil_image)
                sil_image = image.copy()
                sil_image[vox] = 1
                ax[3].imshow(sil_image)
                plt.show()
Ejemplo n.º 5
0
from shape_tfds.shape.shapenet import core

ids, names = core.load_synset_ids()

name = "suitcase"
# name = 'watercraft'
# name = 'aeroplane'
# name = 'table'
# name = 'rifle'

config = core.VoxelConfig(synset_id=ids[name], resolution=32)
mapping_context = core.get_data_mapping_context(config)


def vis(voxels):
    """visualize a single image/voxel pair."""
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D  # pylint: disable=unused-import

    ax = plt.gca(projection="3d")
    ax.voxels(voxels)
    # ax.axis("square")
    plt.show()


with mapping_context as mapping:
    for k, v in mapping.items():
        vis(v["voxels"])
Ejemplo n.º 6
0
import os

import tensorflow as tf

from shape_tfds.shape.shapenet import core

tf.compat.v1.enable_eager_execution()

resolution = 32
split = "train"
names = ("suitcase", "telephone", "table")
ids, _ = core.load_synset_ids()

configs = (core.VoxelConfig(ids[n], resolution) for n in names)
builders = tuple(core.ShapenetCore(config=config) for config in configs)

all_fns = []
for builder in builders:
    builder.download_and_prepare()
    prefix = "%s-%s" % (builder.name.split("/")[0], split)
    data_dir = builder.data_dir
    record_fns = tuple(
        os.path.join(data_dir, fn) for fn in os.listdir(data_dir)
        if fn.startswith(prefix))
    all_fns.extend(record_fns)

print(all_fns)
builder = builders[0]
dataset = tf.data.TFRecordDataset(all_fns, num_parallel_reads=len(all_fns))
dataset = dataset.map(builder._file_format_adapter._parser.parse_example)