Пример #1
0
def main():
    log.info('Getting all filenames.')
    syndirs = sorted(glob.glob('/data/render_for_cnn/data/syn_images_cropped_bkg_overlaid/*'))
    random.seed(42)
    filenames = []
    for syndir in syndirs:
        modeldirs = sorted(glob.glob(path.join(syndir, '*')))
        if is_subset:
            modeldirs = modeldirs[:10]
        for modeldir in modeldirs:
            renderings = sorted(glob.glob(path.join(modeldir, '*')))
            if is_subset:
                renderings = renderings[:7]
            filenames.extend(renderings)

    log.info('{} files'.format(len(filenames)))

    data_base_dir = '/data/mvshape'

    start_time = time.time()

    random.seed(42)
    log.info('Processing rgb images.')
    for i, filename in enumerate(filenames):
        m = re.search(r'syn_images_cropped_bkg_overlaid/(.*?)/(.*?)/[^_]+?_[^_]+?_v(\d{4,})_a', filename)
        synset = m.group(1)
        model_name = m.group(2)
        v = m.group(3)
        image_num = int(v)

        vc_rendering_name = '{}_{:04d}'.format(model_name, image_num)
        out_filename = path.join(data_base_dir, 'out/shapenetcore/single_rgb_128/{}.png'.format(vc_rendering_name))

        assert path.isfile(filename)
        io_utils.ensure_dir_exists(path.dirname(out_filename))

        img = io_utils.read_jpg(filename)
        assert img.shape[0] == img.shape[1]
        assert img.shape[2] == 3

        resize_method = {0: 'bilinear',
                         1: 'bicubic',
                         2: 'lanczos'}[random.randint(0, 2)]

        resized_img = scipy.misc.imresize(img, (128, 128), interp=resize_method)
        assert resized_img.dtype == np.uint8

        # io_utils.save_array_compressed(out_filename, resized_img)
        scipy.misc.imsave(out_filename, resized_img)

        if i % 100 == 0:
            t_elapsed = (time.time() - start_time)
            t_remaining = (t_elapsed / (i + 1) * (len(filenames) - i))
            log.info('Creating examples in db. {} of {}. elapsed: {:.1f} min, remaining: {:.1f} min'.format(i, len(filenames), t_elapsed / 60, t_remaining / 60))

    t_elapsed = (time.time() - start_time)
    log.info('total elapsed: {:.1f} min'.format(t_elapsed / 60))
def save_model(model, save_dir):
    current_epoch = model['metadata']['current_epoch']
    current_epoch_step = model['metadata']['current_epoch_step']
    global_step = model['metadata']['global_step']

    filename = path.join(
        save_dir,
        'models0_{:05}_{:07}_{:08}.pth'.format(current_epoch,
                                               current_epoch_step,
                                               global_step))

    io_utils.ensure_dir_exists(path.dirname(filename))

    with open(filename, 'wb') as f:
        log.info('Saving.. {}'.format(filename))
        torch.save(model, f)
        log.info('Saved.')
Пример #3
0
    def save_input_image_as_aligned_mesh(self, tag, i, outdir):
        self.save_input_image_as_mesh(tag, i, path.join(outdir, 'input'))
        saved_mesh = path.join(outdir, 'input/depth_meshes/mesh_0000.ply')
        out0_mesh = path.join(outdir, 'depth_meshes/mesh_0000.ply')
        target_mesh = path.join(outdir, 'input/transformed_input.ply')

        # view0 camera
        cam = self.target_camera_objects(tag, i)[0]

        tmp_filename0 = io_utils.temp_filename('/tmp/mvshape_tmp',
                                               suffix='_transformed.ply')
        tmp_filename1 = io_utils.temp_filename('/tmp/mvshape_tmp',
                                               suffix='_transformed.ply')

        io_utils.ensure_dir_exists(path.dirname(tmp_filename0))
        io_utils.ensure_dir_exists(path.dirname(tmp_filename1))

        Rt = cam.Rt()
        mesh_utils.transform_ply(saved_mesh, tmp_filename0, Rt)
        mesh_utils.transform_ply(out0_mesh, tmp_filename1, Rt)

        source = io_utils.read_ply_pcl(tmp_filename0)['v']
        target = io_utils.read_ply_pcl(tmp_filename1)['v']

        offset, scale = pcl_utils.find_aligning_transformation(source, target)

        M = np.eye(4)
        M[:3, :3] *= scale
        M[:3, 3] = offset

        Rt44 = np.eye(4)
        Rt44[:3, :] = Rt

        final_transform = cam.Rt_inv().dot(M.dot(Rt44))

        mesh_utils.transform_ply(saved_mesh,
                                 target_mesh,
                                 final_transform,
                                 confidence_scale=2.0,
                                 value_scale=1.7)

        return target_mesh
