def get_images_train(): dataset_dir = os.path.join(LLFF_DIR, DATASET) image_dir = os.path.join(dataset_dir, 'images') _, poses, bds, _, _ = load_llff_data(dataset_dir, factor=None) # split into train data and validation data validation_ids = np.arange(poses.shape[0]) validation_ids[::8] = -1 #validation every [0,7, ... ] validation_ids = validation_ids < 0 # pick only pose in train images_path = [ os.path.join('images', f) for f in sorted(os.listdir(image_dir)) ] images_train = [] images_valid = [] for image_id in range(poses.shape[0]): R = poses[image_id][:3, :3] t = poses[image_id][:3, 3].reshape([3, 1]) # LLFF need to inverse rotation and translation to match our format R[:3, 0] *= -1 R = np.transpose(R) t = np.matmul(R, t) img_obj = { "path": images_path[image_id], "center": -R @ t, "depth": bds[image_id] } if not validation_ids[image_id]: images_train.append(img_obj) else: images_valid.append(img_obj) return images_train
def get_images_data(dataset, split_val=8): dataset_dir = os.path.join(LLFF_DIR, dataset) image_dir = os.path.join(dataset_dir, 'images') _, poses, bds, _, _ = load_llff_data(dataset_dir, factor=None) # split into train data and validation data validation_ids = np.arange(poses.shape[0]) validation_ids[::split_val] = -1 #validation every [0,7, ... ] validation_ids = validation_ids < 0 # pick only pose in train images_path = [ 'images/{}'.format(f) for f in sorted(os.listdir(image_dir)) ] images_train = [] images_valid = [] for image_id in range(poses.shape[0]): R = poses[image_id][:3, :3] t = poses[image_id][:3, 3].reshape([3, 1]) # LLFF need to inverse rotation and translation to match our format R[:3, 0] *= -1 R = np.transpose(R) t = np.matmul(R, t) H = poses[image_id, 0, -1] W = poses[image_id, 1, -1] focal = poses[image_id, 2, -1] img_obj = { "path": images_path[image_id], "r": R, "t": t, "R": R.T, "center": -R @ t, "planes": bds[image_id], "camera": { "width": int(W), "height": int(H), "fx": float(focal), "fy": float(focal), "px": float(W / 2.0), "py": float(H / 2.0) }, } if not validation_ids[image_id]: images_train.append(img_obj) else: images_valid.append(img_obj) return images_train, images_valid
def train(): parser = config_parser() args = parser.parse_args() if args.random_seed is not None: print('Fixing random seed', args.random_seed) np.random.seed(args.random_seed) tf.compat.v1.set_random_seed(args.random_seed) # Load data if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify, rgba=True) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) print('DEFINING BOUNDS') if args.no_ndc: near = tf.reduce_min(bds) * .9 far = tf.reduce_max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if args.render_test: render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf( args) bds_dict = { 'near': tf.cast(near, tf.float32), 'far': tf.cast(far, tf.float32), } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None testsavedir = os.path.join( basedir, expname, 'renderonly_{}_{:06d}'.format( 'test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Create optimizer lrate = args.lrate if args.lrate_decay > 0: lrate = tf.keras.optimizers.schedules.ExponentialDecay( lrate, decay_steps=args.lrate_decay * 1000, decay_rate=0.1) optimizer = tf.keras.optimizers.Adam(lrate) models['optimizer'] = optimizer global_step = tf.compat.v1.train.get_or_create_global_step() global_step.assign(start) # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: print('get rays') rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]] rays = np.stack(rays, axis=0) # [N, ro+rd, H, W, 3] print('done, concats') rays = np.transpose(rays, [0, 2, 3, 1, 4]) rays = np.reshape(rays, list(rays.shape[:3]) + [6]) rays_rgba = np.concatenate([rays, images], -1) rays_rgba = np.stack([rays_rgba[i] for i in i_train], axis=0) # train images only rays_rgba = np.reshape(rays_rgba, [-1, 10]) rays_rgba = rays_rgba.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgba) print('done') i_batch = 0 N_iters = 1000000 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers writer = tf.summary.create_file_writer( os.path.join(basedir, 'summaries', expname)) for i in range(start, N_iters): time0 = time.time() batch = rays_rgba[i_batch:i_batch + N_rand] batch_rays = batch[:, :6] target_rgba = batch[:, 6:] batch_rays = np.reshape(batch_rays, [-1, 2, 3]) batch_rays = np.transpose(batch_rays, [1, 0, 2]) i_batch += N_rand if i_batch >= rays_rgba.shape[0]: np.random.shuffle(rays_rgba) i_batch = 0 ##### Core optimization loop ##### with tf.GradientTape() as tape: # Make predictions for color, disparity, accumulated opacity. rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) # Compute MSE loss between predicted and true RGBA. rgba = to_rgba(rgb, acc) loss = img2mse(rgba, target_rgba) trans = extras['raw'][..., -1] psnr = mse2psnr(loss) # Add MSE loss for coarse-grained model if 'rgb0' in extras: rgba0 = to_rgba(extras['rgb0'], extras['acc0']) loss0 = img2mse(rgba0, target_rgba) loss += loss0 psnr0 = mse2psnr(loss0) img_loss = img2mse(rgb, target_rgba[..., :3]) img_psnr = mse2psnr(img_loss) tf.stop_gradient(img_loss) tf.stop_gradient(img_psnr) gradients = tape.gradient(loss, grad_vars) optimizer.apply_gradients(zip(gradients, grad_vars)) dt = time.time() - time0 ##### end ##### # Rest is logging def save_weights(net, prefix, i): path = os.path.join(basedir, expname, '{}_{:06d}.npy'.format(prefix, i)) np.save(path, net.get_weights()) print('saved weights at', path) if i % args.i_weights == 0: for k in models: save_weights(models[k], k, i) if i % args.i_video == 0 and i > 0: rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test) print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if args.use_viewdirs: render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4] rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) render_kwargs_test['c2w_staticcam'] = None imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) if i % args.i_testset == 0 and i > 0: testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[i_test].shape) render_path(poses[i_test], hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) print('Saved test set') if i % args.i_print == 0 or i < 10: print(expname, i, psnr.numpy(), loss.numpy(), [img_loss.numpy(), img_psnr.numpy()], global_step.numpy()) print('iter time {:.05f}'.format(dt)) with writer.as_default(): tf.summary.scalar('loss', loss, step=i + 1) tf.summary.scalar('psnr', psnr, step=i + 1) tf.summary.histogram('tran', trans, step=i + 1) if args.N_importance > 0: tf.summary.scalar('psnr0', psnr0, step=i + 1) if i % args.i_img == 0: # Log a rendered validation view to Tensorboard img_i = np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3, :4] rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, **render_kwargs_test) rgba = to_rgba(rgb, acc) psnr = mse2psnr(img2mse(rgba, target)) # Save out the validation image for Tensorboard-free monitoring testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs') if i == 0: os.makedirs(testimgdir, exist_ok=True) imageio.imwrite( os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(rgba)) imageio.imwrite( os.path.join(testimgdir, '{:06d}_acc.png'.format(i)), to8b(acc)) with writer.as_default(): tf.summary.image('rgba', to8b(rgba)[tf.newaxis], step=i + 1) tf.summary.image('disp', disp[tf.newaxis, ..., tf.newaxis], step=i + 1) tf.summary.image('acc', acc[tf.newaxis, ..., tf.newaxis], step=i + 1) tf.summary.scalar('psnr_holdout', psnr, step=i + 1) tf.summary.image('rgb_holdout', target[tf.newaxis], step=i + 1) if args.N_importance > 0: with writer.as_default(): tf.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis], step=i + 1) tf.summary.image('disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis], step=i + 1) tf.summary.image('z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis], step=i + 1) global_step.assign_add(1)
I = I.unsqueeze(0).to(device) if args.gpu != 'cpu' else I.unsqueeze(0) model.eval() with torch.no_grad(): start_time = time.time() output, _ = model(I) # output size: 1 x 21 x 64(H) x 64(W) # print('Inference time: {:.4f} s'.format(time.time()-start_time)) kps_pred_np = get_final_preds( output, use_softmax=cfg.MODEL.HEATMAP_SOFTMAX).cpu().numpy().squeeze() return kps_pred_np # images: 114 x H x W x 3 # poses(c2w and intrinsic): 114 x 3 x 5 images, poses, bds, render_poses, i_test = load_llff_data(base_dir, factor=3, recenter=True) import pickle fname = './pose2d_pred.txt' try: print('Load') with open(fname, 'rb') as f: pts = pickle.load(f) except: print('Failed') pts = [] color_lower = (80, 45, 30) color_upper = (120, 190, 180) images = (images * 255).astype(np.uint8)
def train(): parser = config_parser() args = parser.parse_args() # Multi-GPU args.n_gpus = torch.cuda.device_count() print(f"Using {args.n_gpus} GPU(s).") # Load data if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) print('DEFINING BOUNDS') if args.no_ndc: near = np.ndarray.min(bds) * .9 far = np.ndarray.max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split near = 2. far = 6. if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data( scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R - 1. far = hemi_R + 1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if args.render_test: render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) global_step = start bds_dict = { 'near': near, 'far': far, } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Move testing data to GPU render_poses = torch.Tensor(render_poses).to(device) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') with torch.no_grad(): if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None testsavedir = os.path.join( basedir, expname, 'renderonly_{}_{:06d}'.format( 'test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching K = args.num_poses if use_batching or K >= len(poses): # For random ray batching print('get rays') rays = np.stack( [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] print('done, concats') rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) i_batch = 0 else: print('We are taking batches of rays from', K, 'image(s) at a time', flush=True) train_poses = np.stack([poses[i] for i in i_train], axis=0) train_images = np.stack([images[i] for i in i_train], axis=0) nearest_rays = [] train_translations = train_poses[:, :3, 3] train_translations_indices = np.arange(len(train_translations)) print('Caching nearest rays for each pose...') for i, _ in enumerate(tqdm(train_translations)): translation = train_translations[i] distances = np.linalg.norm(train_translations - translation, axis=-1) knn_poses_indices = np.array( sorted(train_translations_indices, key=lambda i: distances[i]))[:K] knn_poses = train_poses[knn_poses_indices] knn_images = train_images[knn_poses_indices] nearest_rays.append( get_rays_rgb_for_poses(H, W, knn_poses, knn_images, focal)) print('Done', flush=True) # Move training data to GPU images = torch.Tensor(images).to(device) poses = torch.Tensor(poses).to(device) if use_batching: rays_rgb = torch.Tensor(rays_rgb).to(device) N_iters = args.N_iters print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) start = start for i in trange(start, N_iters): time0 = time.time() # Sample random ray batch optimizer.zero_grad() N_rand_small = 1024 if not use_batching: # Random from K-poses. pose_i = np.random.choice(i_train) rays_rgb = nearest_rays[pose_i] assert N_rand <= rays_rgb.shape[0] select_inds = np.random.choice(rays_rgb.shape[0], size=[N_rand], replace=False) # (N_rand,) # epoch_ended = False # num_rays = N_rand if not use_batching else min(N_rand, rays_rgb.size(0) - i_batch) num_rays = N_rand total_weight = 0.0 for j in range(0, N_rand, N_rand_small): # if epoch_ended: # break # Pick with replacement. curr_select_inds = select_inds[j:j + N_rand_small] batch = rays_rgb[curr_select_inds] if not use_batching: # rays_rgb was not moved to device for KNN-pose mode, since it may be too big. batch = torch.Tensor(batch).to(device) batch = torch.transpose(batch, 0, 1) batch_rays, target_s = batch[:2], batch[2] ##### Core optimization loop ##### rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) # optimizer.zero_grad() img_loss = img2mse(rgb, target_s) trans = extras['raw'][..., -1] loss = img_loss psnr = mse2psnr(img_loss) if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss = loss + img_loss0 # psnr0 = mse2psnr(img_loss0) weight = (rgb.size(0) / float(num_rays)) loss_backprop = loss * weight total_weight += weight loss_backprop.backward() assert np.isclose(total_weight, 1.0), total_weight optimizer.step() # NOTE: IMPORTANT! ### update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate**(global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate ################################ dt = time.time() - time0 # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") ##### end ##### # Rest is logging if i % args.i_weights == 0: path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) torch.save( { 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, path) print('Saved checkpoints at', path) # if i%args.i_video==0 and i > 0: # # Turn on testing mode # with torch.no_grad(): # rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test) # print('Done, saving', rgbs.shape, disps.shape) # moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format('video', i)) # imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) # imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) # if args.use_viewdirs: # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] # with torch.no_grad(): # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) # render_kwargs_test['c2w_staticcam'] = None # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) adjusted_step = int( (global_step - start) * N_rand / float(N_rand_small)) + start if i % args.i_testset == 0 and i > 0: testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[i_test].shape) with torch.no_grad(): _, _, test_mse = render_path(torch.Tensor( poses[i_test]).to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) test_psnr = mse2psnr(torch.from_numpy(np.array(test_mse))) writer.add_scalar('mse_test', test_mse, global_step=global_step) writer.add_scalar('psnr_test', test_psnr, global_step=global_step) writer.add_scalar('psnr_test_by_t', test_psnr, global_step=adjusted_step) writer.add_scalar('mse_test_by_t', test_mse, global_step=adjusted_step) print('Saved test set') if i % args.i_print == 0: tqdm.write( f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") print('', flush=True) writer.add_scalar('loss', loss, global_step=global_step) writer.add_scalar('psnr', psnr, global_step=global_step) writer.add_scalar('lr', new_lrate, global_step=global_step) # writer.add_histogram('tran', trans, global_step=global_step) writer.add_scalar('mse_by_t', loss, global_step=adjusted_step) writer.add_scalar('psnr_by_t', psnr, global_step=adjusted_step) # if args.N_importance > 0: # writer.add_scalar('psnr0', psnr0, global_step=global_step) if i % args.i_img == 0: # Log a rendered validation view to Tensorboard # img_i = np.random.choice(i_val) img_i = i_val[0] target = images[img_i] pose = poses[img_i, :3, :4] with torch.no_grad(): rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, **render_kwargs_test) mse = img2mse(rgb, target) psnr = mse2psnr(mse) writer.add_image('rgb', rgb, dataformats='HWC', global_step=global_step) writer.add_image('disp', torch.stack(3 * [disp], dim=-1), dataformats='HWC', global_step=global_step) writer.add_image('acc', torch.stack(3 * [acc], dim=-1), dataformats='HWC', global_step=global_step) writer.add_scalar('mse_holdout', mse, global_step=global_step) writer.add_scalar('mse_holdout_by_t', mse, global_step=adjusted_step) writer.add_image('rgb_holdout', target, dataformats='HWC', global_step=global_step) # if args.N_importance > 0: # with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): # tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis]) # tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis]) # tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis]) global_step += 1
def train(): parser = config_parser() args = parser.parse_args() if args.random_seed is not None: print('Fixing random seed', args.random_seed) np.random.seed(args.random_seed) tf.compat.v1.set_random_seed(args.random_seed) # Load data if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val)]) print('DEFINING BOUNDS') if args.no_ndc: args.near = tf.reduce_min(bds) * .9 args.far = tf.reduce_max(bds) * 1. else: args.near = 0. args.far = 1. print('NEAR FAR', args.near, args.far) elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split, extras = load_blender_data( args.datadir, args.half_res, args.testskip, args.image_extn, mask_directory=args.mask_directory, get_depths=args.get_depth_maps if args.near is None and args.far is None else False, image_field=args.image_fieldname, image_dir_override=args.image_dir_override, trainskip=args.trainskip, train_frames_field=args.frames_field) if args.mask_directory is not None: masks = extras['masks'] if args.get_depth_maps: depth_maps = extras['depth_maps'] K = None if args.use_K: K = extras['K'] print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split if args.near is None and args.far is None: if args.Z_limits_from_pose: Zs = poses[:, 2, 3] args.near = np.min(np.abs(Zs)) * 0.5 args.far = np.max(np.abs(Zs)) * 1.5 elif args.get_depth_maps and args.mask_directory is not None: args.near = np.min(np.mean(depth_maps[masks])) * 0.9 args.far = np.max(depth_maps[masks]) * 1.1 print('using masked depth') elif args.get_depth_maps: args.near = np.min(depth_maps) * 0.9 args.far = np.max(depth_maps) * 1.1 print('using masked depth') else: args.near = 0. args.far = 2. print(f'args.near: {args.near} far: {args.far}') if args.mask_directory is not None and images.shape[-1] == 3: images = np.concatenate([images, masks[..., np.newaxis]], axis=-1) if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) args.near = hemi_R-1. args.far = hemi_R+1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return if args.mask_directory and not args.white_bkgd: assert os.path.isdir(args.mask_directory), f'args.mask_directory not found at: {args.mask_directory}' if args.mask_images: images *= masks[..., np.newaxis] # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) filename_splits = extras['filenames'] enable_keypoints = False keypoint_loss_coeff = 0 if args.colmap_keypoints_filename is not None: enable_keypoints = True from NERFCO.nerf_keypoint_network import get_keypoint_masks import MODELS.random_embeddings train_keypoint_masks, args.number_of_keypoints = \ get_keypoint_masks(args.colmap_keypoints_filename, i_train, filename_splits, H, W) train_keypoint_masks = train_keypoint_masks.ravel() test_keypoint_masks, test_keypoint_count = \ get_keypoint_masks(args.colmap_keypoints_filename, i_test, filename_splits, H, W) number_of_keypoints = np.sum(train_keypoint_masks > 0) assert test_keypoint_count == args.number_of_keypoints, \ f'test_keypoint_count: {test_keypoint_count} == args.number_of_keypoints: {args.number_of_keypoints}' print(train_keypoint_masks.shape, train_keypoint_masks.min(), train_keypoint_masks.max(), number_of_keypoints) keypoint_embeddings_filename = os.path.join(args.basedir, args.expname, 'keypoint_embeddings.pkl') random_keypoint_embeddings_object = MODELS.random_embeddings.StaticRandomEmbeddings( args.number_of_keypoints + 1, args.keypoint_embedding_size, embedding_filename=keypoint_embeddings_filename, zero_embedding_origin=args.zero_embedding_origin) keypoint_embeddings = random_keypoint_embeddings_object.embeddings import NERFCO.extract_keypoints if args.autoencoded_keypoints_filename is not None: enable_keypoints = True assert os.path.isfile(args.autoencoded_keypoints_filename), \ f'autoencoded_keypoints_filename not found at: {args.autoencoded_keypoints_filename}' import pickle input = open(args.autoencoded_keypoints_filename, 'rb') data = pickle.load(input) input.close() train_keypoint_masks = data['train_keypoint_masks'] test_keypoint_masks = data['test_keypoint_masks'] keypoint_embeddings = data['encoded_embeddings'] NERFCO.extract_keypoints.test_keypoint_masks(train_keypoint_masks, keypoint_embeddings) NERFCO.extract_keypoints.test_keypoint_masks(test_keypoint_masks, keypoint_embeddings) args.keypoint_embedding_size = keypoint_embeddings.shape[-1] train_keypoint_masks = train_keypoint_masks.ravel() print(f'loaded {args.keypoint_detector} keypoints: {keypoint_embeddings.shape} from: {args.autoencoded_keypoints_filename}') elif args.keypoints_filename is not None and args.keypoint_detector in ['SIFT', 'ORB']: enable_keypoints = True train_keypoint_masks, test_keypoint_masks, keypoint_embeddings = \ NERFCO.extract_keypoints.get_keypoints_and_maps( args.keypoints_filename, i_test, i_train, filename_splits, H, W) NERFCO.extract_keypoints.test_keypoint_masks(train_keypoint_masks, keypoint_embeddings) NERFCO.extract_keypoints.test_keypoint_masks(test_keypoint_masks, keypoint_embeddings) args.keypoint_embedding_size = keypoint_embeddings.shape[-1] train_keypoint_masks = train_keypoint_masks.ravel() print(f'loaded {args.keypoint_detector} keypoints: {keypoint_embeddings.shape} from: {args.keypoints_filename}') elif args.learnable_embeddings_filename is not None: enable_keypoints = True train_keypoint_masks, test_keypoint_masks, static_keypoints = \ NERFCO.extract_keypoints.get_keypoints_and_maps( args.keypoints_filename, i_test, i_train, filename_splits, H, W) args.number_of_keypoints = static_keypoints.shape[0] # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf( args) if args.use_K: K_dict = { 'K': K, } render_kwargs_train.update(K_dict) render_kwargs_test.update(K_dict) if args.learnable_embeddings_filename is not None: keypoint_embeddings = models['keypoint_embeddings'] NERFCO.extract_keypoints.test_keypoint_masks(train_keypoint_masks, keypoint_embeddings[:]) NERFCO.extract_keypoints.test_keypoint_masks(test_keypoint_masks, keypoint_embeddings[:]) args.keypoint_embedding_size = keypoint_embeddings.shape[-1] train_keypoint_masks = train_keypoint_masks.ravel() print(f'loaded {args.keypoint_detector} keypoints: {keypoint_embeddings.shape} from: {args.learnable_embeddings_filename}') bds_dict = { 'near': tf.cast(args.near, tf.float32), 'far': tf.cast(args.far, tf.float32), } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) if args.render_test: render_poses = np.array(poses[i_test]) # Short circuit if only rendering out from trained model if args.render_only: # render_poses = poses[::args.testskip] render_test_data(args, render_poses, images, i_test, start, render_kwargs_test, hwf, K if K is not None else None, embeddings=keypoint_embeddings if enable_keypoints else None, gt_keypoint_map=test_keypoint_masks if enable_keypoints else None) exit(0) # Create optimizer lrate = args.lrate if args.lrate_decay > 0: lrate = tf.keras.optimizers.schedules.ExponentialDecay(lrate, decay_steps=args.lrate_decay * 1000, decay_rate=0.1) optimizer = tf.keras.optimizers.Adam(lrate) models['optimizer'] = optimizer global_step = tf.compat.v1.train.get_or_create_global_step() global_step.assign(start) photometric_loss_function = img2mse if args.use_huber_loss: photometric_loss_function = tf.compat.v1.losses.huber_loss # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: # For random ray batching. # # Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is # interpreted as, # axis=0: ray origin in world space # axis=1: ray direction in world space # axis=2: observed RGB color of pixel # get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3] # for each pixel in the image. This stack() adds a new dimension. if args.use_K: print('get rays K') rays = [NERFCO.nerf_renderer.get_rays_tf_K(H, W, K, p) for p in poses[:, :3, :4]] else: print('get rays') rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]] rays = np.stack(rays, axis=0) # [N, ro+rd, H, W, 3] print('done, concats') # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.concatenate([rays, images[:, None, ...]], 1) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) rays_rgb = np.stack([rays_rgb[i] for i in i_train], axis=0) # train images only # if args.depth_from_camera and args.depth_loss: train_depths = np.stack([depth_maps[i] for i in i_train], axis=0).ravel() # train images only rays_rgb = np.stack([rays_rgb[i] for i in i_train], axis=0) # train images only if args.mask_directory and not args.white_bkgd: train_masks = np.stack([masks[i] for i in i_train], axis=0) pixel_train_masks = train_masks.ravel() if args.ray_masking: rays_rgb = rays_rgb[np.where(train_masks)] # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) rays_rgb = rays_rgb.astype(np.float32) if args.keypoint_oversample: keypoint_mask = train_keypoint_masks.astype(np.bool) over_sample = train_keypoint_masks.shape[0] // np.sum(keypoint_mask) - 1 keypoint_rays_rgb = rays_rgb[keypoint_mask] non_zero_keypoints = train_keypoint_masks[keypoint_mask] assert non_zero_keypoints[0] == train_keypoint_masks[np.argmax(keypoint_mask)], 'first true keypoint should match' # repeat rgb repeated_keypoint_rays_rgb = np.repeat(keypoint_rays_rgb, over_sample, axis=0) assert np.all(keypoint_rays_rgb[0] == repeated_keypoint_rays_rgb[0]), 'repeats should be the same' rays_rgb = np.concatenate([rays_rgb, repeated_keypoint_rays_rgb]) # repeat keypoints repeated_keypoint_masks = np.repeat(non_zero_keypoints, over_sample, axis=0) assert repeated_keypoint_masks[0] == train_keypoint_masks[np.argmax(keypoint_mask)] train_keypoint_masks = np.concatenate([train_keypoint_masks, repeated_keypoint_masks]) # check rays and keypoint masks are the same size assert rays_rgb.shape[0] == train_keypoint_masks.shape[0] N_iters = 1000000 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers writer = tf.contrib.summary.create_file_writer( os.path.join(basedir, 'summaries', expname)) writer.set_as_default() i_batch = 0 for i in range(start, N_iters): time0 = time.time() if i >= args.keypoint_iterations_start: keypoint_loss_coeff = args.keypoint_loss_coeff # Sample random ray batch if use_batching: # shuffle if i == start or i_batch >= rays_rgb.shape[0]: rays_rgb, permutations = shuffler(rays_rgb) if args.mask_directory and args.sigma_masking: pixel_train_masks = shuffler(pixel_train_masks, permutations) if enable_keypoints: train_keypoint_masks = shuffler(train_keypoint_masks, permutations) if args.depth_from_camera and args.depth_loss: train_depths = shuffler(train_depths, permutations) i_batch = 0 # Random over all images batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] batch = tf.transpose(batch, [1, 0, 2]) if args.sigma_masking: in_mask_pixels_batch = pixel_train_masks[i_batch: i_batch + N_rand] if enable_keypoints: keypoint_masks_batch = train_keypoint_masks[i_batch: i_batch + N_rand] if args.depth_from_camera and args.depth_loss: train_depths_batch = train_depths[i_batch: i_batch + N_rand] i_batch += N_rand # batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position # target_s[n, rgb] = example_id, observed color. if args.depth_from_camera and args.depth_loss: batch_rays, target_s = (batch[0], batch[0], train_depths_batch), batch[2] else: batch_rays, target_s = batch[:2], batch[2] else: # Random from one image test_frame_number = np.random.choice(i_train) target = images[test_frame_number] pose = poses[test_frame_number, :3, :4] if N_rand is not None: if args.use_K: rays_o, rays_d = NERFCO.nerf_renderer.get_rays_K(H, W, K, pose) else: rays_o, rays_d = get_rays(H, W, focal, pose) if i < args.precrop_iters: dH = int(H//2 * args.precrop_frac) dW = int(W//2 * args.precrop_frac) coords = tf.stack(tf.meshgrid( tf.range(H//2 - dH, H//2 + dH), tf.range(W//2 - dW, W//2 + dW), indexing='ij'), -1) if i < 10: print('precrop', dH, dW, coords[0,0], coords[-1,-1]) else: coords = tf.stack(tf.meshgrid( tf.range(H), tf.range(W), indexing='ij'), -1) coords = tf.reshape(coords, [-1, 2]) select_inds = np.random.choice( coords.shape[0], size=[N_rand], replace=False) select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis]) rays_o = tf.gather_nd(rays_o, select_inds) rays_d = tf.gather_nd(rays_d, select_inds) batch_rays = tf.stack([rays_o, rays_d], 0) target_s = tf.gather_nd(target, select_inds) ##### Core optimization loop ##### with tf.GradientTape() as nerf_gradient_tape, tf.GradientTape() as embedding_gradient_tape: # Make predictions for color, disparity, accumulated opacity. rgb, disp, acc, extras = render( H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) # Compute MSE loss between predicted and true RGB. if args.sigma_masking: loss = photometric_loss_function(rgb[in_mask_pixels_batch], target_s[in_mask_pixels_batch]) else: loss = photometric_loss_function(rgb, target_s) if args.force_black_background: loss += 0.1 * photometric_loss_function(rgb[~in_mask_pixels_batch], 0.) if 'keypoint_map' in extras: keypoint_loss = NERFCO.nerf_keypoint_network.get_keypoint_loss(keypoint_embeddings, keypoint_masks_batch, extras['keypoint_map']) loss += keypoint_loss_coeff * keypoint_loss if args.learnable_embeddings_filename is not None: embedding_loss = keypoint_loss_coeff * keypoint_embeddings.distance_correlation_loss(extras['xyz_map'], extras['keypoint_map']) trans = extras['raw'][..., -1] psnr = mse2psnr(loss) if args.sigma_masking: loss += 0.01 * tf.reduce_sum(acc[~in_mask_pixels_batch]) / N_rand # Add MSE loss for coarse-grained model if 'rgb0' in extras: if args.sigma_masking: img_loss0 = photometric_loss_function(extras['rgb0'][in_mask_pixels_batch], target_s[in_mask_pixels_batch]) psnr0 = mse2psnr(img_loss0) loss += img_loss0 + 0.01 * tf.reduce_sum(extras['acc0'][~in_mask_pixels_batch]) / N_rand else: img_loss0 = photometric_loss_function(extras['rgb0'], target_s) psnr0 = mse2psnr(img_loss0) loss += img_loss0 if args.force_black_background: loss += 0.1 * photometric_loss_function(extras['rgb0'][~in_mask_pixels_batch], 0.) if 'keypoint_map_0' in extras: coarse_keypoint_loss = NERFCO.nerf_keypoint_network.get_keypoint_loss(keypoint_embeddings, keypoint_masks_batch, extras['keypoint_map_0']) loss += keypoint_loss_coeff * coarse_keypoint_loss if args.depth_from_camera and args.depth_loss: loss += tf.losses.huber_loss(train_depths_batch, extras['depth_map']) gradients = nerf_gradient_tape.gradient(loss, grad_vars) optimizer.apply_gradients(zip(gradients, grad_vars)) if args.learnable_embeddings_filename is not None: gradients = embedding_gradient_tape.gradient(embedding_loss, models['keypoint_embeddings'].trainable_variables) optimizer.apply_gradients(zip(gradients, models['keypoint_embeddings'].trainable_variables)) dt = time.time()-time0 ##### end ##### # Rest is logging def save_weights(net, prefix, i): path = os.path.join( basedir, expname, '{}_{:06d}.npy'.format(prefix, i)) np.save(path, net.get_weights()) print('saved weights at', path) if i % args.i_weights == 0: for k in models: save_weights(models[k], k, i) if i % args.i_video == 0 and i > 0: rgbs, test_extras = render_path( render_poses, hwf, args.chunk, render_kwargs_test) disps = test_extras['disp'] print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join( basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if args.use_viewdirs: render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4] rgbs_still, _ = render_path( render_poses, hwf, args.chunk, render_kwargs_test) render_kwargs_test['c2w_staticcam'] = None imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) if i % args.i_testset == 0 and i > 0: testsavedir = os.path.join( basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[i_test].shape) render_path(poses[i_test], hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) print('Saved test set') if i % args.i_print == 0 or i < 10: output_line = f'{expname} {i}, {psnr.numpy():.3g}, {loss.numpy():.3g} ' # report loss if enable_keypoints: gt_kp_mask = keypoint_masks_batch.astype(np.bool) network_output = np.squeeze(extras['keypoint_map']) output_line += f'kp loss:{keypoint_loss:.2g} ' if len(network_output.shape) == 1: output_line+= f'GT kp#:{int(np.sum(keypoint_masks_batch.astype(np.bool).astype(np.float))):d} '\ f'TP+TN:{np.sum(np.isclose(gt_kp_mask.astype(np.float), network_output)) / network_output.shape[0]:.2g} '\ f'TP:{np.sum(np.isclose(gt_kp_mask[gt_kp_mask].astype(np.float), network_output[gt_kp_mask])) / np.sum(gt_kp_mask):.2g} '\ f'TN:{np.sum(np.isclose(gt_kp_mask[~gt_kp_mask].astype(np.float), network_output[~gt_kp_mask])) / np.sum(~gt_kp_mask):.2g} ' print(output_line) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): tf.contrib.summary.scalar('loss', loss) tf.contrib.summary.scalar('psnr', psnr) tf.contrib.summary.histogram('tran', trans) if args.N_importance > 0: tf.contrib.summary.scalar('psnr0', psnr0) if enable_keypoints: tf.contrib.summary.histogram('keypoint output', network_output) tf.contrib.summary.scalar('keypoint_loss', keypoint_loss) if i % args.i_img == 0: # report accuracy # Log a rendered validation view to Tensorboard test_frame_number = np.random.choice(i_val) target = images[test_frame_number] pose = poses[test_frame_number, :3, :4] rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) # Save out the validation image for Tensorboard-free monitoring testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs') os.makedirs(testimgdir, exist_ok=True) imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(rgb)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): output_image = to8b(rgb)[tf.newaxis] tf.contrib.summary.image('rgb', output_image) tf.contrib.summary.image( 'disp', disp[tf.newaxis, ..., tf.newaxis]) tf.contrib.summary.image( 'acc', acc[tf.newaxis, ..., tf.newaxis]) tf.contrib.summary.scalar('psnr_holdout', psnr) tf.contrib.summary.image('rgb_holdout', target[tf.newaxis]) if enable_keypoints: # TODO: test why the train renders seem to be less lossy than the test renders # the test renders highlight the entire object whereas # the train renders don't, but have a good loss test_frame_in_split = np.where(i_val == test_frame_number)[0][0] gt_keypoint_mask = test_keypoint_masks[test_frame_in_split] gt_non_zeros_kp_mask = gt_keypoint_mask.astype(np.bool) test_keypoint_loss = NERFCO.nerf_keypoint_network.get_keypoint_loss(keypoint_embeddings, gt_keypoint_mask, extras['keypoint_map']) tf.contrib.summary.scalar('test_keypoint_loss', test_keypoint_loss) if len(keypoint_embeddings.shape) == 1: network_output = extras['keypoint_map'] network_output_image = np.reshape(network_output.numpy(), (H, W)) keypoint_accuracy_image = network_output_image == gt_non_zeros_kp_mask tf.contrib.summary.histogram('keypointyness', network_output) tf.contrib.summary.scalar('keypoint_acc', np.sum(keypoint_accuracy_image) / (H*W)) tf.contrib.summary.scalar('non_zero_keypoint_acc', np.sum(keypoint_accuracy_image[gt_non_zeros_kp_mask]) / np.sum( gt_non_zeros_kp_mask)) tf.contrib.summary.image('keypoint_image_acc', to8b(keypoint_accuracy_image)[tf.newaxis, ..., tf.newaxis]) else: inferred_keypoint_image, keypoint_accuracy_image = \ NERFCO.nerf_keypoint_network.create_embedded_keypoint_image(extras['keypoint_map'], keypoint_embeddings, gt_keypoint_mask) tf.contrib.summary.scalar('keypoint_acc', np.sum(keypoint_accuracy_image) / (H * W)) tf.contrib.summary.scalar('non_zero_keypoint_acc', np.sum(keypoint_accuracy_image[gt_non_zeros_kp_mask]) / np.sum(gt_non_zeros_kp_mask)) tf.contrib.summary.image('keypoint_image_acc', to8b(keypoint_accuracy_image)[tf.newaxis, ..., tf.newaxis]) tf.contrib.summary.image('keypoint_image', to8b(inferred_keypoint_image)[tf.newaxis, ..., tf.newaxis]) if args.N_importance > 0: with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image( 'rgb0', to8b(extras['rgb0'])[tf.newaxis]) tf.contrib.summary.image( 'disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis]) tf.contrib.summary.image( 'z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis]) global_step.assign_add(1)
# expname = 'fern_test' config = os.path.join(basedir, expname, 'config.txt') print('Args:') print(open(config, 'r').read()) parser = run_nerf_fast.config_parser() weights_name = 'model_200000.npy' # weights_name = 'model_000700.npy' args = parser.parse_args('--config {} --ft_path {}'.format( config, os.path.join(basedir, expname, weights_name))) print('loaded args') images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) H, W, focal = poses[0, :3, -1].astype(np.float32) poses = poses[:, :3, :4] print(f"NUM IMAGES: {poses.shape[0]}") H = int(H) W = int(W) hwf = [H, W, focal] images = images.astype(np.float32) poses = poses.astype(np.float32) if args.no_ndc: near = tf.reduce_min(bds) * .9
def load_llff_dataset( render_kwargs_train_=None, render_kwargs_test_=None, return_nerf_volume_extent=False, ): datadir = args.datadir factor = args.factor spherify = args.spherify bd_factor = args.bd_factor # actual loading images, poses, bds, render_poses, i_test = load_llff_data( datadir, factor=factor, recenter=True, bd_factor=bd_factor, spherify=spherify, ) extras = _get_multi_view_helper_mappings(images.shape[0]) # poses hwf = poses[0, :3, -1] poses = poses[:, :3, :4] # N x 3 x 4 all_rotations = poses[:, :3, :3] # N x 3 x 3 all_translations = poses[:, :3, 3] # N x 3 render_poses = render_poses[:, :3, :4] render_rotations = render_poses[:, :3, :3] render_translations = render_poses[:, :3, 3] # splits i_test = [] # [i_test] if args.test_block_size > 0 and args.train_block_size > 0: print("splitting timesteps into training (" + str(args.train_block_size) + ") and test (" + str(args.test_block_size) + ") blocks") num_timesteps = len(dataset_extras["raw_timesteps"]) test_timesteps = np.concatenate([ np.arange( min(num_timesteps, blocks_start + args.train_block_size), min( num_timesteps, blocks_start + args.train_block_size + args.test_block_size, ), ) for blocks_start in np.arange( 0, num_timesteps, args.train_block_size + args.test_block_size) ]) i_test = [ imageid for imageid, timestep in enumerate( dataset_extras["imageid_to_timestepid"]) if timestep in test_timesteps ] i_test = np.array(i_test) i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) # near, far # if args.no_ndc: near = np.ndarray.min(bds) * 0.9 far = np.ndarray.max(bds) * 1.0 # else: # near = 0. # far = 1. bds_dict = { "near": near, "far": far, } if render_kwargs_train_ is not None: render_kwargs_train_.update(bds_dict) if render_kwargs_test_ is not None: render_kwargs_test_.update(bds_dict) if return_nerf_volume_extent: ray_params = checkpoint_dict["ray_params"] min_point, max_point = determine_nerf_volume_extent( get_parallelized_render_function(), poses, hwf[0], hwf[1], hwf[2], ray_params, render_kwargs_test, ) extras["min_nerf_volume_point"] = min_point.detach() extras["max_nerf_volume_point"] = max_point.detach() return ( images, poses, all_rotations, all_translations, bds, render_poses, render_rotations, render_translations, i_train, i_val, i_test, near, far, extras, )
def train(): parser = config_parser() args = parser.parse_args() # Load data K = None if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) full_original_track_poses = torch.from_numpy(poses).float().to(device) print('DEFINING BOUNDS') if args.no_ndc: near = np.ndarray.min(bds) * .9 far = np.ndarray.max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split near = 2. far = 6. if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'custom': images, poses, render_poses, full_original_track_poses, hwf, i_split = load_custom_data( args.scene, args.half_res, args.testskip, args.inv) print('Loaded RealEstate', images.shape, render_poses.shape, hwf, render_poses.shape) i_train, i_val, i_test = i_split #use ndc near = 0.05 far = args.far images = images[..., :3] elif args.dataset_type == 'LINEMOD': images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data( args.datadir, args.half_res, args.testskip) print( f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}' ) print(f'[CHECK HERE] near: {near}, far: {far}.') i_train, i_val, i_test = i_split if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data( scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R - 1. far = hemi_R + 1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return # images are all of the images needed # poses (c2w) are all of the poses needed, matches images in terms of indexing # i_train, i_val, i_test are the indexes of images/poses that are train, val, test respectively # render_poses (c2w) are the poses to render a novel track video with # hwf is height, width, focal length (not normalized), which are used to construct K if it hasn't been loaded # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if K is None: modifier = 0 K = np.array([[focal + modifier * W, 0, 0.5 * W], [0, focal + modifier * H, 0.5 * H], [0, 0, 1]]) if args.render_test: render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir basedir = os.path.join(basedir, dset_to_prefix[args.dataset_type]) expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) global_step = start bds_dict = { 'near': near, 'far': far, } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Move testing data to GPU render_poses = torch.Tensor(render_poses).to(device) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') with torch.no_grad(): if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None testsavedir = os.path.join( basedir, expname, 'renderonly_{}_{:06d}'.format( 'test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: # For random ray batching print('get rays') rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] print('done, concats') rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) print('done') i_batch = 0 # Move training data to GPU if use_batching: images = torch.Tensor(images).to(device) poses = torch.Tensor(poses).to(device) if use_batching: rays_rgb = torch.Tensor(rays_rgb).to(device) N_iters = 200000 + 1 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) start = start + 1 for i in trange(start, N_iters): time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1, 3*?] batch = torch.transpose(batch, 0, 1) batch_rays, target_s = batch[:2], batch[2] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: print("Shuffle data after an epoch!") rand_idx = torch.randperm(rays_rgb.shape[0]) rays_rgb = rays_rgb[rand_idx] i_batch = 0 else: # Random from one image img_i = np.random.choice(i_train) target = images[img_i] target = torch.Tensor(target).to(device) pose = poses[img_i, :3, :4] if N_rand is not None: rays_o, rays_d = get_rays( H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) if i < args.precrop_iters: dH = int(H // 2 * args.precrop_frac) dW = int(W // 2 * args.precrop_frac) coords = torch.stack( torch.meshgrid( torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)), -1) if i == start: print( f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}" ) else: coords = torch.stack( torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)), -1) # (H, W, 2) coords = torch.reshape(coords, [-1, 2]) # (H * W, 2) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) select_coords = coords[select_inds].long() # (N_rand, 2) rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) batch_rays = torch.stack([rays_o, rays_d], 0) target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) ##### Core optimization loop ##### rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) optimizer.zero_grad() img_loss = img2mse(rgb, target_s) trans = extras['raw'][..., -1] loss = img_loss psnr = mse2psnr(img_loss) if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0) loss.backward() optimizer.step() # NOTE: IMPORTANT! ### update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate**(global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate ################################ dt = time.time() - time0 # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") ##### end ##### # Rest is logging if i % args.i_weights == 0: path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) torch.save( { 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, path) print('Saved checkpoints at', path) if i % args.i_video == 0 and i > 0: # Turn on testing mode compare_dir = os.path.join( '/data/vision/billf/intrinsic/neural-render/ericqian/results/', dset_to_prefix[args.dataset_type]) novel_compare_dir = os.path.join(compare_dir, args.scene, 'novel', 'nerf') os.makedirs(novel_compare_dir, exist_ok=True) with torch.no_grad(): rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, savedir=novel_compare_dir) moviebase = os.path.join( basedir, expname, '{}_renderposes_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) moviebase_compare = os.path.join(compare_dir, args.scene, 'novel') imageio.mimwrite(moviebase_compare + '/nerf_rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase_compare + '/nerf_disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) # Turn on testing mode original_compare_dir = os.path.join(compare_dir, args.scene, 'original', 'nerf') os.makedirs(original_compare_dir, exist_ok=True) with torch.no_grad(): rgbs, disps = render_path(full_original_track_poses, hwf, K, args.chunk, render_kwargs_test, savedir=original_compare_dir) moviebase = os.path.join( basedir, expname, '{}_fulloriginal_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) moviebase_compare = os.path.join(compare_dir, args.scene, 'original') imageio.mimwrite(moviebase_compare + '/nerf_rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase_compare + '/nerf_disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if i % args.i_testset == 0 and i > 0: testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) with torch.no_grad(): render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) if i % args.i_print == 0: tqdm.write( f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") """ print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) print('iter time {:.05f}'.format(dt)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): tf.contrib.summary.scalar('loss', loss) tf.contrib.summary.scalar('psnr', psnr) tf.contrib.summary.histogram('tran', trans) if args.N_importance > 0: tf.contrib.summary.scalar('psnr0', psnr0) if i%args.i_img==0: # Log a rendered validation view to Tensorboard img_i=np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3,:4] with torch.no_grad(): rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis]) tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis]) tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis]) tf.contrib.summary.scalar('psnr_holdout', psnr) tf.contrib.summary.image('rgb_holdout', target[tf.newaxis]) if args.N_importance > 0: with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis]) tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis]) tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis]) """ global_step += 1
def cloud_size_vs_performance(): basedir = './logs' expname = 'fern_example' config = os.path.join(basedir, expname, 'config.txt') print('Args:') print(open(config, 'r').read()) parser = run_nerf.config_parser() weights_name = 'model_200000.npy' args = parser.parse_args('--config {} --ft_path {}'.format(config, os.path.join(basedir, expname, weights_name))) print('loaded args') images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) H, W, focal = poses[0,:3,-1].astype(np.float32) poses = poses[:, :3, :4] H = int(H) W = int(W) hwf = [H, W, focal] images = images.astype(np.float32) poses = poses.astype(np.float32) near = 0. far = 1. if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] _, render_kwargs, start, grad_vars, models = run_nerf.create_nerf(args) to_use = i_test[0] bds_dict = { 'near' : tf.cast(near, tf.float32), 'far' : tf.cast(far, tf.float32), } render_kwargs.update(bds_dict) print('Render kwargs:') pprint.pprint(render_kwargs) res_dir = "./cloud_size_test" res = {} res['cloud_size'] = [] res['mse'] = [] res['psnr'] = [] res['time'] = [] for i in [1,2,4,8,16,32]: print(f'Running with cloud downsampled {i}x') start_time = time.time() ret_vals = run_nerf.render(H, W, focal, c2w=poses[to_use], pc=True, cloudsize=i, **render_kwargs) end_time = time.time() img = np.clip(ret_vals[0],0,1) mse = run_nerf.img2mse(images[to_use], img) psnr = run_nerf.mse2psnr(mse) res['cloud_size'].append((17 * H * W) // (i * i)) res['mse'].append(float(mse)) res['psnr'].append(float(psnr)) res['time'].append(end_time - start_time) # a = [1,2,4,8,16,32] # b = [1/x for x in a] # make plots # cs vs psnr fig, ax = plt.subplots(1,1) fig.suptitle('PSNR vs Point Cloud Size') ax.set_xlabel('Cloud Size') ax.set_ylabel('PSNR') plt.xscale('log') ax.plot(res['cloud_size'],res['psnr']) plt.savefig(os.path.join(res_dir, 'cs_psnr.png')) fig, ax = plt.subplots(1,1) fig.suptitle('PSNR vs Running Time') ax.set_xlabel('Time') ax.set_ylabel('PSNR') plt.xscale('log') ax.plot(res['time'],res['psnr']) plt.savefig(os.path.join(res_dir, 'time_psnr.png')) fig, ax = plt.subplots(1,1) fig.suptitle('Running Time vs Cloud Size') ax.set_xlabel('Cloud Size') ax.set_ylabel('Running Time') plt.xscale('log') plt.yscale('log') ax.plot(res['cloud_size'],res['time']) plt.savefig(os.path.join(res_dir, 'cs_time.png')) with open(os.path.join(res_dir, 'results.txt'), 'w') as outfile: json.dump(res,outfile)
def train_mu(): parser = config_parser() args = parser.parse_args() if args.random_seed is not None: print('Fixing random seed', args.random_seed) np.random.seed(args.random_seed) tf.compat.v1.set_random_seed(args.random_seed) # Load data if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) print('DEFINING BOUNDS') if args.no_ndc: near = tf.reduce_min(bds) * .9 far = tf.reduce_max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if args.render_test: render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) mu_workdir = os.path.join(basedir, expname + '_mu') os.makedirs(mu_workdir, exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model nerf_render_kwargs_train, nerf_render_kwargs_test, models = create_nerf( args) start, mu_grad_vars, model_mu, mu_embed_fn, mu_embeddirs_fn = create_mu_model( mu_workdir, args) bds_dict = { 'near': tf.cast(near, tf.float32), 'far': tf.cast(far, tf.float32), } nerf_render_kwargs_train.update(bds_dict) nerf_render_kwargs_test.update(bds_dict) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None testsavedir = os.path.join( basedir, expname, 'renderonly_{}_{:06d}'.format( 'test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _ = render_path(render_poses, hwf, args.chunk, nerf_render_kwargs_train, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Create optimizer lrate = args.lrate if args.lrate_decay > 0: lrate = tf.keras.optimizers.schedules.ExponentialDecay( lrate, decay_steps=args.lrate_decay * 1000, decay_rate=0.1) optimizer = tf.keras.optimizers.Adam(lrate) models['optimizer'] = optimizer models['model_mu'] = [model_mu, mu_embed_fn, mu_embeddirs_fn] global_step = tf.compat.v1.train.get_or_create_global_step() global_step.assign(start) # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: # For random ray batching. # # Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is # interpreted as, # axis=0: ray origin in world space # axis=1: ray direction in world space # axis=2: observed RGB color of pixel print('get rays') # get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3] # for each pixel in the image. This stack() adds a new dimension. rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]] rays = np.stack(rays, axis=0) # [N, ro+rd, H, W, 3] print('done, concats') # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.concatenate([rays, images[:, None, ...]], 1) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) rays_rgb = np.stack([rays_rgb[i] for i in i_train], axis=0) # train images only # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) print('done') i_batch = 0 N_iters = 1000000 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers writer = tf.summary.create_file_writer( os.path.join(basedir, 'summaries', expname)) for i in range(start, N_iters): time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1, 3*?] batch = tf.transpose(batch, [1, 0, 2]) # batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position # target_s[n, rgb] = example_id, observed color. batch_rays, target_s = batch[:2], batch[2] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: np.random.shuffle(rays_rgb) i_batch = 0 else: # Random from one image img_i = np.random.choice(i_train) target = images[img_i] pose = poses[img_i, :3, :4] if N_rand is not None: rays_o, rays_d = get_rays(H, W, focal, pose) if i < args.precrop_iters: dH = int(H // 2 * args.precrop_frac) dW = int(W // 2 * args.precrop_frac) coords = tf.stack( tf.meshgrid(tf.range(H // 2 - dH, H // 2 + dH), tf.range(W // 2 - dW, W // 2 + dW), indexing='ij'), -1) if i < 10: print('precrop', dH, dW, coords[0, 0], coords[-1, -1]) else: coords = tf.stack( tf.meshgrid(tf.range(H), tf.range(W), indexing='ij'), -1) coords = tf.reshape(coords, [-1, 2]) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis]) rays_o = tf.gather_nd(rays_o, select_inds) rays_d = tf.gather_nd(rays_d, select_inds) batch_rays = tf.stack([rays_o, rays_d], 0) target_s = tf.gather_nd(target, select_inds) ##### Core optimization loop ##### _, _, _, extras = nr.render(H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **nerf_render_kwargs_train) mu_exp = extras['mu_exp'] pts = extras['pts'] viewdirs = extras['viewdirs'] with tf.GradientTape(persistent=False) as tape: mu_out = run_mu_model(pts, viewdirs, model_mu, mu_embed_fn, mu_embeddirs_fn) mu_out = tf.reshape(mu_out, mu_exp.shape) mu_out = tf.math.tanh(tf.nn.relu(mu_out)) # print(mu_out.shape, mu_exp.shape) mu_loss = img2mse(mu_exp, mu_out) gradients = tape.gradient(mu_loss, mu_grad_vars) optimizer.apply_gradients(zip(gradients, mu_grad_vars)) dt = time.time() - time0 ##### end ##### # Rest is logging def save_mu_weights(net, i): path = os.path.join(mu_workdir, 'model_mu_{:06d}.npy'.format(i)) np.save(path, net.get_weights()) print('saved weights at', path) if i % args.i_weights == 0: save_mu_weights(model_mu, i) if i % args.i_print == 0 or i < 10: print(expname + '_mu', i, mu_loss.numpy(), global_step.numpy()) print('iter time {:.05f}'.format(dt)) with writer.as_default(): tf.summary.scalar('mu_loss', mu_loss, step=i + 1) global_step.assign_add(1)
def train(): parser = config_parser() args = parser.parse_args() # Load data if args.dataset_type == "llff": images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=0.75, spherify=args.spherify, ) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print("Loaded llff", images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print("Auto LLFF holdout,", args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) print("DEFINING BOUNDS") if args.no_ndc: near = np.ndarray.min(bds) * 0.9 far = np.ndarray.max(bds) * 1.0 else: near = 0.0 far = 1.0 print("NEAR FAR", near, far) elif args.dataset_type == "blender": images, poses, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip) print("Loaded blender", images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split near = 2.0 far = 6.0 if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == "deepvoxels": images, poses, render_poses, hwf, i_split = load_dv_data( scene=args.shape, basedir=args.datadir, testskip=args.testskip) print("Loaded deepvoxels", images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R - 1.0 far = hemi_R + 1.0 else: print("Unknown dataset type", args.dataset_type, "exiting") return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if args.render_test: render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, "args.txt") with open(f, "w") as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write("{} = {}\n".format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, "config.txt") with open(f, "w") as file: file.write(open(args.config, "r").read()) # data prepare ready # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) global_step = start bds_dict = { "near": near, "far": far, } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Move testing data to GPU render_poses = torch.Tensor(render_poses).to(device) # Short circuit if only rendering out from trained model if args.render_only: print("RENDER ONLY") with torch.no_grad(): if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None testsavedir = os.path.join( basedir, expname, "renderonly_{}_{:06d}".format( "test" if args.render_test else "path", start), ) os.makedirs(testsavedir, exist_ok=True) print("test poses shape", render_poses.shape) rgbs, _ = render_path( render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor, ) print("Done rendering", testsavedir) imageio.mimwrite(os.path.join(testsavedir, "video.mp4"), to8b(rgbs), fps=30, quality=8) return # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: # For random ray batching print("get rays") rays = np.stack( [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] print("done, concats") rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = rays_rgb.astype(np.float32) print("shuffle rays") np.random.shuffle(rays_rgb) print("done") i_batch = 0 # Move training data to GPU images = torch.Tensor(images).to(device) poses = torch.Tensor(poses).to(device) if use_batching: rays_rgb = torch.Tensor(rays_rgb).to(device) N_iters = 200000 + 1 print("Begin") print("TRAIN views are", i_train) print("TEST views are", i_test) print("VAL views are", i_val) # Summary writers # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) start = start + 1 for i in trange(start, N_iters): time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1, 3*?] batch = torch.transpose(batch, 0, 1) batch_rays, target_s = batch[:2], batch[2] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: print("Shuffle data after an epoch!") rand_idx = torch.randperm(rays_rgb.shape[0]) rays_rgb = rays_rgb[rand_idx] i_batch = 0 else: # Random from one image img_i = np.random.choice(i_train) target = images[img_i] # get one image pose = poses[img_i, :3, :4] # get the pose of the image if N_rand is not None: # for one image, there will be HxW cameras rays, # with the same start points (ro) and different end points (rd) rays_o, rays_d = get_rays( H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) if i < args.precrop_iters: dH = int(H // 2 * args.precrop_frac) dW = int(W // 2 * args.precrop_frac) coords = torch.stack( torch.meshgrid( torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW), ), -1, ) if i == start: print( f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}" ) else: coords = torch.stack( torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)), -1, ) # (H, W, 2) coords = torch.reshape(coords, [-1, 2]) # (H * W, 2) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) # random select some 2D coordinates with (x, y) index to select some camera rays randomly select_coords = coords[select_inds].long() # (N_rand, 2) rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) batch_rays = torch.stack([rays_o, rays_d], 0) # get the rgb values from the orginal image target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) ##### Core optimization loop ##### # in this iteration, optimize the model rgb, disp, acc, extras = render( H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train, ) optimizer.zero_grad() img_loss = img2mse(rgb, target_s) trans = extras["raw"][..., -1] loss = img_loss psnr = mse2psnr(img_loss) if "rgb0" in extras: img_loss0 = img2mse(extras["rgb0"], target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0) loss.backward() optimizer.step() # NOTE: IMPORTANT! ### update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate**(global_step / decay_steps)) for param_group in optimizer.param_groups: param_group["lr"] = new_lrate ################################ dt = time.time() - time0 # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") ##### end ##### # Rest is logging if i % args.i_weights == 0: path = os.path.join(basedir, expname, "{:06d}.tar".format(i)) torch.save( { "global_step": global_step, "network_fn_state_dict": render_kwargs_train["network_fn"].state_dict(), "network_fine_state_dict": render_kwargs_train["network_fine"].state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, path, ) print("Saved checkpoints at", path) if i % args.i_video == 0 and i > 0: # Turn on testing mode with torch.no_grad(): rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test) print("Done, saving", rgbs.shape, disps.shape) moviebase = os.path.join(basedir, expname, "{}_spiral_{:06d}_".format(expname, i)) imageio.mimwrite(moviebase + "rgb.mp4", to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + "disp.mp4", to8b(disps / np.max(disps)), fps=30, quality=8) # if args.use_viewdirs: # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] # with torch.no_grad(): # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) # render_kwargs_test['c2w_staticcam'] = None # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) if i % args.i_testset == 0 and i > 0: testsavedir = os.path.join(basedir, expname, "testset_{:06d}".format(i)) os.makedirs(testsavedir, exist_ok=True) print("test poses shape", poses[i_test].shape) with torch.no_grad(): render_path( torch.Tensor(poses[i_test]).to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir, ) print("Saved test set") if i % args.i_print == 0: tqdm.write( f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") """ print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) print('iter time {:.05f}'.format(dt)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): tf.contrib.summary.scalar('loss', loss) tf.contrib.summary.scalar('psnr', psnr) tf.contrib.summary.histogram('tran', trans) if args.N_importance > 0: tf.contrib.summary.scalar('psnr0', psnr0) if i%args.i_img==0: # Log a rendered validation view to Tensorboard img_i=np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3,:4] with torch.no_grad(): rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis]) tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis]) tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis]) tf.contrib.summary.scalar('psnr_holdout', psnr) tf.contrib.summary.image('rgb_holdout', target[tf.newaxis]) if args.N_importance > 0: with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis]) tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis]) tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis]) """ global_step += 1
def train(): parser = config_parser() args = parser.parse_args() # Multi-GPU args.n_gpus = torch.cuda.device_count() print("Using {} GPU(s).".format(args.n_gpus)) # Load data if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test, gt_depths = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) print('DEFINING BOUNDS') if args.no_ndc: near = np.ndarray.min(bds) * .9 far = np.ndarray.max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) DEPTH_RATIO = 0.00335 elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split near = 2. far = 6. if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data( scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R - 1. far = hemi_R + 1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if args.render_test: render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) global_step = start bds_dict = { 'near': near, 'far': far, } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Move testing data to GPU render_poses = torch.Tensor(render_poses).to(device) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') with torch.no_grad(): if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None testsavedir = os.path.join( basedir, expname, 'renderonly_{}_{:06d}'.format( 'test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: # For random ray batching print('get rays') rays = np.stack( [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] print('done, concats') rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] # [N, ro+rd+rgb+depth, H, W, 3] gt_depths_channeled = np.zeros( (gt_depths.shape[0], gt_depths.shape[1], gt_depths.shape[2], 3)) for i in range(gt_depths.shape[0]): gt_depths_channeled[i] = np.stack([gt_depths[i]] * 3, 2) rays_rgb = np.concatenate( [rays_rgb, gt_depths_channeled[:, None, ...]], 1) rays_rgb = np.transpose( rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb+depth, 3] rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only rays_rgb = np.reshape(rays_rgb, [-1, 4, 3]) # [(N-1)*H*W, ro+rd+rgb+depth, 3] rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) print('done') i_batch = 0 # Move training data to GPU images = torch.Tensor(images).to(device) poses = torch.Tensor(poses).to(device) gt_depths = torch.Tensor(gt_depths).to(device) if use_batching: rays_rgb = torch.Tensor(rays_rgb).to(device) N_iters = 200000 + 1 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) start = start + 1 for i in trange(start, N_iters): time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1+1, 3*?] batch = torch.transpose(batch, 0, 1) # target_depth[n, depth] = example_id, observed color. batch_rays, target_s, target_depth = batch[:2], batch[2], batch[ 3, :, 0] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: print("Shuffle data after an epoch!") rand_idx = torch.randperm(rays_rgb.shape[0]) rays_rgb = rays_rgb[rand_idx] i_batch = 0 else: # Random from one image img_i = np.random.choice(i_train) target = images[img_i] pose = poses[img_i, :3, :4] depth_i = gt_depths[img_i] depth_i = depth_i * DEPTH_RATIO # scale down the depth pix_T_cam = pack_intrinsics(4.39947516e+02, 4.39947516e+02, 240, 320) # distortion = 2.78915786e-02 xyz_i = depth2pointcloud( depth_i.unsqueeze(0).unsqueeze(0), pix_T_cam) #unproject to get pointcloud origin_T_camX = eye_4x4(1) origin_T_camX[:, :3, :4] = poses[img_i:img_i + 1, :4, :4] xyz_w_i = apply_4x4( origin_T_camX, xyz_i) # get the world coordinates (NeRF works in C2W) if N_rand is not None: rays_o, rays_d = get_rays( H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) if i < args.precrop_iters: dH = int(H // 2 * args.precrop_frac) dW = int(W // 2 * args.precrop_frac) coords = torch.stack( torch.meshgrid( torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)), -1) if i == start: print( "[Config] Center cropping of size {} x {} is enabled until iter {}" .format(2 * dH, 2 * dW, args.precrop_iters)) else: coords = torch.stack( torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)), -1) # (H, W, 2) coords = torch.reshape(coords, [-1, 2]) # (H * W, 2) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) select_coords = coords[select_inds].long() # (N_rand, 2) rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) batch_rays = torch.stack([rays_o, rays_d], 0) # (2, N_rand, 3) target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) target_depth = depth_i[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 1) target_xyz = xyz_w_i[0, select_inds] # (1, N_rand, 3) target_xyz = target_xyz[ target_depth > 0.] # Only select those points where depth is valid ##### Core optimization loop ##### rgb, disp, acc, depth, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) optimizer.zero_grad() img_loss = img2mse(rgb, target_s) trans = extras['raw'][..., -1] loss = img_loss psnr = mse2psnr(img_loss) depth_mask = torch.ones(target_depth[target_depth != 0.].shape) avg_depth_ratio = reduce_masked_mean( depth[target_depth != 0.] / target_depth[target_depth != 0.], depth_mask) # st() free_occ_xyz = fill_ray_single(target_xyz) target_labels = torch.zeros( (free_occ_xyz.shape[0], free_occ_xyz.shape[1], 1), device='cuda') target_labels[:, -1] = 1 # last poinnt is occupied free_occ_xyz = torch.reshape(free_occ_xyz, (-1, 3)).unsqueeze(1) target_labels = torch.reshape(target_labels, (-1, 1)) rays_to_pass = batch_rays[:, target_depth > 0., :] # (2, N, 3) rays_to_pass = rays_to_pass.repeat(1, 100, 1) # (2, N*samps, 3) sigma = eval_depth( pts=free_occ_xyz, rays=rays_to_pass, network_fn=render_kwargs_train['network_fn'], network_query_fn=render_kwargs_train['network_query_fn']) occfreespace_loss = F.binary_cross_entropy_with_logits( sigma, target_labels) loss += occfreespace_loss if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0) avg_depth_ratio0 = reduce_masked_mean( extras['depth0'][target_depth != 0.] / target_depth[target_depth != 0.], depth_mask) loss.backward() optimizer.step() # NOTE: IMPORTANT! ### update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate**(global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate ################################ dt = time.time() - time0 # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") ##### end ##### # Rest is logging if i % args.i_weights == 0: path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) torch.save( { 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, path) print('Saved checkpoints at', path) if i % args.i_video == 0 and i > 0: # Turn on testing mode with torch.no_grad(): rgbs, disps, depths = render_path(render_poses, hwf, args.chunk, render_kwargs_test) print('Done, saving', rgbs.shape, disps.shape, depths.shape) moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) imageio.mimwrite(moviebase + 'depth.mp4', to8b(depths / np.max(depths)), fps=30, quality=8) # if args.use_viewdirs: # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] # with torch.no_grad(): # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) # render_kwargs_test['c2w_staticcam'] = None # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) if i % args.i_testset == 0 and i > 0: testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[i_test].shape) with torch.no_grad(): render_path(torch.Tensor(poses[i_test]).to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) print('Saved test set') if i % args.i_print == 0: tqdm.write("[TRAIN] Iter: {} Loss: {} PSNR: {}".format( i, loss.item(), psnr.item())) writer.add_scalar('loss', loss, i) writer.add_scalar('img_loss', img_loss, i) writer.add_scalar('occ_loss', occfreespace_loss, i) writer.add_scalar('psnr', psnr, i) writer.add_histogram('tran', trans, i) writer.add_scalar('avg_depth_ratio', avg_depth_ratio, i) if args.N_importance > 0: writer.add_scalar('psnr0', psnr0, i) if i % args.i_img == 0: # Log a rendered validation view to Tensorboard img_i = np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3, :4] with torch.no_grad(): rgb, disp, acc, depth, extras = render( H, W, focal, chunk=args.chunk, c2w=pose, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) writer.add_image('rgb', to8b(rgb.cpu().numpy()), i, dataformats='HWC') writer.add_image('disp', disp.unsqueeze(0), i) writer.add_image('acc', acc.unsqueeze(0), i) writer.add_image('depth', depth.unsqueeze(0), i) writer.add_scalar('psnr_holdout', psnr, i) writer.add_image('rgb_holdout', target, i, dataformats='HWC') global_step += 1
def train(): parser = config_parser() args = parser.parse_args() # Load data if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) print('DEFINING BOUNDS') if args.no_ndc: near = np.ndarray.min(bds) * .9 far = np.ndarray.max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split near = 2. far = 6. if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data( scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R - 1. far = hemi_R + 1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if args.render_test: render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) global_step = start bds_dict = { 'near': near, 'far': far, } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Move testing data to GPU render_poses = torch.Tensor(render_poses).to(device) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') with torch.no_grad(): if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None testsavedir = os.path.join( basedir, expname, 'renderonly_{}_{:06d}'.format( 'test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: # For random ray batching print('get rays') rays = np.stack( [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] print('done, concats') rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) print('done') i_batch = 0 # Move training data to GPU images = torch.Tensor(images).to(device) poses = torch.Tensor(poses).to(device) if use_batching: rays_rgb = torch.Tensor(rays_rgb).to(device) print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) start = start + 1 for i in trange(start, args.N_iters + 1): time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1, 3*?] batch = torch.transpose(batch, 0, 1) batch_rays, target_s = batch[:2], batch[2] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: print("Shuffle data after an epoch!") rand_idx = torch.randperm(rays_rgb.shape[0]) rays_rgb = rays_rgb[rand_idx] i_batch = 0 else: # Random from one image img_i = np.random.choice(i_train) target = images[img_i] pose = poses[img_i, :3, :4] if N_rand is not None: rays_o, rays_d = get_rays( H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) if i < args.precrop_iters: dH = int(H // 2 * args.precrop_frac) dW = int(W // 2 * args.precrop_frac) coords = torch.stack( torch.meshgrid( torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)), -1) if i == start: print( f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}" ) else: coords = torch.stack( torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)), -1) # (H, W, 2) coords = torch.reshape(coords, [-1, 2]) # (H * W, 2) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) select_coords = coords[select_inds].long() # (N_rand, 2) rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) batch_rays = torch.stack([rays_o, rays_d], 0) target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) ##### Core optimization loop ##### rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) optimizer.zero_grad() img_loss = img2mse(rgb, target_s) trans = extras['raw'][..., -1] loss = img_loss psnr = mse2psnr(img_loss) if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0) loss.backward() optimizer.step() # NOTE: IMPORTANT! ### update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate**(global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate ################################ dt = time.time() - time0 # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") ##### end ##### # Rest is logging if i % args.i_weights == 0: path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) torch.save( { 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, path) print('Saved checkpoints at', path) if i % args.i_video == 0 and i > 0: # Turn on testing mode with torch.no_grad(): rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test) print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) # if args.use_viewdirs: # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] # with torch.no_grad(): # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) # render_kwargs_test['c2w_staticcam'] = None # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) if i % args.i_testset == 0 and i > 0: testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[i_test].shape) with torch.no_grad(): render_path(torch.Tensor(poses[i_test]).to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) print('Saved test set') if i % args.i_print == 0: tqdm.write( f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") global_step += 1
def get_data(): basedir = './logs' expname = 'fern_example' config = os.path.join(basedir, expname, 'config.txt') print('Args:') print(open(config, 'r').read()) parser = run_nerf.config_parser() weights_name = 'model_200000.npy' args = parser.parse_args('--config {} --ft_path {}'.format(config, os.path.join(basedir, expname, weights_name))) print('loaded args') images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) H, W, focal = poses[0,:3,-1].astype(np.float32) poses = poses[:, :3, :4] H = int(H) W = int(W) hwf = [H, W, focal] images = images.astype(np.float32) poses = poses.astype(np.float32) near = 0. far = 1. if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] _, render_kwargs, start, grad_vars, models = run_nerf.create_nerf(args) bds_dict = { 'near' : tf.cast(near, tf.float32), 'far' : tf.cast(far, tf.float32), } render_kwargs.update(bds_dict) print('Render kwargs:') pprint.pprint(render_kwargs) results = {} results['pc'] = {} results['no_pc'] = {} # NOTE: Where to output results! result_directory = "./fern_pc_results" img_dir = os.path.join(result_directory, "imgs") down = 1 plt.imsave(os.path.join(img_dir, f"GT{i_test[0]}.png"), images[i_test[0]]) plt.imsave(os.path.join(img_dir, f"GT{i_test[1]}.png"), images[i_test[1]]) for num_samps in [4,8,16,32,64]: print(f'Running {num_samps} sample test') for pc in [True, False]: print(f'{"not " if not pc else ""}using pc') results['pc' if pc else 'no_pc'][num_samps] = {} render_kwargs['N_samples'] = num_samps render_kwargs['N_importance'] = 2*num_samps total_time = 0 total_mse = 0 total_psnr = 0 for i in [i_test[0], i_test[1]]: gt = images[i] start_time = time.time() ret_vals = run_nerf.render(H//down, W//down, focal/down, c2w=poses[i], pc=pc, cloudsize=16, **render_kwargs) end_time = time.time() # add to cum time total_time += (end_time - start_time) # add to accuracy img = np.clip(ret_vals[0],0,1) # TODO: make sure this is commented out for real results (just used to test that it runs) # mse = run_nerf.img2mse(np.zeros((H//down, W//down,3), dtype=np.float32), img) mse = run_nerf.img2mse(gt, img) psnr = run_nerf.mse2psnr(mse) total_mse += float(mse) total_psnr += float(psnr) plt.imsave(os.path.join(img_dir, f'IMG{i}_{"pc" if pc else "no_pc"}_{num_samps}samples.png'), img) total_time /= 2. total_mse /= 2. total_psnr /= 2. results['pc' if pc else 'no_pc'][num_samps]['time'] = total_time results['pc' if pc else 'no_pc'][num_samps]['mse'] = total_mse results['pc' if pc else 'no_pc'][num_samps]['psnr'] = total_psnr with open(os.path.join(result_directory, 'results.txt'), 'w') as outfile: json.dump(results,outfile)
def train(): parser = config_parser() args = parser.parse_args() # Load data if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0,:3,-1] poses = poses[:,:3,:4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val)]) print('DEFINING BOUNDS') if args.no_ndc: near = tf.reduce_min(bds) * .9 far = tf.reduce_max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split near = 2. far = 6. if args.white_bkgd: images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) else: images = images[...,:3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1)) near = hemi_R-1. far = hemi_R+1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if args.render_test: render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf(args) bds_dict = { 'near' : tf.cast(near, tf.float32), 'far' : tf.cast(far, tf.float32), } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', global_step.numpy())) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Create optimizer lrate = args.lrate if args.lrate_decay > 0: lrate = tf.keras.optimizers.schedules.ExponentialDecay(lrate, decay_steps=args.lrate_decay * 1000, decay_rate=0.1) optimizer = tf.keras.optimizers.Adam(lrate) models['optimizer'] = optimizer global_step = tf.compat.v1.train.get_or_create_global_step() global_step.assign(start) # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: # For random ray batching print('get rays') rays = np.stack([get_rays_np(H, W, focal, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] print('done, concats') rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) print('done') i_batch = 0 N_iters = 1000000 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers writer = tf.contrib.summary.create_file_writer(os.path.join(basedir, 'summaries', expname)) writer.set_as_default() for i in range(start, N_iters): time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] batch = tf.transpose(batch, [1,0,2]) batch_rays, target_s = batch[:2], batch[2] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: np.random.shuffle(rays_rgb) i_batch = 0 else: # Random from one image img_i = np.random.choice(i_train) target = images[img_i] pose = poses[img_i, :3,:4] if N_rand is not None: rays_o, rays_d = get_rays(H, W, focal, pose) coords = tf.stack(tf.meshgrid(tf.range(H), tf.range(W), indexing='ij'), -1) coords = tf.reshape(coords, [-1,2]) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) select_inds = tf.gather_nd(coords, select_inds[:,tf.newaxis]) rays_o = tf.gather_nd(rays_o, select_inds) rays_d = tf.gather_nd(rays_d, select_inds) batch_rays = tf.stack([rays_o, rays_d], 0) target_s = tf.gather_nd(target, select_inds) ##### Core optimization loop ##### with tf.GradientTape() as tape: rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) img_loss = img2mse(rgb, target_s) trans = extras['raw'][...,-1] loss = img_loss psnr = mse2psnr(img_loss) if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss += img_loss0 psnr0 = mse2psnr(img_loss0) gradients = tape.gradient(loss, grad_vars) optimizer.apply_gradients(zip(gradients, grad_vars)) dt = time.time()-time0 ##### end ##### # Rest is logging def save_weights(net, prefix, i): path = os.path.join(basedir, expname, '{}_{:06d}.npy'.format(prefix, i)) np.save(path, net.get_weights()) print('saved weights at', path) if i%args.i_weights==0: for k in models: save_weights(models[k], k, i) if i%args.i_video==0 and i > 0: rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test) print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if args.use_viewdirs: render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) render_kwargs_test['c2w_staticcam'] = None imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) if i%args.i_testset==0 and i > 0: testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[i_test].shape) render_path(poses[i_test], hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) print('Saved test set') if i%args.i_print==0 or i < 10: print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) print('iter time {:.05f}'.format(dt)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): tf.contrib.summary.scalar('loss', loss) tf.contrib.summary.scalar('psnr', psnr) tf.contrib.summary.histogram('tran', trans) if args.N_importance > 0: tf.contrib.summary.scalar('psnr0', psnr0) if i%args.i_img==0: # Log a rendered validation view to Tensorboard img_i=np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3,:4] rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis]) tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis]) tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis]) tf.contrib.summary.scalar('psnr_holdout', psnr) tf.contrib.summary.image('rgb_holdout', target[tf.newaxis]) if args.N_importance > 0: with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis]) tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis]) tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis]) global_step.assign_add(1)
def train(): parser = config_parser() args = parser.parse_args() if args.random_seed is not None: print('Fixing random seed', args.random_seed) np.random.seed(args.random_seed) tf.compat.v1.set_random_seed(args.random_seed) # Load data sc = 1. center = None if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test, center = load_llff_data(args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] # with open(f, 'w') as file: # file.write(open(args.config, 'r').read()) print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) # i_test = np.arange(images.shape[0])[::args.llffhold] i_test = np.arange(images.shape[0])[62::3] #i_test = np.concatenate([np.arange(images.shape[0])[-4::], # np.arange(images.shape[0])[0:int(images.shape[0] /# args.llffhold):] # ]) i_val = i_test i_train = np.array([i for i in np.arange(int(images.shape[0]))]) # i_train = np.array([i for i in np.arange(int(images.shape[0])) if # (i not in i_test and i not in i_val)]) print('DEFINING BOUNDS') if args.no_ndc: near = tf.reduce_min(bds) * .9 far = tf.reduce_max(bds) * 1. else: near = 0. far = 1. sc = 1. / (bds.min() * .75) print('NEAR FAR', near, far) elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split near = 2. far = 6. if args.white_bkgd: images = images[..., :3]*images[..., -1:] + (1.-images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R-1. far = hemi_R+1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) if args.render_test: render_poses = np.array(poses[i_test]) if args.render_poses: # load the render path from file render_poses = np.loadtxt(os.path.join(args.datadir, "render_poses.txt")) if render_poses.shape[0] >= 15: # render_poses.reshape((render_poses.shape[0] / 15, 3,5)) render_poses = np.reshape(render_poses, (render_poses.shape[0] / 15, 3, 5)).astype(np.float32) else: # render_poses.reshape((render_poses.shape[0], 3, 5)) render_poses = np.reshape(render_poses, (render_poses.shape[0], 3, 5)).astype(np.float32) # Correct rotation matrix ordering and move variable dim to axis 0 render_poses = np.concatenate([render_poses[:, 1:2, :], -render_poses[:, 0:1, :], render_poses[:, 2:, :]], 1) render_poses = np.moveaxis(render_poses, -1, 0).astype(np.float32) # Rescale if bd_factor is provided # sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor) print("sc", sc) render_poses[:, :3, 3] *= sc #1.333 # output from load_llff.py # render_poses = recenter_poses(render_poses) # should recenter according to # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf modelllffhold render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf( args) bds_dict = { 'near': tf.cast(near, tf.float32), 'far': tf.cast(far, tf.float32), } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') if args.render_test: # render_test switches to test poses images = images[i_test] else: # Default is smoother render_poses path images = None folder = 'path' if args.render_test: folder = 'test' elif args.render_poses: folder = 'poses' # testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format( # 'test' if args.render_test else 'path', start)) testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format( folder, start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor, ndc=args.NDC) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Create optimizer lrate = args.lrate if args.lrate_decay > 0: lrate = tf.keras.optimizers.schedules.ExponentialDecay(lrate, decay_steps=args.lrate_decay * 1000, decay_rate=0.1) optimizer = tf.keras.optimizers.Adam(lrate) models['optimizer'] = optimizer global_step = tf.compat.v1.train.get_or_create_global_step() global_step.assign(start) # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: # For random ray batching. # # Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is # interpreted as, # axis=0: ray origin in world space # axis=1: ray direction in world space # axis=2: observed RGB color of pixel print('get rays') # get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3] # for each pixel in the image. This stack() adds a new dimension. rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]] rays = np.stack(rays, axis=0) # [N, ro+rd, H, W, 3] print('done, concats') # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.concatenate([rays, images[:, None, ...]], 1) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) rays_rgb = np.stack([rays_rgb[i] for i in i_train], axis=0) # train images only # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) print('done') i_batch = 0 N_iters = 1000000 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers writer = tf.contrib.summary.create_file_writer( os.path.join(basedir, 'summaries', expname)) writer.set_as_default() for i in range(start, N_iters): time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] batch = tf.transpose(batch, [1, 0, 2]) # batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position # target_s[n, rgb] = example_id, observed color. batch_rays, target_s = batch[:2], batch[2] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: np.random.shuffle(rays_rgb) i_batch = 0 else: # Random from one image img_i = np.random.choice(i_train) target = images[img_i] pose = poses[img_i, :3, :4] if N_rand is not None: rays_o, rays_d = get_rays(H, W, focal, pose) if i < args.precrop_iters: dH = int(H//2 * args.precrop_frac) dW = int(W//2 * args.precrop_frac) coords = tf.stack(tf.meshgrid( tf.range(H//2 - dH, H//2 + dH), tf.range(W//2 - dW, W//2 + dW), indexing='ij'), -1) if i < 10: print('precrop', dH, dW, coords[0,0], coords[-1,-1]) else: coords = tf.stack(tf.meshgrid( tf.range(H), tf.range(W), indexing='ij'), -1) coords = tf.reshape(coords, [-1, 2]) select_inds = np.random.choice( coords.shape[0], size=[N_rand], replace=False) select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis]) rays_o = tf.gather_nd(rays_o, select_inds) rays_d = tf.gather_nd(rays_d, select_inds) batch_rays = tf.stack([rays_o, rays_d], 0) target_s = tf.gather_nd(target, select_inds) ##### Core optimization loop ##### with tf.GradientTape() as tape: # Make predictions for color, disparity, accumulated opacity. rgb, disp, acc, extras = render( H, W, focal, chunk=args.chunk, rays=batch_rays, ndc=args.NDC, verbose=i < 10, retraw=True, **render_kwargs_train) # Compute MSE loss between predicted and true RGB. img_loss = img2mse(rgb, target_s) trans = extras['raw'][..., -1] loss = img_loss psnr = mse2psnr(img_loss) # Add MSE loss for coarse-grained model if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss += img_loss0 psnr0 = mse2psnr(img_loss0) gradients = tape.gradient(loss, grad_vars) optimizer.apply_gradients(zip(gradients, grad_vars)) dt = time.time()-time0 ##### end ##### # Rest is logging def save_weights(net, prefix, i): path = os.path.join( basedir, expname, '{}_{:06d}.npy'.format(prefix, i)) np.save(path, net.get_weights()) print('saved weights at', path) if i % args.i_weights == 0: for k in models: save_weights(models[k], k, i) if i % args.i_video == 0 and i > 0: rgbs, disps = render_path( render_poses, hwf, args.chunk, render_kwargs_test, ndc=args.NDC) print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join( basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if args.use_viewdirs: render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4] rgbs_still, _ = render_path( render_poses, hwf, args.chunk, render_kwargs_test, ndc=args.NDC) render_kwargs_test['c2w_staticcam'] = None imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) if i % args.i_testset == 0 and i > 0: testsavedir = os.path.join( basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[i_test].shape) render_path(poses[i_test], hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir, ndc=args.NDC) print('Saved test set') if i % args.i_print == 0 or i < 10: print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) print('iter time {:.05f}'.format(dt)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): tf.contrib.summary.scalar('loss', loss) tf.contrib.summary.scalar('psnr', psnr) tf.contrib.summary.histogram('tran', trans) if args.N_importance > 0: tf.contrib.summary.scalar('psnr0', psnr0) if i % args.i_img == 0: # Log a rendered validation view to Tensorboard img_i = np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3, :4] rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,ndc=args.NDC, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) # Save out the validation image for Tensorboard-free monitoring testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs') if i==0: os.makedirs(testimgdir, exist_ok=True) imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(rgb)) with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis]) tf.contrib.summary.image( 'disp', disp[tf.newaxis, ..., tf.newaxis]) tf.contrib.summary.image( 'acc', acc[tf.newaxis, ..., tf.newaxis]) tf.contrib.summary.scalar('psnr_holdout', psnr) tf.contrib.summary.image('rgb_holdout', target[tf.newaxis]) if args.N_importance > 0: with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image( 'rgb0', to8b(extras['rgb0'])[tf.newaxis]) tf.contrib.summary.image( 'disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis]) tf.contrib.summary.image( 'z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis]) global_step.assign_add(1)