예제 #1
0
def vis(cat, n_images, view_index=5, example_ids=None):
    import matplotlib.pyplot as plt
    from shapenet.core import cat_desc_to_id, get_example_ids
    from shapenet.core.blender_renderings.config import RenderConfig
    cat_id = cat_desc_to_id(cat)
    config = RenderConfig(n_images=n_images)
    dataset = config.get_dataset(cat_id, view_index)
    if example_ids is not None and len(example_ids) > 0:
        dataset = dataset.subset(example_ids)
    else:
        example_ids = get_example_ids(cat_id)
    with dataset:
        for example_id in example_ids:
            plt.imshow(dataset[example_id])
            plt.title(example_id)
            plt.show()
예제 #2
0
def vis_segmentations(model_id,
                      example_ids=None,
                      vis_mesh=False,
                      edge_length_threshold=0.02,
                      include_wireframe=False,
                      save=False):
    from scipy.misc import imsave
    if save and example_ids is None:
        raise ValueError('Cannot save without specifying example_ids')
    builder = get_builder(model_id)
    cat_id = builder.cat_id
    if example_ids is None:
        example_ids = example_ids = get_example_ids(cat_id, 'eval')
    if vis_mesh:
        segmented_fn = builder.get_segmented_mesh_fn(edge_length_threshold)
    else:
        segmented_fn = builder.get_segmented_cloud_fn()
    config = RenderConfig()

    with get_predictions_dataset(model_id) as predictions:
        with config.get_dataset(cat_id, builder.view_index) as image_ds:
            for example_id in example_ids:
                example = predictions[example_id]
                probs, dp = (np.array(example[k]) for k in ('probs', 'dp'))
                result = segmented_fn(probs, dp)
                if result is not None:
                    image = image_ds[example_id]
                    print(example_id)
                    segmentation = result['segmentation']
                    if vis_mesh:
                        vertices = result['vertices']
                        faces = result['faces']
                        original_points = result['original_points']
                        original_seg = result['original_segmentation']
                        f0 = mlab.figure(bgcolor=(1, 1, 1))
                        vis_segmented_mesh(vertices,
                                           segmented_cloud(
                                               faces, segmentation),
                                           include_wireframe=include_wireframe,
                                           opacity=0.2)
                        f1 = mlab.figure(bgcolor=(1, 1, 1))
                        vis_clouds(
                            segmented_cloud(original_points, original_seg))
                    else:
                        points = result['points']
                        original_points = result['original_points']
                        f0 = mlab.figure(bgcolor=(1, 1, 1))
                        vis_clouds(segmented_cloud(points, segmentation))
                        f1 = mlab.figure(bgcolor=(1, 1, 1))
                        vis_clouds(
                            segmented_cloud(original_points, segmentation))

                    if save:
                        folder = os.path.join(_paper_dir, 'segmentations',
                                              model_id, example_id)
                        if not os.path.isdir(folder):
                            os.makedirs(folder)
                        fn = 'inferred_%s.png' % ('mesh'
                                                  if vis_mesh else 'cloud')
                        p0 = os.path.join(folder, fn)
                        mlab.savefig(p0, figure=f0)
                        p1 = os.path.join(folder, 'annotated_cloud.png')
                        mlab.savefig(p1, figure=f1)
                        pi = os.path.join(folder, 'query_image.png')
                        imsave(pi, image)
                        mlab.close()
                    else:
                        plt.imshow(image)
                        plt.show(block=False)
                        mlab.show()
                        plt.close()
예제 #3
0
import os
import matplotlib.pyplot as plt
from shapenet.image import with_background
from shapenet.core.blender_renderings.config import RenderConfig
from shapenet.core import cat_desc_to_id, get_example_ids


cat_desc = 'plane'
view_index = 5
config = RenderConfig()
view_angle = config.view_angle(view_index)
cat_id = cat_desc_to_id(cat_desc)
example_ids = get_example_ids(cat_id)

path = config.get_zip_path(cat_id)
if not os.path.isfile(path):
    raise IOError('No renderings at %s' % path)

with config.get_dataset(cat_id, view_index) as ds:
    ds = ds.map(lambda image: with_background(image, 255))
    for example_id in ds:
        image = ds[example_id]
        plt.imshow(image)
        plt.show()
예제 #4
0
def vis_segmentations(
        model_id, example_ids=None, vis_mesh=False,
        edge_length_threshold=0.02, include_wireframe=False,
        save=False):
    from scipy.misc import imsave
    if save and example_ids is None:
        raise ValueError('Cannot save without specifying example_ids')
    builder = get_builder(model_id)
    cat_id = builder.cat_id
    if example_ids is None:
        example_ids = example_ids = get_example_ids(cat_id, 'eval')
    if vis_mesh:
        segmented_fn = builder.get_segmented_mesh_fn(edge_length_threshold)
    else:
        segmented_fn = builder.get_segmented_cloud_fn()
    config = RenderConfig()

    with get_predictions_dataset(model_id) as predictions:
        with config.get_dataset(cat_id, builder.view_index) as image_ds:
            for example_id in example_ids:
                example = predictions[example_id]
                probs, dp = (np.array(example[k]) for k in ('probs', 'dp'))
                result = segmented_fn(probs, dp)
                if result is not None:
                    image = image_ds[example_id]
                    print(example_id)
                    segmentation = result['segmentation']
                    if vis_mesh:
                        vertices = result['vertices']
                        faces = result['faces']
                        original_points = result['original_points']
                        original_seg = result['original_segmentation']
                        f0 = mlab.figure(bgcolor=(1, 1, 1))
                        vis_segmented_mesh(
                            vertices, segmented_cloud(faces, segmentation),
                            include_wireframe=include_wireframe,
                            opacity=0.2)
                        f1 = mlab.figure(bgcolor=(1, 1, 1))
                        vis_clouds(
                            segmented_cloud(original_points, original_seg))
                    else:
                        points = result['points']
                        original_points = result['original_points']
                        f0 = mlab.figure(bgcolor=(1, 1, 1))
                        vis_clouds(segmented_cloud(points, segmentation))
                        f1 = mlab.figure(bgcolor=(1, 1, 1))
                        vis_clouds(
                            segmented_cloud(original_points, segmentation))

                    if save:
                        folder = os.path.join(
                            _paper_dir, 'segmentations', model_id, example_id)
                        if not os.path.isdir(folder):
                            os.makedirs(folder)
                        fn = 'inferred_%s.png' % (
                            'mesh' if vis_mesh else 'cloud')
                        p0 = os.path.join(folder, fn)
                        mlab.savefig(p0, figure=f0)
                        p1 = os.path.join(folder, 'annotated_cloud.png')
                        mlab.savefig(p1, figure=f1)
                        pi = os.path.join(folder, 'query_image.png')
                        imsave(pi, image)
                        mlab.close()
                    else:
                        plt.imshow(image)
                        plt.show(block=False)
                        mlab.show()
                        plt.close()