def _truncate_images_worker(params):
    filename, out_dir, ignore_overwrite = params['filename'], params[
        'out_dir'], params['ignore_overwrite']
    filename_parts = filename.split(os.sep)
    path_suffix = os.sep.join(filename_parts[-3:])
    out_filename = os.path.join(out_dir, path_suffix)
    if ignore_overwrite and os.path.isfile(out_filename):
        # Skip if file already exists.
        return None

    image = io_utils.read_png(filename)
    truncated = make_randomized_square_image(image)
    assert truncated.shape[0] == truncated.shape[0]
    assert truncated.shape[2] == 4
    assert truncated.dtype == np.uint8

    out_filename_parent_dir = os.path.dirname(out_filename)
    io_utils.ensure_dir_exists(out_filename_parent_dir, log_mkdir=False)

    io_utils.save_png(truncated, out_filename)

    return out_filename
Пример #5
0
    def fssr_recon(self, out_dir):
        ply_files = self._depth_meshes(out_dir=out_dir)

        fssr_recon_file = mve.fssr_pcl_files(ply_files, scale=0.6)
        fssr_recon_clean_file = mve.meshclean(fssr_recon_file, threshold=0.1)

        recon_dir = io_utils.ensure_dir_exists(path.join(out_dir, 'recon'))

        new_fssr_recon_file = path.join(recon_dir,
                                        path.basename(fssr_recon_file))
        new_fssr_recon_clean_file = path.join(
            recon_dir, path.basename(fssr_recon_clean_file))

        shutil.move(fssr_recon_file, new_fssr_recon_file)
        shutil.move(fssr_recon_clean_file, new_fssr_recon_clean_file)
Пример #6
0
    def fssr_recon_using_input(self, out_dir, aligned_depth_mesh_filename):
        ply_files = [aligned_depth_mesh_filename] + sorted(
            glob.glob(path.join(out_dir, 'depth_meshes/*')))[1:]

        fssr_recon_file = mve.fssr_pcl_files(ply_files, scale=0.3)
        fssr_recon_clean_file = mve.meshclean(fssr_recon_file, threshold=0.25)

        recon_dir = io_utils.ensure_dir_exists(path.join(out_dir, 'recon'))

        new_fssr_recon_file = path.join(recon_dir,
                                        path.basename(fssr_recon_file))
        new_fssr_recon_clean_file = path.join(
            recon_dir, path.basename(fssr_recon_clean_file))

        shutil.move(fssr_recon_file, new_fssr_recon_file + '.fused.ply')
        shutil.move(fssr_recon_clean_file,
                    new_fssr_recon_clean_file + '.fused.ply')
