def load_model(instance, config, expname=None):
    ''' returns: nerf model and style code '''
    parser = config_parser()
    if expname is None:
        args = parser.parse_args(['--config', config])
    else:
        args = parser.parse_args(['--config', config, '--expname', expname])
    render_kwargs_train, render_kwargs_test, _, _, optimizer, styles = create_nerf(
        args, return_styles=True)
    if styles.shape[0] == 1:
        basedir_args = parser.parse_args(['--config', config])
        basedir_styles = create_nerf(basedir_args, return_styles=True)[-1]
        styles = torch.cat([styles[:1], basedir_styles[1:]])
    return render_kwargs_train, render_kwargs_test, optimizer, styles
def load_dataset(instance,
                 config,
                 N_instances,
                 num_canvases=9,
                 expname=None,
                 use_cached=True):
    parser = config_parser()
    if expname is None:
        args = parser.parse_args(['--config', config])
    else:
        args = parser.parse_args(['--config', config, '--expname', expname])
        N_instances = 1

    poses, hwfs = get_poses_hwfs(args,
                                 instance,
                                 N_instances,
                                 num_canvases=num_canvases)

    if use_cached:
        cache = {k: [] for k in ['alphas', 'features', 'weights']}

        # You can choose to save the cache on disk.
        # Saving the cache takes a while, and loading the cache takes a while the first time, but this is faster if you're editing the same instance frequently.

        # cache_dir = os.path.join(args.expname, 'cache')
        # if not os.path.exists(f'{cache_dir}/{instance}_{num_canvases-1}.pt'):
        #     alphas, features, weights = get_cache(config, instance, expname, save=True)
        #     cache['alphas'] = alphas[:num_canvases]
        #     cache['features'] = features[:num_canvases]
        #     cache['weights'] = weights[:num_canvases]
        # else:
        #     for j in range(len(poses)):
        #         data = torch.load(f'{cache_dir}/{instance}_{j}.pt')
        #         cache['alphas'].append(data['alphas'].cuda())
        #         cache['features'].append(data['features'].cuda())
        #         cache['weights'].append(data['weights'].cuda())

        # Computes and loads the cache. Computing the cache doesn't take noticably longer than rendering.
        alphas, features, weights = get_cache(config, instance, expname,
                                              num_canvases)
        cache['alphas'] = alphas[:num_canvases]
        cache['features'] = features[:num_canvases]
        cache['weights'] = weights[:num_canvases]
    else:
        cache = None

    return poses, hwfs, cache, args
def get_cache(config, instance, expname=None, num_canvases=9, save=False):
    parser = config_parser()
    if expname is None:
        args = parser.parse_args(['--config', config])
        N_instances = None
    else:
        args = parser.parse_args(['--config', config, '--expname', expname])
        N_instances = 1
    cachedir = os.path.join(args.expname, 'cache')
    with torch.no_grad():
        render_kwargs_train, render_kwargs_test, _, _, optimizer, styles = create_nerf(
            args, return_styles=True)
        if N_instances is None:
            N_instances = styles.shape[0]
        nfs = [[args.blender_near, args.blender_far] for _ in range(10)]
        os.makedirs(cachedir, exist_ok=True)
        poses, hwfs = get_poses_hwfs(args, instance, N_instances, num_canvases)
        style = styles[instance].repeat((poses.shape[0], 1))
        rgbs, disps, _, alphas, features, weights = render_path(
            poses.cuda(),
            style,
            hwfs,
            args.chunk,
            render_kwargs_test,
            nfs=nfs,
            verbose=False,
            get_cached='color')
        if save:
            for j, (a, f, w) in enumerate(zip(alphas, features, weights)):
                torch.save({
                    'alphas': a,
                    'features': f,
                    'weights': w
                }, f'{cachedir}/{instance}_{j}.pt')
        utils.save_image(
            torch.tensor(rgbs[:1]).permute(0, 3, 1, 2).cpu(),
            os.path.join(args.expname, '{:03d}.png'.format(instance)))
        return alphas, features, weights
