def load_model(instance, config, expname=None): ''' returns: nerf model and style code ''' parser = config_parser() if expname is None: args = parser.parse_args(['--config', config]) else: args = parser.parse_args(['--config', config, '--expname', expname]) render_kwargs_train, render_kwargs_test, _, _, optimizer, styles = create_nerf( args, return_styles=True) if styles.shape[0] == 1: basedir_args = parser.parse_args(['--config', config]) basedir_styles = create_nerf(basedir_args, return_styles=True)[-1] styles = torch.cat([styles[:1], basedir_styles[1:]]) return render_kwargs_train, render_kwargs_test, optimizer, styles
def load_dataset(instance, config, N_instances, num_canvases=9, expname=None, use_cached=True): parser = config_parser() if expname is None: args = parser.parse_args(['--config', config]) else: args = parser.parse_args(['--config', config, '--expname', expname]) N_instances = 1 poses, hwfs = get_poses_hwfs(args, instance, N_instances, num_canvases=num_canvases) if use_cached: cache = {k: [] for k in ['alphas', 'features', 'weights']} # You can choose to save the cache on disk. # Saving the cache takes a while, and loading the cache takes a while the first time, but this is faster if you're editing the same instance frequently. # cache_dir = os.path.join(args.expname, 'cache') # if not os.path.exists(f'{cache_dir}/{instance}_{num_canvases-1}.pt'): # alphas, features, weights = get_cache(config, instance, expname, save=True) # cache['alphas'] = alphas[:num_canvases] # cache['features'] = features[:num_canvases] # cache['weights'] = weights[:num_canvases] # else: # for j in range(len(poses)): # data = torch.load(f'{cache_dir}/{instance}_{j}.pt') # cache['alphas'].append(data['alphas'].cuda()) # cache['features'].append(data['features'].cuda()) # cache['weights'].append(data['weights'].cuda()) # Computes and loads the cache. Computing the cache doesn't take noticably longer than rendering. alphas, features, weights = get_cache(config, instance, expname, num_canvases) cache['alphas'] = alphas[:num_canvases] cache['features'] = features[:num_canvases] cache['weights'] = weights[:num_canvases] else: cache = None return poses, hwfs, cache, args
def get_cache(config, instance, expname=None, num_canvases=9, save=False): parser = config_parser() if expname is None: args = parser.parse_args(['--config', config]) N_instances = None else: args = parser.parse_args(['--config', config, '--expname', expname]) N_instances = 1 cachedir = os.path.join(args.expname, 'cache') with torch.no_grad(): render_kwargs_train, render_kwargs_test, _, _, optimizer, styles = create_nerf( args, return_styles=True) if N_instances is None: N_instances = styles.shape[0] nfs = [[args.blender_near, args.blender_far] for _ in range(10)] os.makedirs(cachedir, exist_ok=True) poses, hwfs = get_poses_hwfs(args, instance, N_instances, num_canvases) style = styles[instance].repeat((poses.shape[0], 1)) rgbs, disps, _, alphas, features, weights = render_path( poses.cuda(), style, hwfs, args.chunk, render_kwargs_test, nfs=nfs, verbose=False, get_cached='color') if save: for j, (a, f, w) in enumerate(zip(alphas, features, weights)): torch.save({ 'alphas': a, 'features': f, 'weights': w }, f'{cachedir}/{instance}_{j}.pt') utils.save_image( torch.tensor(rgbs[:1]).permute(0, 3, 1, 2).cpu(), os.path.join(args.expname, '{:03d}.png'.format(instance))) return alphas, features, weights
def train(): parser = config_parser() args = parser.parse_args() # Create log dir and copy the config file basedir = args.basedir expname = args.savedir if args.savedir else args.expname print('Experiment dir:', expname) # Load data images, poses, style, i_test, i_train, bds_dict, dataset, hwfs, near_fars, style_inds = load_data( args) _, poses_test, style_test, hwfs_test, nf_test = images[i_test], poses[ i_test], style[i_test], hwfs[i_test], near_fars[i_test] _, poses_train, style_train, hwfs_train, nf_train = images[i_train], poses[ i_train], style[i_train], hwfs[i_train], near_fars[i_train] os.makedirs(os.path.join(basedir, expname), exist_ok=True) np.save(os.path.join(basedir, expname, 'poses.npy'), poses_train.cpu()) np.save(os.path.join(basedir, expname, 'hwfs.npy'), hwfs_train.cpu()) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) print(render_kwargs_train['network_fine']) old_coarse_network = copy.deepcopy( render_kwargs_train['network_fn']).state_dict() old_fine_network = copy.deepcopy( render_kwargs_train['network_fine']).state_dict() global_step = start real_image_application = (args.real_image_dir is not None) optimize_mlp = not real_image_application render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) loss = None if start == 0: # if we're starting from scratch, delete all the logs in that directory. if os.path.exists(os.path.join(basedir, expname, 'log.txt')): os.remove(os.path.join(basedir, expname, 'log.txt')) start = start + 1 for i in range(start, args.n_iters + 1): # Sample random ray batch batch_rays, target_s, style, H, W, focal, near, far, viewdirs_reg = dataset.get_data_batch( train_fn=render_kwargs_train, optimizer=optimizer, loss=loss) render_kwargs_train.update({'near': near, 'far': far}) ##### Core optimization loop ##### rgb, disp, acc, extras = render(H, W, focal, style=style, chunk=args.chunk, rays=batch_rays, viewdirs_reg=viewdirs_reg, **render_kwargs_train) optimizer.zero_grad() img_loss = img2mse(rgb, target_s) loss = img_loss psnr = mse2psnr(img_loss) if args.var_param > 0: var = extras['var'] var0 = extras['var0'] var_loss = var.mean(dim=0) var_loss_coarse = var0.mean(dim=0) loss += args.var_param * var_loss loss += args.var_param * var_loss_coarse var_loss = var_loss.item() var_loss_coarse = var_loss_coarse.item() else: var_loss = 0 var_loss_coarse = 0 if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0).item() else: psnr0 = -1 if args.weight_change_param >= 0: weight_change_loss_coarse = 0. for k, v in render_kwargs_train['network_fn'].named_parameters(): if 'weight' in k: diff = (old_coarse_network[k] - v).pow(2).mean() weight_change_loss_coarse += diff weight_change_loss_fine = 0. for k, v in render_kwargs_train['network_fine'].named_parameters(): if 'weight' in k: diff = (old_fine_network[k] - v).pow(2).mean() weight_change_loss_fine += diff weight_change_loss = weight_change_loss_coarse + weight_change_loss_fine loss = loss + args.weight_change_param * weight_change_loss else: weight_change_loss = torch.tensor(0.) loss.backward() if optimize_mlp: optimizer.step() # NOTE: IMPORTANT! decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate**(global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate ################################ ##### end ##### if i % args.i_weights == 0: path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) state_dict = { 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'styles': dataset.style, 'style_optimizer': dataset.style_optimizer.state_dict() } if args.N_importance > 0: state_dict['network_fine_state_dict'] = render_kwargs_train[ 'network_fine'].state_dict() torch.save(state_dict, path) print('Saved checkpoints at', path) if i % args.i_testset == 0 and i > 0: if real_image_application: style_test = dataset.get_features().repeat( (poses_test.shape[0], 1)) testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) with torch.no_grad(): render_path(poses_test.to(device), style_test, hwfs_test, args.chunk, render_kwargs_test, nfs=nf_test, savedir=testsavedir, maximum=100) print('Saved test set') if i % args.i_trainset == 0 and i > 0: if real_image_application: style_train = dataset.get_features().repeat( (poses_train.shape[0], 1)) trainsavedir = os.path.join(basedir, expname, 'trainset_{:06d}'.format(i)) os.makedirs(trainsavedir, exist_ok=True) with torch.no_grad(): render_path(poses_train.to(device), style_train, hwfs_train, args.chunk, render_kwargs_test, nfs=nf_train, savedir=trainsavedir, maximum=100) print('Saved train set') if i % args.i_print == 0 or i == 1: log_str = f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()} PSNR0: {psnr0} Var loss: {var_loss} Var loss coarse: {var_loss_coarse} Weight change loss: {weight_change_loss}" with open(os.path.join(basedir, expname, 'log.txt'), 'a+') as f: f.write(log_str + '\n') print(log_str) global_step += 1 if real_image_application and global_step - start == args.n_iters_real: return if real_image_application and global_step - start == args.n_iters_code_only: optimize_mlp = True dataset.optimizer_name = 'adam' dataset.style_optimizer = torch.optim.Adam(dataset.params, lr=dataset.lr) print('Starting to jointly optimize weights with code')
log_str = f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()} PSNR0: {psnr0} Var loss: {var_loss} Var loss coarse: {var_loss_coarse} Weight change loss: {weight_change_loss}" with open(os.path.join(basedir, expname, 'log.txt'), 'a+') as f: f.write(log_str + '\n') print(log_str) global_step += 1 if real_image_application and global_step - start == args.n_iters_real: return if real_image_application and global_step - start == args.n_iters_code_only: optimize_mlp = True dataset.optimizer_name = 'adam' dataset.style_optimizer = torch.optim.Adam(dataset.params, lr=dataset.lr) print('Starting to jointly optimize weights with code') if __name__ == '__main__': parser = config_parser() args = parser.parse_args() if args.instance != -1: # Allows for scripting over single instance experiments. exit_if_job_done(os.path.join(args.basedir, args.expname)) torch.set_default_tensor_type('torch.cuda.FloatTensor') train() mark_job_done(os.path.join(args.basedir, args.expname)) else: torch.set_default_tensor_type('torch.cuda.FloatTensor') train()
def test(): parser = config_parser() args = parser.parse_args() images, poses, style, i_test, i_train, bds_dict, dataset, hwfs, near_fars, _ = load_data( args) images_test, poses_test, style_test, hwfs_test, nf_test = images[ i_test], poses[i_test], style[i_test], hwfs[i_test], near_fars[i_test] images_train, poses_train, style_train, hwfs_train, nf_train = images[ i_train], poses[i_train], style[i_train], hwfs[i_train], near_fars[ i_train] # Create log dir and copy the config file basedir = args.basedir expname = args.expname render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) np.save(os.path.join(basedir, expname, 'poses.npy'), poses_train.cpu()) np.save(os.path.join(basedir, expname, 'hwfs.npy'), hwfs_train.cpu()) render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) with torch.no_grad(): if args.render_test: if args.shuffle_poses: print('Shuffling test poses') permutation = list(range(len(poses_test))) random.shuffle(permutation) poses_test = poses_test[permutation] testsavedir = os.path.join(basedir, expname, 'test_imgs{:06d}'.format(start)) os.makedirs(testsavedir, exist_ok=True) _, _, psnr = render_path(poses_test.to(device), style_test, hwfs_test, args.chunk, render_kwargs_test, nfs=nf_test, gt_imgs=images_test, savedir=testsavedir) print('Saved test set w/ psnr', psnr) if args.render_train: if args.shuffle_poses: print('Shuffling train poses') permutation = list(range(len(poses_train))) random.shuffle(permutation) poses_train = poses_train[permutation] trainsavedir = os.path.join(basedir, expname, 'train_imgs{:06d}'.format(start)) os.makedirs(trainsavedir, exist_ok=True) _, _, psnr = render_path(poses_train.to(device), style_train, hwfs_train, args.chunk, render_kwargs_test, nfs=nf_train, gt_imgs=images_train, savedir=trainsavedir) print('Saved train set w/ psnr', psnr)
def transfer_codes(config): torch.set_default_tensor_type('torch.cuda.FloatTensor') N_per_transfer = 8 N_transfer = 10 parser = config_parser() args = parser.parse_args(['--config', config]) render_kwargs_train, render_kwargs_test, _, _, optimizer, styles = create_nerf( args, return_styles=True) nfs = [[args.blender_near, args.blender_far] for _ in range(N_per_transfer)] os.makedirs(f'{args.expname}/transfer_codes', exist_ok=True) with torch.no_grad(): for i1 in range(N_transfer): for i2 in tqdm(range(N_transfer)): if i1 == i2: continue s1, s2 = styles[i1].unsqueeze(dim=0), styles[i2].unsqueeze( dim=0) poses_1, hwfs_1 = get_poses_hwfs(args, i1, styles.shape[0], num_canvases=N_per_transfer) poses_2, hwfs_2 = get_poses_hwfs(args, i2, styles.shape[0], num_canvases=N_per_transfer) rgb1 = render_path(poses_1, s1, hwfs_1, 4096, render_kwargs_test, nfs=nfs, verbose=False)[0][0] rgb2 = render_path(poses_2, s2, hwfs_2, 4096, render_kwargs_test, nfs=nfs, verbose=False)[0][0] utils.save_image( torch.tensor(rgb1).permute(2, 0, 1), f'{args.expname}/transfer_codes/{i1}.png') utils.save_image( torch.tensor(rgb2).permute(2, 0, 1), f'{args.expname}/transfer_codes/{i2}.png') take_color = torch.cat([s1[:, :32], s2[:, 32:]], dim=1) take_shape = torch.cat([s2[:, :32], s1[:, 32:]], dim=1) # i1 with shape from i2 is the same as i2 color from i1, so don't duplicate color_from = render_path(poses_1, take_color.repeat( (N_per_transfer, 1)), hwfs_1, 4096, render_kwargs_test, nfs=nfs, verbose=False)[0] shape_from = render_path(poses_2, take_shape.repeat( (N_per_transfer, 1)), hwfs_2, 4096, render_kwargs_test, nfs=nfs, verbose=False)[0] utils.save_image( torch.tensor(color_from).permute(0, 3, 1, 2), f'{args.expname}/transfer_codes/{i1}_color_from_{i2}.png') utils.save_image( torch.tensor(shape_from).permute(0, 3, 1, 2), f'{args.expname}/transfer_codes/{i1}_shape_from_{i2}.png')
def load_config(config): parser = config_parser() return parser.parse_args(['--config', config])