def main():
    syn_images_dir = '/data/mvshape/shapenetcore/single_rgb_128/'
    shapenetcore_dir = '/data/shapenetcore/ShapeNetCore.v1/'

    log.info('Getting all filenames.')
    syndirs = sorted(glob.glob(path.join(syn_images_dir, '*')))
    filenames = []
    for syndir in syndirs:
        modeldirs = sorted(glob.glob(path.join(syndir, '*')))
        if is_subset:
            modeldirs = modeldirs[:10]
        for modeldir in modeldirs:
            renderings = sorted(glob.glob(path.join(modeldir, '*.png')))
            if is_subset:
                renderings = renderings[:7]
            filenames.extend(renderings)

    # random.seed(42)
    # if not is_subset:
    #     random.shuffle(filenames)
    #     filenames = filenames[:1000000]

    random.seed(42)

    log.info('{} files'.format(len(filenames)))

    # TODO
    target_dir = '/data/mvshape/database'

    if is_subset:
        sqlite_file_path = join(target_dir, 'shapenetcore_subset.sqlite')
    else:
        sqlite_file_path = join(target_dir, 'shapenetcore.sqlite')
    output_cam_distance_from_origin = 2

    log.info('Setting up output directory.')
    # set up debugging directory.
    if path.isfile(sqlite_file_path):
        os.remove(sqlite_file_path)
    io_utils.ensure_dir_exists(target_dir)

    # used for making sure there is no duplicate.
    duplicate_name_check_set = set()

    log.info('Checking for duplicates. And making sure params.txt exists.')
    for i, filename in enumerate(filenames):
        m = re.search(r'single_rgb_128/(.*?)/(.*?)/[^_]+?_[^_]+?_v(\d{4,})_a',
                      filename)
        synset = m.group(1)
        model_name = m.group(2)
        v = m.group(3)
        image_num = int(v)
        vc_rendering_name = '{}_{}_{:04d}'.format(synset, model_name,
                                                  image_num)
        if vc_rendering_name in duplicate_name_check_set:
            print('duplicate found: ', (filename, vc_rendering_name))
        duplicate_name_check_set.add(vc_rendering_name)
        params_filename = join(syn_images_dir, synset, model_name,
                               'params.txt')
        assert path.isfile(params_filename)

    # Create the database
    dbm.init(sqlite_file_path)

    with dbm.db.transaction() as txn:
        log.info('Creating common objects.')
        make_dataset('shapenetcore')

        make_rendering_type('rgb')
        make_rendering_type('depth')
        make_rendering_type('normal')
        make_rendering_type('voxels')

        make_tag('novelview')
        make_tag('novelmodel')
        make_tag('novelclass')
        make_tag('perspective_input')
        make_tag('orthographic_input')
        make_tag('perspective_output')
        make_tag('orthographic_output')
        make_tag('viewer_centered')
        make_tag('object_centered')
        make_tag('real_world')

        make_split('train')
        make_split('test')
        make_split('validation')

        # Quote from http://shapenet.cs.stanford.edu/shapenet/obj-zip/ShapeNetCore.v1/README.txt
        #   "The OBJ files have been pre-aligned so that the up direction is the +Y axis, and the front is the +X axis.  In addition each model is normalized to fit within a unit cube centered at the origin."
        oc_output_cam = camera.OrthographicCamera.from_Rt(
            transforms.lookat_matrix(cam_xyz=(0, 0,
                                              output_cam_distance_from_origin),
                                     obj_xyz=(0, 0, 0),
                                     up=(0, 1, 0)),
            wh=(128, 128),
            is_world_to_cam=True)
        db_oc_output_cam = get_db_camera(oc_output_cam, fov=None)

        # Prepare all category objects.
        log.info('Preparing categories.')
        synset_db_category_map = {}
        for synset, synset_name in synset_name_pairs:
            db_category_i, _ = dbm.Category.get_or_create(name=synset_name)
            synset_db_category_map[synset] = db_category_i

        txn.commit()

        # Prepare all mesh model objects.
        # ---------------------------------------------
        db_object_map = {}
        # model_name -> {rendering_type_name -> rendering}
        db_object_centered_renderings = {}
        log.info('Preparing mesh model objects.')
        start_time = time.time()
        count = 0
        for i, filename in enumerate(filenames):
            m = re.search(
                r'single_rgb_128/(.*?)/(.*?)/[^_]+?_[^_]+?_v(\d{4,})_a',
                filename)
            synset = m.group(1)
            model_name = m.group(2)
            if model_name not in db_object_map:
                mesh_filename = join(shapenetcore_dir, synset, model_name,
                                     'model.obj')
                assert path.isfile(mesh_filename)

                mesh_filename_suffix = join(
                    '/mesh/shapenetcore/v1',
                    '/'.join(mesh_filename.split('/')[-3:]))

                db_category = synset_db_category_map[synset]
                # Must be unique.
                db_object = dbm.Object.create(
                    name=model_name,
                    category=db_category,
                    dataset=datasets['shapenetcore'],
                    num_vertices=0,  # Not needed for now. Easy to fill in later.
                    num_faces=0,
                    mesh_filename=mesh_filename_suffix,
                )
                db_object_map[model_name] = db_object

                oc_rendering_name = '{}_{}'.format(synset, model_name)

                assert model_name not in db_object_centered_renderings
                db_object_centered_renderings[model_name] = {
                    'output_rgb':
                    dbm.ObjectRendering.create(
                        type=rendering_types['rgb'],
                        camera=db_oc_output_cam,
                        object=db_object,
                        # JPG
                        filename='/shapenetcore/mv20_rgb_128/{}.bin'.format(
                            oc_rendering_name),
                        resolution=128,
                        num_channels=3,
                        set_size=20,
                        is_normalized=False,
                    ),
                    'output_depth':
                    dbm.ObjectRendering.create(
                        type=rendering_types['depth'],
                        camera=db_oc_output_cam,
                        object=db_object,
                        # Since there is only one gt rendering per model, their id is the same as the model name.
                        filename='/shapenetcore/mv20_depth_128/{}.bin'.format(
                            oc_rendering_name),
                        resolution=128,
                        num_channels=1,
                        set_size=20,
                        is_normalized=False,
                    ),
                    'output_normal':
                    dbm.ObjectRendering.create(
                        type=rendering_types['normal'],
                        camera=db_oc_output_cam,
                        object=db_object,
                        filename='/shapenetcore/mv20_normal_128/{}.bin'.format(
                            oc_rendering_name),
                        resolution=128,
                        num_channels=3,
                        set_size=20,
                        is_normalized=False,
                    ),
                    'output_voxels':
                    dbm.ObjectRendering.create(
                        type=rendering_types['voxels'],
                        camera=db_oc_output_cam,
                        object=db_object,
                        filename='/shapenetcore/voxels_32/{}.bin'.format(
                            oc_rendering_name),
                        resolution=32,
                        num_channels=1,
                        set_size=1,
                        is_normalized=False,
                    )
                }

                if count % 5000 == 0:
                    txn.commit()
                    t_elapsed = (time.time() - start_time)
                    t_remaining = (t_elapsed / (i + 1) * (len(filenames) - i))
                    log.info(
                        'Creating mesh objects in db. {} of {}. elapsed: {:.1f} min, remaining: {:.1f} min'
                        .format(i, len(filenames), t_elapsed / 60,
                                t_remaining / 60))

                count += 1
        txn.commit()
        t_elapsed = time.time() - start_time
        log.info('created {} mesh objects in db. elapsed: {:.1f} min'.format(
            count, t_elapsed / 60))

        start_time = time.time()

        log.info('Processing rgb images.')
        for i, filename in enumerate(filenames):
            m = re.search(
                r'single_rgb_128/(.*?)/(.*?)/[^_]+?_[^_]+?_v(\d{4,})_a',
                filename)
            synset = m.group(1)
            model_name = m.group(2)
            v = m.group(3)
            image_num = int(v)

            params_filename = join(syn_images_dir, synset, model_name,
                                   'params.txt')
            assert path.isfile(params_filename)

            lines = render_for_cnn_utils.read_params_file(params_filename)
            Rt = render_for_cnn_utils.get_Rt_from_RenderForCNN_parameters(
                lines[image_num])

            # Input and output cameras
            # -------------------
            input_cam = camera.OrthographicCamera.from_Rt(Rt,
                                                          wh=(128, 128),
                                                          is_world_to_cam=True)
            # 49.1343 degrees is the default fov in blender.
            db_input_cam = get_db_camera(input_cam, fov=49.1343)

            input_cam_depth_xyz = (input_cam.pos /
                                   la.norm(input_cam.pos)) * 1.5
            input_cam_depth_Rt = transforms.lookat_matrix(
                cam_xyz=input_cam_depth_xyz,
                obj_xyz=(0, 0, 0),
                up=input_cam.up_vector)
            input_cam_depth = camera.OrthographicCamera.from_Rt(
                input_cam_depth_Rt, wh=(128, 128), is_world_to_cam=True)
            db_input_cam_depth = get_db_camera(input_cam_depth, fov=49.1343)

            output_cam_xyz = (input_cam.pos / la.norm(
                input_cam.pos)) * output_cam_distance_from_origin
            output_Rt = transforms.lookat_matrix(cam_xyz=output_cam_xyz,
                                                 obj_xyz=(0, 0, 0),
                                                 up=input_cam.up_vector)
            vc_output_cam = camera.OrthographicCamera.from_Rt(
                output_Rt, wh=(128, 128), is_world_to_cam=True)
            db_vc_output_cam = get_db_camera(vc_output_cam, fov=None)

            # ---
            db_object = db_object_map[model_name]

            vc_rendering_name = '{}_{}_{:04d}'.format(synset, model_name,
                                                      image_num)

            # Viewer centered renderings.
            # --------------------------------

            # Input rgb image:
            db_object_rendering_input_rgb = dbm.ObjectRendering.create(
                type=rendering_types['rgb'],
                camera=db_input_cam,
                object=db_object,
                # This should already exist.
                filename='/shapenetcore/single_rgb_128/{}.png'.format(
                    vc_rendering_name),
                resolution=128,
                num_channels=1,
                set_size=1,
                is_normalized=False,  # False for rgb.
            )

            db_object_rendering_input_depth = dbm.ObjectRendering.create(
                type=rendering_types['depth'],
                camera=db_input_cam_depth,
                object=db_object,
                filename='/shapenetcore/single_depth_128/{}.bin'.format(
                    vc_rendering_name),
                resolution=128,
                num_channels=1,
                set_size=1,
                is_normalized=True,
            )

            db_object_rendering_vc_output_rgb = dbm.ObjectRendering.create(
                type=rendering_types['rgb'],
                camera=db_vc_output_cam,
                object=db_object,
                filename='/shapenetcore/mv20_rgb_128/{}.bin'.format(
                    vc_rendering_name),
                resolution=128,
                num_channels=3,
                set_size=20,
                is_normalized=False,
            )

            db_object_rendering_vc_output_depth = dbm.ObjectRendering.create(
                type=rendering_types['depth'],
                camera=db_vc_output_cam,
                object=db_object,
                filename='/shapenetcore/mv20_depth_128/{}.bin'.format(
                    vc_rendering_name),
                resolution=128,
                num_channels=1,
                set_size=20,
                is_normalized=False,
            )

            db_object_rendering_vc_output_normal = dbm.ObjectRendering.create(
                type=rendering_types['normal'],
                camera=db_vc_output_cam,
                object=db_object,
                filename='/shapenetcore/mv20_normal_128/{}.bin'.format(
                    vc_rendering_name),
                resolution=128,
                num_channels=3,
                set_size=20,
                is_normalized=False,
            )

            db_object_rendering_vc_output_voxels = dbm.ObjectRendering.create(
                type=rendering_types['voxels'],
                camera=db_vc_output_cam,
                object=db_object,
                filename='/shapenetcore/voxels_32/{}.bin'.format(
                    vc_rendering_name),
                resolution=32,
                num_channels=1,
                set_size=1,
                is_normalized=False,
            )

            # Examples
            # ----------------

            # A row in the `Example` table is just an id for many-to-many references.

            # View centered
            example_viewer_centered = dbm.Example.create()
            dbm.ExampleObjectRendering.create(
                example=example_viewer_centered,
                rendering=db_object_rendering_input_rgb)
            dbm.ExampleObjectRendering.create(
                example=example_viewer_centered,
                rendering=db_object_rendering_input_depth)
            dbm.ExampleObjectRendering.create(
                example=example_viewer_centered,
                rendering=db_object_rendering_vc_output_depth)
            dbm.ExampleObjectRendering.create(
                example=example_viewer_centered,
                rendering=db_object_rendering_vc_output_normal)
            dbm.ExampleObjectRendering.create(
                example=example_viewer_centered,
                rendering=db_object_rendering_vc_output_rgb)
            dbm.ExampleObjectRendering.create(
                example=example_viewer_centered,
                rendering=db_object_rendering_vc_output_voxels)
            dbm.ExampleDataset.create(example=example_viewer_centered,
                                      dataset=datasets['shapenetcore'])
            dbm.ExampleSplit.create(example=example_viewer_centered,
                                    split=splits['train'])
            dbm.ExampleTag.create(example=example_viewer_centered,
                                  tag=tags['real_world'])
            dbm.ExampleTag.create(example=example_viewer_centered,
                                  tag=tags['viewer_centered'])
            dbm.ExampleTag.create(example=example_viewer_centered,
                                  tag=tags['perspective_input'])
            dbm.ExampleTag.create(example=example_viewer_centered,
                                  tag=tags['orthographic_output'])
            dbm.ExampleTag.create(example=example_viewer_centered,
                                  tag=tags['novelmodel'])

            # Object centered
            example_object_centered = dbm.Example.create()
            dbm.ExampleObjectRendering.create(
                example=example_object_centered,
                rendering=db_object_rendering_input_rgb)
            dbm.ExampleObjectRendering.create(
                example=example_object_centered,
                rendering=db_object_rendering_input_depth)
            dbm.ExampleObjectRendering.create(
                example=example_object_centered,
                rendering=db_object_centered_renderings[model_name]
                ['output_depth'])
            dbm.ExampleObjectRendering.create(
                example=example_object_centered,
                rendering=db_object_centered_renderings[model_name]
                ['output_normal'])
            dbm.ExampleObjectRendering.create(
                example=example_object_centered,
                rendering=db_object_centered_renderings[model_name]
                ['output_rgb'])
            dbm.ExampleObjectRendering.create(
                example=example_object_centered,
                rendering=db_object_centered_renderings[model_name]
                ['output_voxels'])
            dbm.ExampleDataset.create(example=example_object_centered,
                                      dataset=datasets['shapenetcore'])
            dbm.ExampleSplit.create(example=example_object_centered,
                                    split=splits['train'])
            dbm.ExampleTag.create(example=example_object_centered,
                                  tag=tags['real_world'])
            dbm.ExampleTag.create(example=example_object_centered,
                                  tag=tags['object_centered'])
            dbm.ExampleTag.create(example=example_object_centered,
                                  tag=tags['perspective_input'])
            dbm.ExampleTag.create(example=example_object_centered,
                                  tag=tags['orthographic_output'])
            dbm.ExampleTag.create(example=example_object_centered,
                                  tag=tags['novelmodel'])

            if i % 5000 == 0:
                txn.commit()
                t_elapsed = (time.time() - start_time)
                t_remaining = (t_elapsed / (i + 1) * (len(filenames) - i))
                log.info(
                    'Creating examples in db. {} of {}. elapsed: {:.1f} min, remaining: {:.1f} min'
                    .format(i, len(filenames), t_elapsed / 60,
                            t_remaining / 60))
        txn.commit()

    dbm.db.commit()

    t_elapsed = (time.time() - start_time)
    log.info('total elapsed: {:.1f} min'.format(t_elapsed / 60))
