def main(): log.info('Getting all filenames.') syndirs = sorted(glob.glob('/data/render_for_cnn/data/syn_images_cropped_bkg_overlaid/*')) random.seed(42) filenames = [] for syndir in syndirs: modeldirs = sorted(glob.glob(path.join(syndir, '*'))) if is_subset: modeldirs = modeldirs[:10] for modeldir in modeldirs: renderings = sorted(glob.glob(path.join(modeldir, '*'))) if is_subset: renderings = renderings[:7] filenames.extend(renderings) log.info('{} files'.format(len(filenames))) data_base_dir = '/data/mvshape' start_time = time.time() random.seed(42) log.info('Processing rgb images.') for i, filename in enumerate(filenames): m = re.search(r'syn_images_cropped_bkg_overlaid/(.*?)/(.*?)/[^_]+?_[^_]+?_v(\d{4,})_a', filename) synset = m.group(1) model_name = m.group(2) v = m.group(3) image_num = int(v) vc_rendering_name = '{}_{:04d}'.format(model_name, image_num) out_filename = path.join(data_base_dir, 'out/shapenetcore/single_rgb_128/{}.png'.format(vc_rendering_name)) assert path.isfile(filename) io_utils.ensure_dir_exists(path.dirname(out_filename)) img = io_utils.read_jpg(filename) assert img.shape[0] == img.shape[1] assert img.shape[2] == 3 resize_method = {0: 'bilinear', 1: 'bicubic', 2: 'lanczos'}[random.randint(0, 2)] resized_img = scipy.misc.imresize(img, (128, 128), interp=resize_method) assert resized_img.dtype == np.uint8 # io_utils.save_array_compressed(out_filename, resized_img) scipy.misc.imsave(out_filename, resized_img) if i % 100 == 0: t_elapsed = (time.time() - start_time) t_remaining = (t_elapsed / (i + 1) * (len(filenames) - i)) log.info('Creating examples in db. {} of {}. elapsed: {:.1f} min, remaining: {:.1f} min'.format(i, len(filenames), t_elapsed / 60, t_remaining / 60)) t_elapsed = (time.time() - start_time) log.info('total elapsed: {:.1f} min'.format(t_elapsed / 60))
def save_model(model, save_dir): current_epoch = model['metadata']['current_epoch'] current_epoch_step = model['metadata']['current_epoch_step'] global_step = model['metadata']['global_step'] filename = path.join( save_dir, 'models0_{:05}_{:07}_{:08}.pth'.format(current_epoch, current_epoch_step, global_step)) io_utils.ensure_dir_exists(path.dirname(filename)) with open(filename, 'wb') as f: log.info('Saving.. {}'.format(filename)) torch.save(model, f) log.info('Saved.')
def save_input_image_as_aligned_mesh(self, tag, i, outdir): self.save_input_image_as_mesh(tag, i, path.join(outdir, 'input')) saved_mesh = path.join(outdir, 'input/depth_meshes/mesh_0000.ply') out0_mesh = path.join(outdir, 'depth_meshes/mesh_0000.ply') target_mesh = path.join(outdir, 'input/transformed_input.ply') # view0 camera cam = self.target_camera_objects(tag, i)[0] tmp_filename0 = io_utils.temp_filename('/tmp/mvshape_tmp', suffix='_transformed.ply') tmp_filename1 = io_utils.temp_filename('/tmp/mvshape_tmp', suffix='_transformed.ply') io_utils.ensure_dir_exists(path.dirname(tmp_filename0)) io_utils.ensure_dir_exists(path.dirname(tmp_filename1)) Rt = cam.Rt() mesh_utils.transform_ply(saved_mesh, tmp_filename0, Rt) mesh_utils.transform_ply(out0_mesh, tmp_filename1, Rt) source = io_utils.read_ply_pcl(tmp_filename0)['v'] target = io_utils.read_ply_pcl(tmp_filename1)['v'] offset, scale = pcl_utils.find_aligning_transformation(source, target) M = np.eye(4) M[:3, :3] *= scale M[:3, 3] = offset Rt44 = np.eye(4) Rt44[:3, :] = Rt final_transform = cam.Rt_inv().dot(M.dot(Rt44)) mesh_utils.transform_ply(saved_mesh, target_mesh, final_transform, confidence_scale=2.0, value_scale=1.7) return target_mesh
def _truncate_images_worker(params): filename, out_dir, ignore_overwrite = params['filename'], params[ 'out_dir'], params['ignore_overwrite'] filename_parts = filename.split(os.sep) path_suffix = os.sep.join(filename_parts[-3:]) out_filename = os.path.join(out_dir, path_suffix) if ignore_overwrite and os.path.isfile(out_filename): # Skip if file already exists. return None image = io_utils.read_png(filename) truncated = make_randomized_square_image(image) assert truncated.shape[0] == truncated.shape[0] assert truncated.shape[2] == 4 assert truncated.dtype == np.uint8 out_filename_parent_dir = os.path.dirname(out_filename) io_utils.ensure_dir_exists(out_filename_parent_dir, log_mkdir=False) io_utils.save_png(truncated, out_filename) return out_filename
def fssr_recon(self, out_dir): ply_files = self._depth_meshes(out_dir=out_dir) fssr_recon_file = mve.fssr_pcl_files(ply_files, scale=0.6) fssr_recon_clean_file = mve.meshclean(fssr_recon_file, threshold=0.1) recon_dir = io_utils.ensure_dir_exists(path.join(out_dir, 'recon')) new_fssr_recon_file = path.join(recon_dir, path.basename(fssr_recon_file)) new_fssr_recon_clean_file = path.join( recon_dir, path.basename(fssr_recon_clean_file)) shutil.move(fssr_recon_file, new_fssr_recon_file) shutil.move(fssr_recon_clean_file, new_fssr_recon_clean_file)
def fssr_recon_using_input(self, out_dir, aligned_depth_mesh_filename): ply_files = [aligned_depth_mesh_filename] + sorted( glob.glob(path.join(out_dir, 'depth_meshes/*')))[1:] fssr_recon_file = mve.fssr_pcl_files(ply_files, scale=0.3) fssr_recon_clean_file = mve.meshclean(fssr_recon_file, threshold=0.25) recon_dir = io_utils.ensure_dir_exists(path.join(out_dir, 'recon')) new_fssr_recon_file = path.join(recon_dir, path.basename(fssr_recon_file)) new_fssr_recon_clean_file = path.join( recon_dir, path.basename(fssr_recon_clean_file)) shutil.move(fssr_recon_file, new_fssr_recon_file + '.fused.ply') shutil.move(fssr_recon_clean_file, new_fssr_recon_clean_file + '.fused.ply')
def main(): syn_images_dir = '/data/mvshape/shapenetcore/single_rgb_128/' shapenetcore_dir = '/data/shapenetcore/ShapeNetCore.v1/' log.info('Getting all filenames.') syndirs = sorted(glob.glob(path.join(syn_images_dir, '*'))) filenames = [] for syndir in syndirs: modeldirs = sorted(glob.glob(path.join(syndir, '*'))) if is_subset: modeldirs = modeldirs[:10] for modeldir in modeldirs: renderings = sorted(glob.glob(path.join(modeldir, '*.png'))) if is_subset: renderings = renderings[:7] filenames.extend(renderings) # random.seed(42) # if not is_subset: # random.shuffle(filenames) # filenames = filenames[:1000000] random.seed(42) log.info('{} files'.format(len(filenames))) # TODO target_dir = '/data/mvshape/database' if is_subset: sqlite_file_path = join(target_dir, 'shapenetcore_subset.sqlite') else: sqlite_file_path = join(target_dir, 'shapenetcore.sqlite') output_cam_distance_from_origin = 2 log.info('Setting up output directory.') # set up debugging directory. if path.isfile(sqlite_file_path): os.remove(sqlite_file_path) io_utils.ensure_dir_exists(target_dir) # used for making sure there is no duplicate. duplicate_name_check_set = set() log.info('Checking for duplicates. And making sure params.txt exists.') for i, filename in enumerate(filenames): m = re.search(r'single_rgb_128/(.*?)/(.*?)/[^_]+?_[^_]+?_v(\d{4,})_a', filename) synset = m.group(1) model_name = m.group(2) v = m.group(3) image_num = int(v) vc_rendering_name = '{}_{}_{:04d}'.format(synset, model_name, image_num) if vc_rendering_name in duplicate_name_check_set: print('duplicate found: ', (filename, vc_rendering_name)) duplicate_name_check_set.add(vc_rendering_name) params_filename = join(syn_images_dir, synset, model_name, 'params.txt') assert path.isfile(params_filename) # Create the database dbm.init(sqlite_file_path) with dbm.db.transaction() as txn: log.info('Creating common objects.') make_dataset('shapenetcore') make_rendering_type('rgb') make_rendering_type('depth') make_rendering_type('normal') make_rendering_type('voxels') make_tag('novelview') make_tag('novelmodel') make_tag('novelclass') make_tag('perspective_input') make_tag('orthographic_input') make_tag('perspective_output') make_tag('orthographic_output') make_tag('viewer_centered') make_tag('object_centered') make_tag('real_world') make_split('train') make_split('test') make_split('validation') # Quote from http://shapenet.cs.stanford.edu/shapenet/obj-zip/ShapeNetCore.v1/README.txt # "The OBJ files have been pre-aligned so that the up direction is the +Y axis, and the front is the +X axis. In addition each model is normalized to fit within a unit cube centered at the origin." oc_output_cam = camera.OrthographicCamera.from_Rt( transforms.lookat_matrix(cam_xyz=(0, 0, output_cam_distance_from_origin), obj_xyz=(0, 0, 0), up=(0, 1, 0)), wh=(128, 128), is_world_to_cam=True) db_oc_output_cam = get_db_camera(oc_output_cam, fov=None) # Prepare all category objects. log.info('Preparing categories.') synset_db_category_map = {} for synset, synset_name in synset_name_pairs: db_category_i, _ = dbm.Category.get_or_create(name=synset_name) synset_db_category_map[synset] = db_category_i txn.commit() # Prepare all mesh model objects. # --------------------------------------------- db_object_map = {} # model_name -> {rendering_type_name -> rendering} db_object_centered_renderings = {} log.info('Preparing mesh model objects.') start_time = time.time() count = 0 for i, filename in enumerate(filenames): m = re.search( r'single_rgb_128/(.*?)/(.*?)/[^_]+?_[^_]+?_v(\d{4,})_a', filename) synset = m.group(1) model_name = m.group(2) if model_name not in db_object_map: mesh_filename = join(shapenetcore_dir, synset, model_name, 'model.obj') assert path.isfile(mesh_filename) mesh_filename_suffix = join( '/mesh/shapenetcore/v1', '/'.join(mesh_filename.split('/')[-3:])) db_category = synset_db_category_map[synset] # Must be unique. db_object = dbm.Object.create( name=model_name, category=db_category, dataset=datasets['shapenetcore'], num_vertices=0, # Not needed for now. Easy to fill in later. num_faces=0, mesh_filename=mesh_filename_suffix, ) db_object_map[model_name] = db_object oc_rendering_name = '{}_{}'.format(synset, model_name) assert model_name not in db_object_centered_renderings db_object_centered_renderings[model_name] = { 'output_rgb': dbm.ObjectRendering.create( type=rendering_types['rgb'], camera=db_oc_output_cam, object=db_object, # JPG filename='/shapenetcore/mv20_rgb_128/{}.bin'.format( oc_rendering_name), resolution=128, num_channels=3, set_size=20, is_normalized=False, ), 'output_depth': dbm.ObjectRendering.create( type=rendering_types['depth'], camera=db_oc_output_cam, object=db_object, # Since there is only one gt rendering per model, their id is the same as the model name. filename='/shapenetcore/mv20_depth_128/{}.bin'.format( oc_rendering_name), resolution=128, num_channels=1, set_size=20, is_normalized=False, ), 'output_normal': dbm.ObjectRendering.create( type=rendering_types['normal'], camera=db_oc_output_cam, object=db_object, filename='/shapenetcore/mv20_normal_128/{}.bin'.format( oc_rendering_name), resolution=128, num_channels=3, set_size=20, is_normalized=False, ), 'output_voxels': dbm.ObjectRendering.create( type=rendering_types['voxels'], camera=db_oc_output_cam, object=db_object, filename='/shapenetcore/voxels_32/{}.bin'.format( oc_rendering_name), resolution=32, num_channels=1, set_size=1, is_normalized=False, ) } if count % 5000 == 0: txn.commit() t_elapsed = (time.time() - start_time) t_remaining = (t_elapsed / (i + 1) * (len(filenames) - i)) log.info( 'Creating mesh objects in db. {} of {}. elapsed: {:.1f} min, remaining: {:.1f} min' .format(i, len(filenames), t_elapsed / 60, t_remaining / 60)) count += 1 txn.commit() t_elapsed = time.time() - start_time log.info('created {} mesh objects in db. elapsed: {:.1f} min'.format( count, t_elapsed / 60)) start_time = time.time() log.info('Processing rgb images.') for i, filename in enumerate(filenames): m = re.search( r'single_rgb_128/(.*?)/(.*?)/[^_]+?_[^_]+?_v(\d{4,})_a', filename) synset = m.group(1) model_name = m.group(2) v = m.group(3) image_num = int(v) params_filename = join(syn_images_dir, synset, model_name, 'params.txt') assert path.isfile(params_filename) lines = render_for_cnn_utils.read_params_file(params_filename) Rt = render_for_cnn_utils.get_Rt_from_RenderForCNN_parameters( lines[image_num]) # Input and output cameras # ------------------- input_cam = camera.OrthographicCamera.from_Rt(Rt, wh=(128, 128), is_world_to_cam=True) # 49.1343 degrees is the default fov in blender. db_input_cam = get_db_camera(input_cam, fov=49.1343) input_cam_depth_xyz = (input_cam.pos / la.norm(input_cam.pos)) * 1.5 input_cam_depth_Rt = transforms.lookat_matrix( cam_xyz=input_cam_depth_xyz, obj_xyz=(0, 0, 0), up=input_cam.up_vector) input_cam_depth = camera.OrthographicCamera.from_Rt( input_cam_depth_Rt, wh=(128, 128), is_world_to_cam=True) db_input_cam_depth = get_db_camera(input_cam_depth, fov=49.1343) output_cam_xyz = (input_cam.pos / la.norm( input_cam.pos)) * output_cam_distance_from_origin output_Rt = transforms.lookat_matrix(cam_xyz=output_cam_xyz, obj_xyz=(0, 0, 0), up=input_cam.up_vector) vc_output_cam = camera.OrthographicCamera.from_Rt( output_Rt, wh=(128, 128), is_world_to_cam=True) db_vc_output_cam = get_db_camera(vc_output_cam, fov=None) # --- db_object = db_object_map[model_name] vc_rendering_name = '{}_{}_{:04d}'.format(synset, model_name, image_num) # Viewer centered renderings. # -------------------------------- # Input rgb image: db_object_rendering_input_rgb = dbm.ObjectRendering.create( type=rendering_types['rgb'], camera=db_input_cam, object=db_object, # This should already exist. filename='/shapenetcore/single_rgb_128/{}.png'.format( vc_rendering_name), resolution=128, num_channels=1, set_size=1, is_normalized=False, # False for rgb. ) db_object_rendering_input_depth = dbm.ObjectRendering.create( type=rendering_types['depth'], camera=db_input_cam_depth, object=db_object, filename='/shapenetcore/single_depth_128/{}.bin'.format( vc_rendering_name), resolution=128, num_channels=1, set_size=1, is_normalized=True, ) db_object_rendering_vc_output_rgb = dbm.ObjectRendering.create( type=rendering_types['rgb'], camera=db_vc_output_cam, object=db_object, filename='/shapenetcore/mv20_rgb_128/{}.bin'.format( vc_rendering_name), resolution=128, num_channels=3, set_size=20, is_normalized=False, ) db_object_rendering_vc_output_depth = dbm.ObjectRendering.create( type=rendering_types['depth'], camera=db_vc_output_cam, object=db_object, filename='/shapenetcore/mv20_depth_128/{}.bin'.format( vc_rendering_name), resolution=128, num_channels=1, set_size=20, is_normalized=False, ) db_object_rendering_vc_output_normal = dbm.ObjectRendering.create( type=rendering_types['normal'], camera=db_vc_output_cam, object=db_object, filename='/shapenetcore/mv20_normal_128/{}.bin'.format( vc_rendering_name), resolution=128, num_channels=3, set_size=20, is_normalized=False, ) db_object_rendering_vc_output_voxels = dbm.ObjectRendering.create( type=rendering_types['voxels'], camera=db_vc_output_cam, object=db_object, filename='/shapenetcore/voxels_32/{}.bin'.format( vc_rendering_name), resolution=32, num_channels=1, set_size=1, is_normalized=False, ) # Examples # ---------------- # A row in the `Example` table is just an id for many-to-many references. # View centered example_viewer_centered = dbm.Example.create() dbm.ExampleObjectRendering.create( example=example_viewer_centered, rendering=db_object_rendering_input_rgb) dbm.ExampleObjectRendering.create( example=example_viewer_centered, rendering=db_object_rendering_input_depth) dbm.ExampleObjectRendering.create( example=example_viewer_centered, rendering=db_object_rendering_vc_output_depth) dbm.ExampleObjectRendering.create( example=example_viewer_centered, rendering=db_object_rendering_vc_output_normal) dbm.ExampleObjectRendering.create( example=example_viewer_centered, rendering=db_object_rendering_vc_output_rgb) dbm.ExampleObjectRendering.create( example=example_viewer_centered, rendering=db_object_rendering_vc_output_voxels) dbm.ExampleDataset.create(example=example_viewer_centered, dataset=datasets['shapenetcore']) dbm.ExampleSplit.create(example=example_viewer_centered, split=splits['train']) dbm.ExampleTag.create(example=example_viewer_centered, tag=tags['real_world']) dbm.ExampleTag.create(example=example_viewer_centered, tag=tags['viewer_centered']) dbm.ExampleTag.create(example=example_viewer_centered, tag=tags['perspective_input']) dbm.ExampleTag.create(example=example_viewer_centered, tag=tags['orthographic_output']) dbm.ExampleTag.create(example=example_viewer_centered, tag=tags['novelmodel']) # Object centered example_object_centered = dbm.Example.create() dbm.ExampleObjectRendering.create( example=example_object_centered, rendering=db_object_rendering_input_rgb) dbm.ExampleObjectRendering.create( example=example_object_centered, rendering=db_object_rendering_input_depth) dbm.ExampleObjectRendering.create( example=example_object_centered, rendering=db_object_centered_renderings[model_name] ['output_depth']) dbm.ExampleObjectRendering.create( example=example_object_centered, rendering=db_object_centered_renderings[model_name] ['output_normal']) dbm.ExampleObjectRendering.create( example=example_object_centered, rendering=db_object_centered_renderings[model_name] ['output_rgb']) dbm.ExampleObjectRendering.create( example=example_object_centered, rendering=db_object_centered_renderings[model_name] ['output_voxels']) dbm.ExampleDataset.create(example=example_object_centered, dataset=datasets['shapenetcore']) dbm.ExampleSplit.create(example=example_object_centered, split=splits['train']) dbm.ExampleTag.create(example=example_object_centered, tag=tags['real_world']) dbm.ExampleTag.create(example=example_object_centered, tag=tags['object_centered']) dbm.ExampleTag.create(example=example_object_centered, tag=tags['perspective_input']) dbm.ExampleTag.create(example=example_object_centered, tag=tags['orthographic_output']) dbm.ExampleTag.create(example=example_object_centered, tag=tags['novelmodel']) if i % 5000 == 0: txn.commit() t_elapsed = (time.time() - start_time) t_remaining = (t_elapsed / (i + 1) * (len(filenames) - i)) log.info( 'Creating examples in db. {} of {}. elapsed: {:.1f} min, remaining: {:.1f} min' .format(i, len(filenames), t_elapsed / 60, t_remaining / 60)) txn.commit() dbm.db.commit() t_elapsed = (time.time() - start_time) log.info('total elapsed: {:.1f} min'.format(t_elapsed / 60))
def main(): base = '/data/mvshape' batch_size = 50 np.random.seed(42) loaders_o = mvshape.data.dataset.ExampleLoader2( '/data/mvshape/out/splits/pascal3d_test_examples_opo/all_examples.cbor', tensors_to_read=('input_rgb', 'target_depth', 'target_voxels'), shuffle=True, batch_size=batch_size) loaders_v = mvshape.data.dataset.ExampleLoader2( '/data/mvshape/out/splits/pascal3d_test_examples_vpo/all_examples.cbor', tensors_to_read=('input_rgb', 'target_depth', 'target_voxels'), shuffle=True, batch_size=batch_size) loaders = [loaders_o, loaders_v] both_models = [ mvshape.models.encoderdecoder.load_model( '/data/mvshape/out/pytorch/shapenetcore_rgb_mv6/opo/0/models0_00005_0018323_00109115.pth' ), mvshape.models.encoderdecoder.load_model( '/data/mvshape/out/pytorch/shapenetcore_rgb_mv6/vpo/0/models0_00005_0018323_00109115.pth' ), ] exps = ['o', 'v'] # #### TODO # mode = 1 # # loaders = [loaders[mode]] # both_models = [both_models[mode]] # exps = [exps[mode]] # #### counter = 0 for L, M, exp in zip(loaders, both_models, exps): torch_utils.recursive_module_apply(M, lambda m: m.cuda()) torch_utils.recursive_train_setter(M, is_training=False) loader = L while True: next_batch = loader.next() if next_batch is None: print('END ################################') break batch_data_np = mvshape.models.encoderdecoder.prepare_data_rgb_mv( next_batch=next_batch) im = batch_data_np['in_image'] helper_torch_modules = mvshape.models.encoderdecoder.build_helper_torch_modules( ) out = mvshape.models.encoderdecoder.get_final_images_from_model( M, im, helper_torch_modules=helper_torch_modules) masked_depth = out['masked_depth'] recon_basedir = '/data/mvshape/out/pascal3d_recon/' out_basedir = '/data/mvshape/out/pascal3d_figures/' for bi in range(len(next_batch[0])): image_name = path.basename( next_batch[0][bi]['input_rgb']['filename']).split('.')[0] recon_dir = recon_basedir + '/{}/{}/'.format(exp, image_name) # if path.isdir(recon_dir): # print('{} exists. skipping'.format(recon_dir)) # continue eye = next_batch[0][bi]['target_camera']['eye'] up = next_batch[0][bi]['target_camera']['up'] lookat = next_batch[0][bi]['target_camera']['lookat'] Rt_list = mvshape.camera_utils.make_six_views( camera_xyz=eye, object_xyz=lookat, up=up) cams = [ dshin.camera.OrthographicCamera.from_Rt(Rt_list[i], sRt_scale=1.75, wh=(128, 128)) for i in range(len(Rt_list)) ] mv = mvshape.shapes.MVshape(masked_images=masked_depth[bi], cameras=cams) depth_mesh_filenames = glob.glob(recon_dir + 'depth_meshes/*.ply') pcl = [] for item in depth_mesh_filenames: pcl.append(io_utils.read_mesh(item)['v']) pcl = np.concatenate(pcl, axis=0) print(pcl.shape) fig_dir = path.join(out_basedir, '{}/{}/'.format(exp, image_name)) io_utils.ensure_dir_exists(fig_dir) pt.figure(figsize=(5, 5)) ax = pt.gca(projection='3d') color = pcl[:, 0] + 0.6 # +0.6 to force the values to be positive. not necessary. rotmat = transforms.rotation_matrix(angle=45, direction=np.array( (0, 1, 0))) rotmat2 = transforms.rotation_matrix(angle=-30, direction=np.array( (0, 0, 1))) pcl = transforms.apply44(rotmat2.dot(rotmat), pcl) index_array = np.argsort(pcl[:, 0]) pcl = pcl[index_array] color = color[index_array] geom3d.pts(pcl, markersize=45, color=color, zdir='y', show_labels=False, cmap='viridis', cam_sph=(1, 90, 0), ax=ax) ax.axis('off') pt.savefig(fig_dir + 'pcl.png', bbox_inches='tight', transparent=True, pad_inches=0) pt.close() pt.figure(figsize=(5, 5)) ax = pt.gca() geom2d.draw_depth(out['silhouette_prob'][bi], cmap='gray', nan_color=(1.0, 1.0, 1.0), grid=128, grid_width=3, ax=ax, show_colorbar=False, show_colorbar_ticks=False) pt.savefig(fig_dir + '/silhouette.png', bbox_inches='tight', transparent=False, pad_inches=0) pt.close() pt.figure(figsize=(10, 10)) ax = pt.gca() geom2d.draw_depth(masked_depth[bi], cmap='viridis', nan_color=(1.0, 1.0, 1.0), grid=128, grid_width=6, ax=ax, show_colorbar=False, show_colorbar_ticks=False) pt.savefig(fig_dir + '/masked-depth.png', bbox_inches='tight', transparent=False, pad_inches=0) pt.close() rgb_filename = base + next_batch[0][bi]['input_rgb']['filename'] assert path.isfile(rgb_filename) rgb_link_target = fig_dir + '/input.png' if path.islink(rgb_link_target): os.remove(rgb_link_target) os.symlink(rgb_filename, rgb_link_target) print(counter, fig_dir, rgb_filename) counter += 1
def blend_and_resize(params): object_image_filename = params['object_image_filename'] bkg_filenames = params['bkg_filenames'] out_image_filename = params['out_image_filename'] object_image = io_utils.read_png(object_image_filename) resolution = 128 bkg_clutter_ratio = 0.8 scale_max = 4 use_background_image = random.random() < bkg_clutter_ratio resize_method = { 0: 'bilinear', 1: 'bicubic', 2: 'lanczos' }[random.randint(0, 2)] def force_uint8(arr): if arr.dtype in (np.float32, np.float64): arr = (arr * 255).round().astype(np.uint8) assert arr.dtype == np.uint8 return arr def force_float(arr): if arr.dtype == np.uint8: arr = arr.astype(np.float32) / 255.0 elif arr.dtype == np.float64: arr = arr.astype(np.float32) assert arr.dtype == np.float32 return arr def resize(arr: np.ndarray, res): assert arr.shape[0] == arr.shape[1] arr = force_uint8(arr) resized = scipy.misc.imresize(arr, size=(res, res), interp=resize_method) return force_float(resized) # Crop and pad. iy, ix = np.where(object_image)[:2] y0 = np.min(iy) x0 = np.min(ix) y1 = np.max(iy) + 1 x1 = np.max(ix) + 1 src_cropped = object_image[y0:y1, x0:x1] src_cropped_square = make_randomized_square_image(src_cropped) src_s = src_cropped_square.shape[0] target_resolution = min(src_s, resolution) if use_background_image: # Read random file. bkg_image = None while True: bkg_image = _load_image(random.choice(bkg_filenames)) s = min(bkg_image.shape[:2]) if s >= target_resolution: break if bkg_image.ndim == 2: bkg_image = np.tile(bkg_image[:, :, None], (1, 1, 3)) # resize background image. if the iamge is smaller than 128, don't resize. res = random.randint( target_resolution, min(target_resolution * scale_max, min(bkg_image.shape[:2]))) y = random.randint(0, bkg_image.shape[0] - res) x = random.randint(0, bkg_image.shape[1] - res) bkg_cropped = bkg_image[y:y + res, x:x + res] assert bkg_cropped.shape[0] == bkg_cropped.shape[1] assert bkg_cropped.shape[0] == res assert bkg_cropped.shape[2] == 3 bkg = resize(bkg_cropped, res=target_resolution) else: color_gray = random.random() bkg = color_gray * np.ones( (target_resolution, target_resolution, 3), dtype=np.float32) src_image_rgba = resize(src_cropped_square, res=target_resolution) src_image = src_image_rgba[:, :, :3] mask = src_image_rgba[:, :, 3] assert bkg.dtype == np.float32 assert bkg.shape[0] == bkg.shape[1] assert bkg.shape[2] == 3 assert bkg.shape[0] == target_resolution assert src_image.dtype == np.float32 assert src_image.shape == bkg.shape assert mask.dtype == np.float32 blended = ((1.0 - mask)[:, :, None] * bkg) + ( (mask)[:, :, None] * src_image) blended_final = force_uint8(resize(blended, res=resolution)) assert blended_final.shape[0] == blended_final.shape[1] assert blended_final.shape[0] == resolution assert blended_final.dtype == np.uint8 io_utils.ensure_dir_exists(path.dirname(out_image_filename), log_mkdir=False) scipy.misc.imsave(out_image_filename, blended_final) return out_image_filename