Exemple #4
0
def train():
    parser = config_parser()
    args = parser.parse_args()

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.savedir if args.savedir else args.expname
    print('Experiment dir:', expname)

    # Load data
    images, poses, style, i_test, i_train, bds_dict, dataset, hwfs, near_fars, style_inds = load_data(
        args)
    _, poses_test, style_test, hwfs_test, nf_test = images[i_test], poses[
        i_test], style[i_test], hwfs[i_test], near_fars[i_test]
    _, poses_train, style_train, hwfs_train, nf_train = images[i_train], poses[
        i_train], style[i_train], hwfs[i_train], near_fars[i_train]

    os.makedirs(os.path.join(basedir, expname), exist_ok=True)

    np.save(os.path.join(basedir, expname, 'poses.npy'), poses_train.cpu())
    np.save(os.path.join(basedir, expname, 'hwfs.npy'), hwfs_train.cpu())
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)

    print(render_kwargs_train['network_fine'])
    old_coarse_network = copy.deepcopy(
        render_kwargs_train['network_fn']).state_dict()
    old_fine_network = copy.deepcopy(
        render_kwargs_train['network_fine']).state_dict()

    global_step = start
    real_image_application = (args.real_image_dir is not None)
    optimize_mlp = not real_image_application
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)
    loss = None

    if start == 0:
        # if we're starting from scratch, delete all the logs in that directory.
        if os.path.exists(os.path.join(basedir, expname, 'log.txt')):
            os.remove(os.path.join(basedir, expname, 'log.txt'))
    start = start + 1

    for i in range(start, args.n_iters + 1):
        # Sample random ray batch
        batch_rays, target_s, style, H, W, focal, near, far, viewdirs_reg = dataset.get_data_batch(
            train_fn=render_kwargs_train, optimizer=optimizer, loss=loss)
        render_kwargs_train.update({'near': near, 'far': far})

        #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H,
                                        W,
                                        focal,
                                        style=style,
                                        chunk=args.chunk,
                                        rays=batch_rays,
                                        viewdirs_reg=viewdirs_reg,
                                        **render_kwargs_train)
        optimizer.zero_grad()

        img_loss = img2mse(rgb, target_s)
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if args.var_param > 0:
            var = extras['var']
            var0 = extras['var0']
            var_loss = var.mean(dim=0)
            var_loss_coarse = var0.mean(dim=0)

            loss += args.var_param * var_loss
            loss += args.var_param * var_loss_coarse
            var_loss = var_loss.item()
            var_loss_coarse = var_loss_coarse.item()
        else:
            var_loss = 0
            var_loss_coarse = 0

        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0).item()
        else:
            psnr0 = -1

        if args.weight_change_param >= 0:
            weight_change_loss_coarse = 0.
            for k, v in render_kwargs_train['network_fn'].named_parameters():
                if 'weight' in k:
                    diff = (old_coarse_network[k] - v).pow(2).mean()
                    weight_change_loss_coarse += diff
            weight_change_loss_fine = 0.
            for k, v in render_kwargs_train['network_fine'].named_parameters():
                if 'weight' in k:
                    diff = (old_fine_network[k] - v).pow(2).mean()
                    weight_change_loss_fine += diff
            weight_change_loss = weight_change_loss_coarse + weight_change_loss_fine
            loss = loss + args.weight_change_param * weight_change_loss
        else:
            weight_change_loss = torch.tensor(0.)

        loss.backward()
        if optimize_mlp:
            optimizer.step()

        # NOTE: IMPORTANT!
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate**(global_step / decay_steps))

        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################
        #####           end            #####

        if i % args.i_weights == 0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            state_dict = {
                'global_step':
                global_step,
                'network_fn_state_dict':
                render_kwargs_train['network_fn'].state_dict(),
                'optimizer_state_dict':
                optimizer.state_dict(),
                'styles':
                dataset.style,
                'style_optimizer':
                dataset.style_optimizer.state_dict()
            }
            if args.N_importance > 0:
                state_dict['network_fine_state_dict'] = render_kwargs_train[
                    'network_fine'].state_dict()
            torch.save(state_dict, path)
            print('Saved checkpoints at', path)

        if i % args.i_testset == 0 and i > 0:
            if real_image_application:
                style_test = dataset.get_features().repeat(
                    (poses_test.shape[0], 1))
            testsavedir = os.path.join(basedir, expname,
                                       'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            with torch.no_grad():
                render_path(poses_test.to(device),
                            style_test,
                            hwfs_test,
                            args.chunk,
                            render_kwargs_test,
                            nfs=nf_test,
                            savedir=testsavedir,
                            maximum=100)
            print('Saved test set')

        if i % args.i_trainset == 0 and i > 0:
            if real_image_application:
                style_train = dataset.get_features().repeat(
                    (poses_train.shape[0], 1))
            trainsavedir = os.path.join(basedir, expname,
                                        'trainset_{:06d}'.format(i))
            os.makedirs(trainsavedir, exist_ok=True)
            with torch.no_grad():
                render_path(poses_train.to(device),
                            style_train,
                            hwfs_train,
                            args.chunk,
                            render_kwargs_test,
                            nfs=nf_train,
                            savedir=trainsavedir,
                            maximum=100)
            print('Saved train set')

        if i % args.i_print == 0 or i == 1:
            log_str = f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()} PSNR0: {psnr0} Var loss: {var_loss} Var loss coarse: {var_loss_coarse} Weight change loss: {weight_change_loss}"
            with open(os.path.join(basedir, expname, 'log.txt'), 'a+') as f:
                f.write(log_str + '\n')
            print(log_str)

        global_step += 1

        if real_image_application and global_step - start == args.n_iters_real:
            return

        if real_image_application and global_step - start == args.n_iters_code_only:
            optimize_mlp = True
            dataset.optimizer_name = 'adam'
            dataset.style_optimizer = torch.optim.Adam(dataset.params,
                                                       lr=dataset.lr)
            print('Starting to jointly optimize weights with code')