def main():
    base = '/data/mvshape'
    batch_size = 50
    np.random.seed(42)

    loaders_o = mvshape.data.dataset.ExampleLoader2(
        '/data/mvshape/out/splits/pascal3d_test_examples_opo/all_examples.cbor',
        tensors_to_read=('input_rgb', 'target_depth', 'target_voxels'),
        shuffle=True,
        batch_size=batch_size)
    loaders_v = mvshape.data.dataset.ExampleLoader2(
        '/data/mvshape/out/splits/pascal3d_test_examples_vpo/all_examples.cbor',
        tensors_to_read=('input_rgb', 'target_depth', 'target_voxels'),
        shuffle=True,
        batch_size=batch_size)
    loaders = [loaders_o, loaders_v]

    both_models = [
        mvshape.models.encoderdecoder.load_model(
            '/data/mvshape/out/pytorch/shapenetcore_rgb_mv6/opo/0/models0_00005_0018323_00109115.pth'
        ),
        mvshape.models.encoderdecoder.load_model(
            '/data/mvshape/out/pytorch/shapenetcore_rgb_mv6/vpo/0/models0_00005_0018323_00109115.pth'
        ),
    ]

    exps = ['o', 'v']

    # #### TODO
    # mode = 1
    #
    # loaders = [loaders[mode]]
    # both_models = [both_models[mode]]
    # exps = [exps[mode]]
    # ####

    counter = 0

    for L, M, exp in zip(loaders, both_models, exps):
        torch_utils.recursive_module_apply(M, lambda m: m.cuda())
        torch_utils.recursive_train_setter(M, is_training=False)

        loader = L

        while True:
            next_batch = loader.next()
            if next_batch is None:
                print('END ################################')
                break
            batch_data_np = mvshape.models.encoderdecoder.prepare_data_rgb_mv(
                next_batch=next_batch)
            im = batch_data_np['in_image']
            helper_torch_modules = mvshape.models.encoderdecoder.build_helper_torch_modules(
            )
            out = mvshape.models.encoderdecoder.get_final_images_from_model(
                M, im, helper_torch_modules=helper_torch_modules)

            masked_depth = out['masked_depth']

            recon_basedir = '/data/mvshape/out/pascal3d_recon/'
            out_basedir = '/data/mvshape/out/pascal3d_figures/'

            for bi in range(len(next_batch[0])):
                image_name = path.basename(
                    next_batch[0][bi]['input_rgb']['filename']).split('.')[0]
                recon_dir = recon_basedir + '/{}/{}/'.format(exp, image_name)
                # if path.isdir(recon_dir):
                #     print('{} exists. skipping'.format(recon_dir))
                #     continue
                eye = next_batch[0][bi]['target_camera']['eye']
                up = next_batch[0][bi]['target_camera']['up']
                lookat = next_batch[0][bi]['target_camera']['lookat']

                Rt_list = mvshape.camera_utils.make_six_views(
                    camera_xyz=eye, object_xyz=lookat, up=up)
                cams = [
                    dshin.camera.OrthographicCamera.from_Rt(Rt_list[i],
                                                            sRt_scale=1.75,
                                                            wh=(128, 128))
                    for i in range(len(Rt_list))
                ]

                mv = mvshape.shapes.MVshape(masked_images=masked_depth[bi],
                                            cameras=cams)

                depth_mesh_filenames = glob.glob(recon_dir +
                                                 'depth_meshes/*.ply')
                pcl = []
                for item in depth_mesh_filenames:
                    pcl.append(io_utils.read_mesh(item)['v'])
                pcl = np.concatenate(pcl, axis=0)
                print(pcl.shape)

                fig_dir = path.join(out_basedir,
                                    '{}/{}/'.format(exp, image_name))
                io_utils.ensure_dir_exists(fig_dir)

                pt.figure(figsize=(5, 5))
                ax = pt.gca(projection='3d')
                color = pcl[:,
                            0] + 0.6  # +0.6 to force the values to be positive. not necessary.
                rotmat = transforms.rotation_matrix(angle=45,
                                                    direction=np.array(
                                                        (0, 1, 0)))
                rotmat2 = transforms.rotation_matrix(angle=-30,
                                                     direction=np.array(
                                                         (0, 0, 1)))
                pcl = transforms.apply44(rotmat2.dot(rotmat), pcl)
                index_array = np.argsort(pcl[:, 0])
                pcl = pcl[index_array]
                color = color[index_array]
                geom3d.pts(pcl,
                           markersize=45,
                           color=color,
                           zdir='y',
                           show_labels=False,
                           cmap='viridis',
                           cam_sph=(1, 90, 0),
                           ax=ax)
                ax.axis('off')
                pt.savefig(fig_dir + 'pcl.png',
                           bbox_inches='tight',
                           transparent=True,
                           pad_inches=0)
                pt.close()

                pt.figure(figsize=(5, 5))
                ax = pt.gca()
                geom2d.draw_depth(out['silhouette_prob'][bi],
                                  cmap='gray',
                                  nan_color=(1.0, 1.0, 1.0),
                                  grid=128,
                                  grid_width=3,
                                  ax=ax,
                                  show_colorbar=False,
                                  show_colorbar_ticks=False)
                pt.savefig(fig_dir + '/silhouette.png',
                           bbox_inches='tight',
                           transparent=False,
                           pad_inches=0)
                pt.close()

                pt.figure(figsize=(10, 10))
                ax = pt.gca()
                geom2d.draw_depth(masked_depth[bi],
                                  cmap='viridis',
                                  nan_color=(1.0, 1.0, 1.0),
                                  grid=128,
                                  grid_width=6,
                                  ax=ax,
                                  show_colorbar=False,
                                  show_colorbar_ticks=False)
                pt.savefig(fig_dir + '/masked-depth.png',
                           bbox_inches='tight',
                           transparent=False,
                           pad_inches=0)
                pt.close()

                rgb_filename = base + next_batch[0][bi]['input_rgb']['filename']
                assert path.isfile(rgb_filename)

                rgb_link_target = fig_dir + '/input.png'
                if path.islink(rgb_link_target):
                    os.remove(rgb_link_target)
                os.symlink(rgb_filename, rgb_link_target)
                print(counter, fig_dir, rgb_filename)

                counter += 1
