def convert_disparity_to_saveable(disparity, normalize=True): # takes in a single disparity map of shape height x width. # can be saved via: imageio.imwrite(filename, convert_disparity_to_saveable(disparity)) converted_disparity = (disparity / np.max(disparity) if normalize else disparity.copy()) converted_disparity = to8b( converted_disparity) # height x width. int values in [0,255]. return converted_disparity
def save_depthmap(filename, depth, H, W): os.makedirs(os.path.dirname(filename), exist_ok=True) disp_map_label = 1 - 1. / torch.max( torch.ones_like(depth).cuda() * 1e-10, depth) disp_map_label = torch.reshape(disp_map_label, (H, W)) rgb8_label = to8b(disp_map_label.cpu().detach().numpy()) imageio.imwrite(filename, rgb8_label)
def update_canvas(self, images, disps=None): for i, image in enumerate(images): resized_rgb = F.interpolate(image.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) self.real_images_array[i].src = renormalize.as_url(resized_rgb) self.real_canvas_array[i].image = renormalize.as_url(image) self.real_canvas_array[i].resized_image = renormalize.as_url( resized_rgb) if disps is not None: disp_img = torch.from_numpy(to8b(to_disp_img( disps[i]))).unsqueeze(dim=0) / 255. resized_disp = F.interpolate(disp_img.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) self.real_canvas_array[i].resized_disp = renormalize.as_url( resized_disp) self.real_canvas_array[i].disp = renormalize.as_url(disp_img) if self.editing_canvas.index >= 0: self.editing_canvas.image = self.real_canvas_array[ self.editing_canvas.index].image
def convert_disparity_to_phong(disparity, normalize=True): converted_disparity = (disparity / np.max(disparity) if normalize else disparity.copy()) converted_disparity = to8b( visualize_disparity_with_blinn_phong(converted_disparity)) return converted_disparity # height x width x 3. int values in [0,255].
def convert_disparity_to_jet(disparity, normalize=True): converted_disparity = (disparity / np.max(disparity) if normalize else disparity.copy()) converted_disparity = to8b( visualize_disparity_with_jet_color_scheme(converted_disparity)) return converted_disparity # height x width x 3. int values in [0,255].
def convert_rgb_to_saveable(rgb): # input: float values in [0,1] # output: int values in [0,255] return to8b(rgb)
def free_viewpoint_rendering(args): # memory vs. speed and quality frames_at_a_time = 10 # set to 1 to reduce memory requirements only_rgb = False # set to True to reduce memory requirements. Needs to be False for some scene editing to work. # determine output name if args.camera_path == "spiral": output_name = args.deformations + "_" + args.camera_path elif args.camera_path == "fixed": output_name = (args.deformations + "_" + args.camera_path + "_" + str(args.fixed_view)) elif args.camera_path == "input_reconstruction": output_name = args.deformations + "_" + args.camera_path else: raise RuntimeError("invalid --camera_path argument") if args.forced_background_stabilization is not None: output_name += "_fbs_" + str(args.forced_background_stabilization) if args.motion_factor is not None: output_name += "_exaggeration_" + str(args.motion_factor) if args.foreground_removal is not None: output_name += "_removal_" + str(args.foreground_removal) if args.render_canonical: output_name += "_canonical" output_folder = os.path.join(args.input, "output", output_name) create_folder(output_folder) # load Nerf network ( render_kwargs_train, render_kwargs_test, start, grad_vars, load_weights_into_network, checkpoint_dict, get_training_ray_bending_latents, load_llff_dataset, raw_render_path, render_convenient, convert_rgb_to_saveable, convert_disparity_to_saveable, convert_disparity_to_jet, convert_disparity_to_phong, store_ray_bending_mesh_visualization, to8b, ) = _setup_nonrigid_nerf_network(args.input) print("sucessfully loaded nerf network", flush=True) # load dataset ray_bending_latents = (get_training_ray_bending_latents() ) # shape: frames x latent_size ( images, poses, all_rotations, all_translations, bds, render_poses, render_rotations, render_translations, i_train, i_val, i_test, near, far, dataset_extras, ) = load_llff_dataset(render_kwargs_train_=render_kwargs_train, render_kwargs_test_=render_kwargs_test ) # load dataset that this nerf was trained on print("sucessfully loaded dataset", flush=True) # determine subset if args.deformations == "train": indices = i_train poses = poses[i_train] ray_bending_latents = ray_bending_latents[i_train] images = images[i_train] print("rendering training set") elif args.deformations == "test": indices = i_test poses = poses[i_test] ray_bending_latents = ray_bending_latents[i_test] images = images[i_test] print("rendering test set") elif args.deformations == "all": print("rendering training and test set") else: raise RuntimeError("invalid --deformations argument") copy_over_groundtruth_images = False if copy_over_groundtruth_images: groundtruth_images_folder = os.path.join(output_folder, "groundtruth") create_folder(groundtruth_images_folder) for i, rgb in enumerate(images): rgb = convert_rgb_to_saveable(rgb) file_prefix = os.path.join(groundtruth_images_folder, str(i).zfill(6)) imageio.imwrite(file_prefix + ".png", rgb) # determine camera poses and latent codes num_poses = poses.shape[0] intrinsics = dataset_extras["intrinsics"] if args.camera_path == "input_reconstruction": poses = poses intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][i]] for i in range(num_poses) ] elif args.camera_path == "fixed": poses = torch.stack( [torch.Tensor(poses[args.fixed_view]) for _ in range(num_poses)], 0) # N x 3 x 4 intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][args.fixed_view]] for _ in range(num_poses) ] elif args.camera_path == "spiral": # poses = np.stack(_spiral_poses(poses, bds, num_poses), axis=0) poses = [] while len(poses) < num_poses: poses += [render_pose for render_pose in render_poses] poses = np.stack(poses, axis=0)[:num_poses] intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][0]] for _ in range(num_poses) ] else: # poses has shape N x ... and ray_bending_latents has shape N x ... # Can design custom camera paths here. # poses is indexed with imageid # ray_bending_latents is indexed with timestepid # intrinsics is indexed with viewid # images is indexed with imageid raise RuntimeError # example with time interpolation from a fixed camera view num_target_frames = 500 latent_indices = np.linspace(0, ray_bending_latents.shape[0] - 1, num=num_target_frames) start_indices = np.floor(latent_indices).astype(np.int) end_indices = np.ceil(latent_indices).astype(np.int) start_latents = ray_bending_latents[ start_indices] # num_target_frames x latent_size end_latents = ray_bending_latents[ end_indices] # num_target_frames x latent_size interpolation_factors = latent_indices - start_indices # shape: num_target_frames. should be in [0,1] interpolation_factors = torch.Tensor(interpolation_factors).reshape( -1, 1) # num_target_frames x 1 ray_bending_latents = end_latents * interpolation_factors + start_latents * ( 1. - interpolation_factors) fixed_camera = 0 poses = torch.stack([ torch.Tensor(poses[fixed_camera]) for _ in range(num_target_frames) ], 0) # N x 3 x 4 intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][fixed_camera]] for _ in range(num_target_frames) ] latents = ray_bending_latents latents = latents.detach().cuda() # rendering correspondence_rgbs = [] rigidities = [] rgbs = [] disps = [] num_output_frames = poses.shape[0] for start_index in range(0, num_output_frames, frames_at_a_time): end_index = np.min([start_index + frames_at_a_time, num_output_frames]) print( "rendering " + str(start_index) + " to " + str(end_index) + " out of " + str(num_output_frames), flush=True, ) subposes = poses[start_index:end_index] sublatents = [latents[i] for i in range(start_index, end_index)] # render returned = render_convenient( poses=subposes, ray_bending_latents=sublatents, intrinsics=intrinsics, with_ray_bending=not args.render_canonical, detailed_output=not only_rgb, rigidity_test_time_cutoff=args.forced_background_stabilization, motion_factor=args.motion_factor, foreground_removal=args.foreground_removal) if only_rgb: subrgbs, subdisps = returned else: subrgbs, subdisps, details_and_rest = returned print("finished rendering", flush=True) rgbs += [image for image in subrgbs] disps += [image for image in subdisps] if only_rgb: correspondence_rgbs += [None for _ in subrgbs] rigidities += [None for _ in subrgbs] continue # determine correspondences # details_and_rest: list, one entry per image. each image has first two dimensions height x width. min_point = np.array( checkpoint_dict["scripts_dict"]["min_nerf_volume_point"]).reshape( 1, 1, 3) max_point = np.array( checkpoint_dict["scripts_dict"]["max_nerf_volume_point"]).reshape( 1, 1, 3) for i, image_details in enumerate(details_and_rest): # visibility_weight is the weight of the influence that each sample has on the final rgb value. so they sum to at most 1. accumulated_visibility = torch.cumsum( torch.Tensor(image_details["fine_visibility_weights"]).cuda(), dim=-1) # height x width x point samples median_indices = torch.min( torch.abs(accumulated_visibility - 0.5), dim=-1 )[1] # height x width. visibility goes from 0 to 1. 0.5 is the median, so treat it as "most likely to be on the actually visible surface" # visualize canonical correspondences as RGB height, width = median_indices.shape surface_pixels = ( image_details["fine_input_pts"].reshape( height * width, -1, 3)[np.arange(height * width), median_indices.cpu().reshape(-1), :].reshape( height, width, 3) ) # height x width x 3. median_indices contains the index of one ray sample for each pixel. this ray sample is selected in this line of code. correspondence_rgb = (surface_pixels - min_point) / (max_point - min_point) number_of_small_rgb_voxels = 100 # break the canonical space into smaller voxels. each voxel covers the entire RGB space [0,1]^3. makes it easier to visualize small changes. leads to a 3D checkerboard pattern if number_of_small_rgb_voxels > 1: correspondence_rgb *= number_of_small_rgb_voxels correspondence_rgb = correspondence_rgb - correspondence_rgb.astype( int) correspondence_rgbs.append(correspondence_rgb) # visualize rigidity if "fine_rigidity_mask" in image_details: rigidity = (image_details["fine_rigidity_mask"].reshape( height * width, -1)[np.arange(height * width), median_indices.cpu().reshape(-1)].reshape( height, width)) # height x width. values in [0,1] rigidities.append(rigidity) else: rigidities.append(None) rgbs = np.stack(rgbs, axis=0) disps = np.stack(disps, axis=0) correspondence_rgbs = np.stack(correspondence_rgbs, axis=0) use_rigidity = rigidities[0] is not None # store results # for i, (rgb, disp, correspondence_rgb, rigidity) in zip(indices, (zip(rgbs, disps, correspondence_rgbs, rigidities))): for i, (rgb, disp, correspondence_rgb, rigidity) in enumerate( zip(rgbs, disps, correspondence_rgbs, rigidities)): print("storing image " + str(i) + " / " + str(rgbs.shape[0]), flush=True) rgb = convert_rgb_to_saveable(rgb) disp_saveable = convert_disparity_to_saveable(disp) disp_jet = convert_disparity_to_jet(disp) disp_phong = convert_disparity_to_phong(disp) if not only_rgb: correspondence_rgb = convert_rgb_to_saveable(correspondence_rgb) if use_rigidity: rigidity_saveable = convert_disparity_to_saveable(rigidity, normalize=False) rigidity_jet = convert_disparity_to_jet(rigidity, normalize=False) file_postfix = "_" + str(i).zfill(6) + ".png" imageio.imwrite(os.path.join(output_folder, "rgb" + file_postfix), rgb) if not only_rgb: imageio.imwrite( os.path.join(output_folder, "correspondences" + file_postfix), correspondence_rgb, ) if use_rigidity: imageio.imwrite( os.path.join(output_folder, "rigidity" + file_postfix), rigidity_saveable, ) imageio.imwrite( os.path.join(output_folder, "rigidity_jet" + file_postfix), rigidity_jet) imageio.imwrite(os.path.join(output_folder, "disp" + file_postfix), disp_saveable) imageio.imwrite(os.path.join(output_folder, "disp_jet" + file_postfix), disp_jet) imageio.imwrite( os.path.join(output_folder, "disp_phong" + file_postfix), disp_phong) # movies file_prefix = os.path.join(output_folder, "video_") try: print("storing RGB video...", flush=True) imageio.mimwrite( file_prefix + "rgb.mp4", convert_rgb_to_saveable(rgbs), fps=args.output_video_fps, quality=9, ) if not only_rgb: print("storing correspondence RGB video...", flush=True) imageio.mimwrite( file_prefix + "correspondences.mp4", convert_rgb_to_saveable(correspondence_rgbs), fps=args.output_video_fps, quality=9, ) print("storing disparity video...", flush=True) imageio.mimwrite( file_prefix + "disp.mp4", convert_disparity_to_saveable(disps), fps=args.output_video_fps, quality=9, ) print("storing disparity jet video...", flush=True) imageio.mimwrite( file_prefix + "disp_jet.mp4", np.stack([convert_disparity_to_jet(disp) for disp in disps], axis=0), fps=args.output_video_fps, quality=9, ) print("storing disparity phong video...", flush=True) imageio.mimwrite( file_prefix + "disp_phong.mp4", np.stack([convert_disparity_to_phong(disp) for disp in disps], axis=0), fps=args.output_video_fps, quality=9, ) if use_rigidity: rigidities = np.stack(rigidities, axis=0) print("storing rigidity video...", flush=True) imageio.mimwrite( file_prefix + "rigidity.mp4", convert_disparity_to_saveable(rigidities, normalize=False), fps=args.output_video_fps, quality=9, ) print("storing rigidity jet video...", flush=True) imageio.mimwrite( file_prefix + "rigidity_jet.mp4", np.stack( [ convert_disparity_to_jet(rigidity, normalize=False) for rigidity in rigidities ], axis=0, ), fps=args.output_video_fps, quality=9, ) except: print( "imageio.mimwrite() failed. maybe ffmpeg is not installed properly?" ) # evaluation of background stability if args.camera_path == "fixed": standard_deviations = np.std(rgbs, axis=0) averaged_standard_devations = 10 * np.mean(standard_deviations, axis=-1) from matplotlib import cm color_mapping = np.array([cm.jet(i)[:3] for i in range(256)]) max_value = 1 min_value = 0 averaged_standard_devations = np.clip( averaged_standard_devations, a_max=max_value, a_min=min_value ) / max_value # cut off above max_value. result is normalized to [0,1] averaged_standard_devations = ( 255. * averaged_standard_devations).astype( 'uint8') # now contains int in [0,255] original_shape = averaged_standard_devations.shape averaged_standard_devations = color_mapping[ averaged_standard_devations.flatten()] averaged_standard_devations = averaged_standard_devations.reshape( original_shape + (3, )) imageio.imwrite(os.path.join(output_folder, "standard_deviations.png"), averaged_standard_devations) # quantitative evaluation if args.camera_path == "input_reconstruction": try: from PerceptualSimilarity import lpips perceptual_metric = lpips.LPIPS(net='alex') except: print( "Perceptual LPIPS metric not found. Please see the README for installation instructions" ) perceptual_metric = None create_error_maps = True # whether to write out error images instead of just computing scores naive_error_folder = os.path.join(output_folder, "naive_errors/") create_folder(naive_error_folder) ssim_error_folder = os.path.join(output_folder, "ssim_errors/") create_folder(ssim_error_folder) to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) def visualize_with_jet_color_scheme(image): from matplotlib import cm color_mapping = np.array([cm.jet(i)[:3] for i in range(256)]) max_value = 1.0 min_value = 0.0 intermediate = np.clip( image, a_max=max_value, a_min=min_value ) / max_value # cut off above max_value. result is normalized to [0,1] intermediate = (255. * intermediate).astype( 'uint8') # now contains int in [0,255] original_shape = intermediate.shape intermediate = color_mapping[intermediate.flatten()] intermediate = intermediate.reshape(original_shape + (3, )) return intermediate mask = None scores = {} from skimage.metrics import structural_similarity as ssim for i, (groundtruth, generated) in enumerate(zip(images, rgbs)): if mask is None: # undistortion leads to masked-out black pixels in groundtruth mask = (np.sum(groundtruth, axis=-1) == 0.) groundtruth[mask] = 0. generated[mask] = 0. # PSNR mse = np.mean((groundtruth - generated)**2) psnr = -10. * np.log10(mse) # SSIM # https://scikit-image.org/docs/dev/api/skimage.metrics.html#skimage.metrics.structural_similarity returned = ssim(groundtruth, generated, data_range=1.0, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, full=create_error_maps) if create_error_maps: ssim_error, ssim_error_image = returned else: ssim_error = returned # perceptual metric if perceptual_metric is None: lpips = 1. else: def numpy_to_pytorch(np_image): torch_image = 2 * torch.from_numpy( np_image) - 1 # height x width x 3. must be in [-1,+1] torch_image = torch_image.permute(2, 0, 1) # 3 x height x width return torch_image.unsqueeze(0) # 1 x 3 x height x width lpips = perceptual_metric.forward( numpy_to_pytorch(groundtruth), numpy_to_pytorch(generated)) lpips = float(lpips.detach().reshape(1).numpy()[0]) scores[i] = {"psnr": psnr, "ssim": ssim_error, "lpips": lpips} if create_error_maps: # MSE-style error = np.linalg.norm(groundtruth - generated, axis=-1) / np.sqrt( 1 + 1 + 1) # height x width error *= 10. # exaggarate error error = np.clip(error, 0.0, 1.0) error = to8b(visualize_with_jet_color_scheme( error)) # height x width x 3. int values in [0,255] filename = os.path.join(naive_error_folder, 'error_{:03d}.png'.format(i)) imageio.imwrite(filename, error) # SSIM filename = os.path.join(ssim_error_folder, 'error_{:03d}.png'.format(i)) ssim_error_image = to8b( visualize_with_jet_color_scheme( 1. - np.mean(ssim_error_image, axis=-1))) imageio.imwrite(filename, ssim_error_image) averaged_scores = {} averaged_scores["average_psnr"] = np.mean( [score["psnr"] for score in scores.values()]) averaged_scores["average_ssim"] = np.mean( [score["ssim"] for score in scores.values()]) averaged_scores["average_lpips"] = np.mean( [score["lpips"] for score in scores.values()]) print(averaged_scores, flush=True) scores.update(averaged_scores) import json with open(os.path.join(output_folder, "scores.json"), "w") as json_file: json.dump(scores, json_file, indent=4)
def render_path(render_poses, styles, hwfs, chunk, render_kwargs, nfs=None, gt_imgs=None, alpha_cache=None, feature_cache=None, weights_cache=None, savedir=None, maximum=1000, get_cached=None, get_rgbs=False, verbose=True, cb=None, update_cb=None): render_kwargs['network_fine'].get_cached = get_cached rgbs = [] disps = [] alphas = [] features = [] weights = [] total_psnr = 0. total_psnr0 = 0 N = len(render_poses) s = N // maximum if len(render_poses) > maximum else 1 if gt_imgs is not None: gt_imgs = gt_imgs[::s].cuda() render_poses = render_poses[::s].cuda() styles = styles[::s].cuda() hwfs = hwfs[::s].cuda() iterator = zip(render_poses, styles) if verbose: iterator = tqdm(iterator, total=len(styles)) for i, (c2w, style) in enumerate(iterator): if cb is not None: cb(i) H, W, focal = hwfs[i] H, W = int(H), int(W) if nfs is not None: near, far = nfs[i] render_kwargs.update({'near': near, 'far': far}) alpha = alpha_cache[i] if alpha_cache is not None else None feature = feature_cache[i] if feature_cache is not None else None weight = weights_cache[i] if weights_cache is not None else None rgb, disp, acc, additional = render(H, W, focal, style=style, chunk=chunk, weights=weight, c2w=c2w[:3, :4], alpha=alpha, feature=feature, **render_kwargs) if 'rgb0' in additional: rgb0 = additional['rgb0'] if gt_imgs is not None: gt_img = gt_imgs[i] mse_loss = img2mse(rgb, gt_img) psnr = mse2psnr(mse_loss) total_psnr += psnr.item() if 'rgb0' in additional: mse_loss0 = img2mse(rgb0, gt_img) psnr0 = mse2psnr(mse_loss0) total_psnr0 += psnr0.item() rgbs.append(rgb.cpu().numpy()) disps.append(disp.cpu().numpy()) if update_cb: update_cb(i, rgbs[-1]) if get_cached: alphas.append(additional['raw_alpha']) features.append(additional['features']) weights.append(additional['weights0']) if savedir is not None: rgb8 = to8b(rgbs[-1]) disp8 = to8b(to_disp_img(disps[-1])) imageio.imwrite(os.path.join(savedir, '{:04d}_rgb.png'.format(i)), rgb8) imageio.imwrite(os.path.join(savedir, '{:04d}_disp.png'.format(i)), disp8) if gt_imgs is not None: gt_img = to8b((gt_imgs[i]).cpu().numpy()) imageio.imwrite( os.path.join(savedir, '{:04d}_gt.png'.format(i)), gt_img) if gt_imgs is not None: with open(os.path.join(savedir, 'log.txt'), 'a+') as f: torchvision.utils.save_image( torch.tensor(rgbs).cpu().permute(0, 3, 1, 2), 'rgbs.png') torchvision.utils.save_image( torch.tensor(gt_imgs).cpu().permute(0, 3, 1, 2), 'gt.png') msg = f'psnr0, psnr1, {total_psnr0/len(render_poses)}, {total_psnr/len(render_poses)}' f.write(msg + '\n') print(msg) total_psnr = total_psnr / len(rgbs) rgbs = np.stack(rgbs, 0) disps = np.stack(disps, 0) if get_cached: render_kwargs['network_fine'].get_cached = None return rgbs, disps, total_psnr, alphas, features, weights else: return rgbs, disps, total_psnr
def log_qualitative_results(writer, metalearning_iter, scene_id, save_dir, render_poses, poses, i_split, hwf, chunk, render_kwargs_test, images, N_importance, timesteps, use_viewdirs=True, render_test_set=False): H, W, focal = hwf testsavedir = os.path.join( save_dir, 'testset_iter/{:06d}/scene{}'.format(metalearning_iter, scene_id)) os.makedirs(testsavedir, exist_ok=True) split_timesteps = [] for i in i_split: split_timesteps.extend([timesteps[i]] * H * W) split_timesteps = np.asarray(split_timesteps) render_path(poses[i_split], hwf, split_timesteps, chunk, render_kwargs_test, gt_imgs=images[i_split], savedir=testsavedir) # Log a rendered validation view to Tensorboard img_i = np.random.choice(i_split) target = images[img_i] pose = poses[img_i, :3, :4] split_timestep = [timesteps[img_i]] * (H * W) rgb, disp, acc, extras = render(H, W, focal, split_timestep, chunk=chunk, c2w=pose, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) # Save out the validation image for Tensorboard-free monitoring testimgdir = os.path.join(save_dir, 'tboard_val_imgs') if not os.path.exists(testimgdir): os.makedirs(testimgdir, exist_ok=True) imageio.imwrite( os.path.join(testimgdir, '{:06d}_{}.png'.format(metalearning_iter, scene_id)), to8b(rgb)) writer.add_image(f'rgb/{scene_id}', np.squeeze(to8b(rgb)[tf.newaxis], axis=0), metalearning_iter, dataformats='HWC') writer.add_image(f'disp/{scene_id}', np.squeeze(disp[tf.newaxis, ..., tf.newaxis], axis=0), metalearning_iter, dataformats='HWC') writer.add_image(f'acc/{scene_id}', np.squeeze(acc[tf.newaxis, ..., tf.newaxis], axis=0), metalearning_iter, dataformats='HWC') writer.add_image(f'rgb_holdout/{scene_id}', np.squeeze(target[tf.newaxis], axis=0), metalearning_iter, dataformats='HWC') if N_importance > 0: writer.add_image(f'rgb0/{scene_id}', np.squeeze(to8b(extras['rgb0'])[tf.newaxis], axis=0), metalearning_iter, dataformats='HWC') writer.add_image(f'disp0/{scene_id}', np.squeeze(extras['disp0'][tf.newaxis, ..., tf.newaxis], axis=0), metalearning_iter, dataformats='HWC') writer.add_image(f'z_std/{scene_id}', np.squeeze(extras['z_std'][tf.newaxis, ..., tf.newaxis], axis=0), metalearning_iter, dataformats='HWC') # Save videos if render_test_set: sorted_timesteps = sorted(list(set(timesteps))) rgbs, disps = render_timesteps(poses[i_split[0]], hwf, sorted_timesteps, chunk, render_kwargs_test, savedir=testsavedir) moviebase = os.path.join( save_dir, '{}_temporal_{:06d}_'.format(scene_id, metalearning_iter)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if use_viewdirs: render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4] rgbs_still, _ = render_path(render_poses, hwf, sorted_timesteps, chunk, render_kwargs_test) render_kwargs_test['c2w_staticcam'] = None imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)
def __init__(self, instance, config, use_cached=True, expname=None, edit_type=None, num_canvases=9, shape_params='fusion_shape_branch', color_params='color_branch', randneg=8192, device='cuda:0'): super().__init__(style=dict( border="3px solid gray", padding="8px", display="inline-block")) torch.set_default_tensor_type('torch.cuda.FloatTensor' if device == 'cuda:0' else 'cpu') self.edit_type = edit_type self.instance = instance self.num_canvases = num_canvases self.shape_params = shape_params self.color_params = color_params self.size = IMG_SIZE self.randneg = randneg self.device = device self.msg_out = labwidget.Div() self.editing_canvas = paintwidget.PaintWidget(image='', width=self.size * 3, height=self.size * 3).on( 'mask', self.change_mask) self.editing_canvas.index = -1 self.copy_canvas = paintwidget.PaintWidget(image='', width=self.size * 2, height=self.size * 2).on( 'mask', self.copy) self.copy_mask = None inline = dict(display='inline', border="2px solid gray") self.toggle_rgbs_disps_btn = labwidget.Button( 'show depth', style=inline).on('click', self.toggle_rgb_disps) self.positive_mask_btn = labwidget.Button(self.pad('edit color'), style=inline).on( 'click', self.positive_mask) self.addition_mask_btn = labwidget.Button(self.pad('add shape'), style=inline).on( 'click', self.add) self.sigma_mask_btn = labwidget.Button(self.pad('remove shape'), style=inline).on( 'click', self.sigma_mask) self.color_from_btn = labwidget.Button(self.pad('transfer color'), style=inline).on( 'click', self.color_from) self.shape_from_btn = labwidget.Button(self.pad('transfer shape'), style=inline).on( 'click', self.shape_from) self.execute_btn = labwidget.Button(self.pad('execute'), style=inline).on( 'click', self.execute_edit) self.brushsize_textbox = labwidget.Textbox(5, desc='brushsize: ', size=3).on( 'value', self.change_brushsize) self.target = None self.use_color_cache = True self.color_style = dict(display='inline', border="2px solid white") trn = transforms.Compose( [transforms.Resize(32), transforms.ToTensor()]) bg_img = trn(Image.open('bg.png').convert('RGB')) bg_img = renormalize.as_url(bg_img * 2 - 1) self.color_pallete = [ labwidget.Image(src=bg_img, style=self.color_style).on('click', self.set_color) ] self.color_pallete[-1].index = 0 self.color_pallete[-1].color_type = 'bg' for color in mean_colors.colors.values(): image = torch.zeros(3, 32, 32) image[0, :, :] = color[0] image[1, :, :] = color[1] image[2, :, :] = color[2] image = image / 255. * 2 - 1 self.color_pallete.append( labwidget.Image(src=renormalize.as_url(image), style=self.color_style).on( 'click', self.set_color)) self.color_pallete[-1].index = len(self.color_pallete) - 1 self.color_pallete[-1].color_type = 'color' # TODO: Highlight the white box with black for clarity self.color = None self.mask_type = None self.real_canvas_array = [] self.real_images_array = [] self.positive_masks = [] train, test, optimizer, styles = load_model(instance, config, expname=expname) poses, hwfs, cache, args = load_dataset(instance, config, num_canvases=num_canvases, N_instances=styles.shape[0], expname=expname, use_cached=use_cached) self.parentdir = load_config(config).expname self.expname = expname if expname else self.parentdir self.savedir = os.path.join(self.expname, str(instance)) os.makedirs(self.savedir, exist_ok=True) self.poses = poses.to(device) self.cache = cache self.chunk = args.chunk self.near = args.blender_near self.far = args.blender_far self.nfs = [[self.near, self.far]] * self.poses.shape[0] self.hwfs = hwfs.to(device) self.old_fine_network = dict( copy.deepcopy(test['network_fine']).named_parameters()) self.train_kwargs = train self.test_kwargs = test self.optimizer = None self.all_instance_styles = styles self.instance_style = styles[instance].unsqueeze(dim=0).to(device) if cache is not None: self.weights = cache['weights'] self.alphas = cache['alphas'] self.features = cache['features'] else: self.weights = None self.alphas = None self.features = None self.trn = transforms.Compose( [transforms.Resize(128), transforms.ToTensor()]) self.transfer_instances_array = [ labwidget.Image(src='').on('click', self.change_target) for _ in range(12) ] self.addition_instances_array = [ labwidget.Image(src='').on('click', self.change_target) for _ in range(12) ] images, disps = self.render(self.poses, self.instance_style, verbose=False, get_disps=True) for i, image in enumerate(images): resized = F.interpolate(image.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) disp_img = torch.from_numpy(to8b(to_disp_img( disps[i]))).unsqueeze(dim=0) / 255. resized_disp = F.interpolate(disp_img.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) self.real_images_array.append( labwidget.Image(src=renormalize.as_url(resized)).on( 'click', self.set_editing_canvas)) self.real_images_array[-1].index = i self.real_canvas_array.append( paintwidget.PaintWidget(image=renormalize.as_url(image), width=self.size * 3, height=self.size * 3).on( 'mask', self.change_mask)) self.real_canvas_array[-1].index = i self.real_canvas_array[-1].negative_mask = '' self.real_canvas_array[-1].resized_image = renormalize.as_url( resized) self.real_canvas_array[-1].resized_disp = renormalize.as_url( resized_disp) self.real_canvas_array[-1].disp = renormalize.as_url(disp_img) self.real_canvas_array[-1].orig = renormalize.as_url(image) self.positive_masks.append(torch.zeros(image.shape).cpu()) self.show_rgbs = True self.change_brushsize() self.editname_textbox = labwidget.Datalist(choices=self.saved_names(), style=inline) self.save_btn = labwidget.Button('save', style=inline).on('click', self.save) self.load_btn = labwidget.Button('load', style=inline).on('click', self.load)
1) * N_per_instance all_poses = all_poses[ps:pe] if args.video: all_poses, all_hwfs = generate_flythrough(all_poses[0].cpu(), all_hwfs[0], num_poses=100) nfs = [[writer.near, writer.far]] * all_poses.shape[0] styles = writer.instance_style.repeat((all_poses.shape[0], 1)) with torch.no_grad(): print(f'Saving samples in {savedir}') rgbs, disps, psnr = render_path(all_poses, styles, all_hwfs, writer.chunk, writer.test_kwargs, nfs=nfs, savedir=savedir, verbose=True) if args.video: imageio.mimwrite(os.path.join(savedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) imageio.mimwrite(os.path.join(savedir, 'disps.mp4'), to8b(disps / np.max(disps)), fps=30, quality=8)