Exemple #5
0
            log_str = f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()} PSNR0: {psnr0} Var loss: {var_loss} Var loss coarse: {var_loss_coarse} Weight change loss: {weight_change_loss}"
            with open(os.path.join(basedir, expname, 'log.txt'), 'a+') as f:
                f.write(log_str + '\n')
            print(log_str)

        global_step += 1

        if real_image_application and global_step - start == args.n_iters_real:
            return

        if real_image_application and global_step - start == args.n_iters_code_only:
            optimize_mlp = True
            dataset.optimizer_name = 'adam'
            dataset.style_optimizer = torch.optim.Adam(dataset.params,
                                                       lr=dataset.lr)
            print('Starting to jointly optimize weights with code')


if __name__ == '__main__':
    parser = config_parser()
    args = parser.parse_args()
    if args.instance != -1:
        # Allows for scripting over single instance experiments.
        exit_if_job_done(os.path.join(args.basedir, args.expname))
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        train()
        mark_job_done(os.path.join(args.basedir, args.expname))
    else:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        train()
Exemple #6
0
def test():
    parser = config_parser()
    args = parser.parse_args()

    images, poses, style, i_test, i_train, bds_dict, dataset, hwfs, near_fars, _ = load_data(
        args)
    images_test, poses_test, style_test, hwfs_test, nf_test = images[
        i_test], poses[i_test], style[i_test], hwfs[i_test], near_fars[i_test]
    images_train, poses_train, style_train, hwfs_train, nf_train = images[
        i_train], poses[i_train], style[i_train], hwfs[i_train], near_fars[
            i_train]

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)

    np.save(os.path.join(basedir, expname, 'poses.npy'), poses_train.cpu())
    np.save(os.path.join(basedir, expname, 'hwfs.npy'), hwfs_train.cpu())

    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    with torch.no_grad():
        if args.render_test:
            if args.shuffle_poses:
                print('Shuffling test poses')
                permutation = list(range(len(poses_test)))
                random.shuffle(permutation)
                poses_test = poses_test[permutation]
            testsavedir = os.path.join(basedir, expname,
                                       'test_imgs{:06d}'.format(start))
            os.makedirs(testsavedir, exist_ok=True)
            _, _, psnr = render_path(poses_test.to(device),
                                     style_test,
                                     hwfs_test,
                                     args.chunk,
                                     render_kwargs_test,
                                     nfs=nf_test,
                                     gt_imgs=images_test,
                                     savedir=testsavedir)
            print('Saved test set w/ psnr', psnr)

        if args.render_train:
            if args.shuffle_poses:
                print('Shuffling train poses')
                permutation = list(range(len(poses_train)))
                random.shuffle(permutation)
                poses_train = poses_train[permutation]
            trainsavedir = os.path.join(basedir, expname,
                                        'train_imgs{:06d}'.format(start))
            os.makedirs(trainsavedir, exist_ok=True)
            _, _, psnr = render_path(poses_train.to(device),
                                     style_train,
                                     hwfs_train,
                                     args.chunk,
                                     render_kwargs_test,
                                     nfs=nf_train,
                                     gt_imgs=images_train,
                                     savedir=trainsavedir)
            print('Saved train set w/ psnr', psnr)