def blend_and_resize(params):
    object_image_filename = params['object_image_filename']
    bkg_filenames = params['bkg_filenames']
    out_image_filename = params['out_image_filename']

    object_image = io_utils.read_png(object_image_filename)

    resolution = 128
    bkg_clutter_ratio = 0.8
    scale_max = 4

    use_background_image = random.random() < bkg_clutter_ratio
    resize_method = {
        0: 'bilinear',
        1: 'bicubic',
        2: 'lanczos'
    }[random.randint(0, 2)]

    def force_uint8(arr):
        if arr.dtype in (np.float32, np.float64):
            arr = (arr * 255).round().astype(np.uint8)
        assert arr.dtype == np.uint8
        return arr

    def force_float(arr):
        if arr.dtype == np.uint8:
            arr = arr.astype(np.float32) / 255.0
        elif arr.dtype == np.float64:
            arr = arr.astype(np.float32)
        assert arr.dtype == np.float32
        return arr

    def resize(arr: np.ndarray, res):
        assert arr.shape[0] == arr.shape[1]
        arr = force_uint8(arr)
        resized = scipy.misc.imresize(arr,
                                      size=(res, res),
                                      interp=resize_method)
        return force_float(resized)

    # Crop and pad.
    iy, ix = np.where(object_image)[:2]
    y0 = np.min(iy)
    x0 = np.min(ix)
    y1 = np.max(iy) + 1
    x1 = np.max(ix) + 1
    src_cropped = object_image[y0:y1, x0:x1]
    src_cropped_square = make_randomized_square_image(src_cropped)

    src_s = src_cropped_square.shape[0]
    target_resolution = min(src_s, resolution)

    if use_background_image:
        # Read random file.
        bkg_image = None
        while True:
            bkg_image = _load_image(random.choice(bkg_filenames))
            s = min(bkg_image.shape[:2])
            if s >= target_resolution:
                break

        if bkg_image.ndim == 2:
            bkg_image = np.tile(bkg_image[:, :, None], (1, 1, 3))

        # resize background image. if the iamge is smaller than 128, don't resize.
        res = random.randint(
            target_resolution,
            min(target_resolution * scale_max, min(bkg_image.shape[:2])))
        y = random.randint(0, bkg_image.shape[0] - res)
        x = random.randint(0, bkg_image.shape[1] - res)
        bkg_cropped = bkg_image[y:y + res, x:x + res]
        assert bkg_cropped.shape[0] == bkg_cropped.shape[1]
        assert bkg_cropped.shape[0] == res
        assert bkg_cropped.shape[2] == 3

        bkg = resize(bkg_cropped, res=target_resolution)
    else:
        color_gray = random.random()
        bkg = color_gray * np.ones(
            (target_resolution, target_resolution, 3), dtype=np.float32)

    src_image_rgba = resize(src_cropped_square, res=target_resolution)
    src_image = src_image_rgba[:, :, :3]
    mask = src_image_rgba[:, :, 3]

    assert bkg.dtype == np.float32
    assert bkg.shape[0] == bkg.shape[1]
    assert bkg.shape[2] == 3
    assert bkg.shape[0] == target_resolution
    assert src_image.dtype == np.float32
    assert src_image.shape == bkg.shape
    assert mask.dtype == np.float32

    blended = ((1.0 - mask)[:, :, None] * bkg) + (
        (mask)[:, :, None] * src_image)

    blended_final = force_uint8(resize(blended, res=resolution))

    assert blended_final.shape[0] == blended_final.shape[1]
    assert blended_final.shape[0] == resolution
    assert blended_final.dtype == np.uint8

    io_utils.ensure_dir_exists(path.dirname(out_image_filename),
                               log_mkdir=False)
    scipy.misc.imsave(out_image_filename, blended_final)

    return out_image_filename