def transfer_codes(config):
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    N_per_transfer = 8
    N_transfer = 10

    parser = config_parser()
    args = parser.parse_args(['--config', config])
    render_kwargs_train, render_kwargs_test, _, _, optimizer, styles = create_nerf(
        args, return_styles=True)
    nfs = [[args.blender_near, args.blender_far]
           for _ in range(N_per_transfer)]
    os.makedirs(f'{args.expname}/transfer_codes', exist_ok=True)

    with torch.no_grad():
        for i1 in range(N_transfer):
            for i2 in tqdm(range(N_transfer)):
                if i1 == i2:
                    continue

                s1, s2 = styles[i1].unsqueeze(dim=0), styles[i2].unsqueeze(
                    dim=0)
                poses_1, hwfs_1 = get_poses_hwfs(args,
                                                 i1,
                                                 styles.shape[0],
                                                 num_canvases=N_per_transfer)
                poses_2, hwfs_2 = get_poses_hwfs(args,
                                                 i2,
                                                 styles.shape[0],
                                                 num_canvases=N_per_transfer)
                rgb1 = render_path(poses_1,
                                   s1,
                                   hwfs_1,
                                   4096,
                                   render_kwargs_test,
                                   nfs=nfs,
                                   verbose=False)[0][0]
                rgb2 = render_path(poses_2,
                                   s2,
                                   hwfs_2,
                                   4096,
                                   render_kwargs_test,
                                   nfs=nfs,
                                   verbose=False)[0][0]
                utils.save_image(
                    torch.tensor(rgb1).permute(2, 0, 1),
                    f'{args.expname}/transfer_codes/{i1}.png')
                utils.save_image(
                    torch.tensor(rgb2).permute(2, 0, 1),
                    f'{args.expname}/transfer_codes/{i2}.png')

                take_color = torch.cat([s1[:, :32], s2[:, 32:]], dim=1)
                take_shape = torch.cat([s2[:, :32], s1[:, 32:]], dim=1)
                # i1 with shape from i2 is the same as i2 color from i1, so don't duplicate
                color_from = render_path(poses_1,
                                         take_color.repeat(
                                             (N_per_transfer, 1)),
                                         hwfs_1,
                                         4096,
                                         render_kwargs_test,
                                         nfs=nfs,
                                         verbose=False)[0]
                shape_from = render_path(poses_2,
                                         take_shape.repeat(
                                             (N_per_transfer, 1)),
                                         hwfs_2,
                                         4096,
                                         render_kwargs_test,
                                         nfs=nfs,
                                         verbose=False)[0]
                utils.save_image(
                    torch.tensor(color_from).permute(0, 3, 1, 2),
                    f'{args.expname}/transfer_codes/{i1}_color_from_{i2}.png')
                utils.save_image(
                    torch.tensor(shape_from).permute(0, 3, 1, 2),
                    f'{args.expname}/transfer_codes/{i1}_shape_from_{i2}.png')
def load_config(config):
    parser = config_parser()
    return parser.parse_args(['--config', config])