Пример #1
0
def get_images_train():
    dataset_dir = os.path.join(LLFF_DIR, DATASET)
    image_dir = os.path.join(dataset_dir, 'images')
    _, poses, bds, _, _ = load_llff_data(dataset_dir, factor=None)
    # split into train data and validation data
    validation_ids = np.arange(poses.shape[0])
    validation_ids[::8] = -1  #validation every [0,7, ... ]
    validation_ids = validation_ids < 0
    # pick only pose in train
    images_path = [
        os.path.join('images', f) for f in sorted(os.listdir(image_dir))
    ]
    images_train = []
    images_valid = []
    for image_id in range(poses.shape[0]):
        R = poses[image_id][:3, :3]
        t = poses[image_id][:3, 3].reshape([3, 1])
        # LLFF need to inverse rotation and translation to match our format
        R[:3, 0] *= -1
        R = np.transpose(R)
        t = np.matmul(R, t)
        img_obj = {
            "path": images_path[image_id],
            "center": -R @ t,
            "depth": bds[image_id]
        }
        if not validation_ids[image_id]:
            images_train.append(img_obj)
        else:
            images_valid.append(img_obj)
    return images_train
Пример #2
0
def get_images_data(dataset, split_val=8):
    dataset_dir = os.path.join(LLFF_DIR, dataset)
    image_dir = os.path.join(dataset_dir, 'images')
    _, poses, bds, _, _ = load_llff_data(dataset_dir, factor=None)
    # split into train data and validation data
    validation_ids = np.arange(poses.shape[0])
    validation_ids[::split_val] = -1  #validation every [0,7, ... ]
    validation_ids = validation_ids < 0
    # pick only pose in train
    images_path = [
        'images/{}'.format(f) for f in sorted(os.listdir(image_dir))
    ]
    images_train = []
    images_valid = []
    for image_id in range(poses.shape[0]):
        R = poses[image_id][:3, :3]
        t = poses[image_id][:3, 3].reshape([3, 1])
        # LLFF need to inverse rotation and translation to match our format
        R[:3, 0] *= -1
        R = np.transpose(R)
        t = np.matmul(R, t)
        H = poses[image_id, 0, -1]
        W = poses[image_id, 1, -1]
        focal = poses[image_id, 2, -1]
        img_obj = {
            "path": images_path[image_id],
            "r": R,
            "t": t,
            "R": R.T,
            "center": -R @ t,
            "planes": bds[image_id],
            "camera": {
                "width": int(W),
                "height": int(H),
                "fx": float(focal),
                "fy": float(focal),
                "px": float(W / 2.0),
                "py": float(H / 2.0)
            },
        }
        if not validation_ids[image_id]:
            images_train.append(img_obj)
        else:
            images_valid.append(img_obj)
    return images_train, images_valid
Пример #3
0
def train():
    parser = config_parser()
    args = parser.parse_args()

    if args.random_seed is not None:
        print('Fixing random seed', args.random_seed)
        np.random.seed(args.random_seed)
        tf.compat.v1.set_random_seed(args.random_seed)

    # Load data

    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify,
            rgba=True)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = tf.reduce_min(bds) * .9
            far = tf.reduce_max(bds) * 1.
        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)
    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf(
        args)

    bds_dict = {
        'near': tf.cast(near, tf.float32),
        'far': tf.cast(far, tf.float32),
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        if args.render_test:
            # render_test switches to test poses
            images = images[i_test]
        else:
            # Default is smoother render_poses path
            images = None

        testsavedir = os.path.join(
            basedir, expname, 'renderonly_{}_{:06d}'.format(
                'test' if args.render_test else 'path', start))
        os.makedirs(testsavedir, exist_ok=True)
        print('test poses shape', render_poses.shape)

        rgbs, _ = render_path(render_poses,
                              hwf,
                              args.chunk,
                              render_kwargs_test,
                              gt_imgs=images,
                              savedir=testsavedir,
                              render_factor=args.render_factor)
        print('Done rendering', testsavedir)
        imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                         to8b(rgbs),
                         fps=30,
                         quality=8)

        return

    # Create optimizer
    lrate = args.lrate
    if args.lrate_decay > 0:
        lrate = tf.keras.optimizers.schedules.ExponentialDecay(
            lrate, decay_steps=args.lrate_decay * 1000, decay_rate=0.1)
    optimizer = tf.keras.optimizers.Adam(lrate)
    models['optimizer'] = optimizer

    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step.assign(start)

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        print('get rays')
        rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]]
        rays = np.stack(rays, axis=0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays = np.transpose(rays, [0, 2, 3, 1, 4])
        rays = np.reshape(rays, list(rays.shape[:3]) + [6])
        rays_rgba = np.concatenate([rays, images], -1)
        rays_rgba = np.stack([rays_rgba[i] for i in i_train],
                             axis=0)  # train images only
        rays_rgba = np.reshape(rays_rgba, [-1, 10])
        rays_rgba = rays_rgba.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgba)
        print('done')
        i_batch = 0

    N_iters = 1000000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    writer = tf.summary.create_file_writer(
        os.path.join(basedir, 'summaries', expname))

    for i in range(start, N_iters):
        time0 = time.time()

        batch = rays_rgba[i_batch:i_batch + N_rand]
        batch_rays = batch[:, :6]
        target_rgba = batch[:, 6:]

        batch_rays = np.reshape(batch_rays, [-1, 2, 3])
        batch_rays = np.transpose(batch_rays, [1, 0, 2])

        i_batch += N_rand
        if i_batch >= rays_rgba.shape[0]:
            np.random.shuffle(rays_rgba)
            i_batch = 0

        #####  Core optimization loop  #####

        with tf.GradientTape() as tape:

            # Make predictions for color, disparity, accumulated opacity.
            rgb, disp, acc, extras = render(H,
                                            W,
                                            focal,
                                            chunk=args.chunk,
                                            rays=batch_rays,
                                            verbose=i < 10,
                                            retraw=True,
                                            **render_kwargs_train)

            # Compute MSE loss between predicted and true RGBA.
            rgba = to_rgba(rgb, acc)
            loss = img2mse(rgba, target_rgba)
            trans = extras['raw'][..., -1]
            psnr = mse2psnr(loss)

            # Add MSE loss for coarse-grained model
            if 'rgb0' in extras:
                rgba0 = to_rgba(extras['rgb0'], extras['acc0'])
                loss0 = img2mse(rgba0, target_rgba)
                loss += loss0
                psnr0 = mse2psnr(loss0)

            img_loss = img2mse(rgb, target_rgba[..., :3])
            img_psnr = mse2psnr(img_loss)
            tf.stop_gradient(img_loss)
            tf.stop_gradient(img_psnr)

        gradients = tape.gradient(loss, grad_vars)
        optimizer.apply_gradients(zip(gradients, grad_vars))

        dt = time.time() - time0

        #####           end            #####

        # Rest is logging

        def save_weights(net, prefix, i):
            path = os.path.join(basedir, expname,
                                '{}_{:06d}.npy'.format(prefix, i))
            np.save(path, net.get_weights())
            print('saved weights at', path)

        if i % args.i_weights == 0:
            for k in models:
                save_weights(models[k], k, i)

        if i % args.i_video == 0 and i > 0:

            rgbs, disps = render_path(render_poses, hwf, args.chunk,
                                      render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname,
                                     '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)

            if args.use_viewdirs:
                render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4]
                rgbs_still, _ = render_path(render_poses, hwf, args.chunk,
                                            render_kwargs_test)
                render_kwargs_test['c2w_staticcam'] = None
                imageio.mimwrite(moviebase + 'rgb_still.mp4',
                                 to8b(rgbs_still),
                                 fps=30,
                                 quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(basedir, expname,
                                       'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            render_path(poses[i_test],
                        hwf,
                        args.chunk,
                        render_kwargs_test,
                        gt_imgs=images[i_test],
                        savedir=testsavedir)
            print('Saved test set')

        if i % args.i_print == 0 or i < 10:
            print(expname, i, psnr.numpy(), loss.numpy(),
                  [img_loss.numpy(), img_psnr.numpy()], global_step.numpy())
            print('iter time {:.05f}'.format(dt))
            with writer.as_default():
                tf.summary.scalar('loss', loss, step=i + 1)
                tf.summary.scalar('psnr', psnr, step=i + 1)
                tf.summary.histogram('tran', trans, step=i + 1)
                if args.N_importance > 0:
                    tf.summary.scalar('psnr0', psnr0, step=i + 1)

            if i % args.i_img == 0:

                # Log a rendered validation view to Tensorboard
                img_i = np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3, :4]

                rgb, disp, acc, extras = render(H,
                                                W,
                                                focal,
                                                chunk=args.chunk,
                                                c2w=pose,
                                                **render_kwargs_test)
                rgba = to_rgba(rgb, acc)
                psnr = mse2psnr(img2mse(rgba, target))

                # Save out the validation image for Tensorboard-free monitoring
                testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')
                if i == 0:
                    os.makedirs(testimgdir, exist_ok=True)
                imageio.imwrite(
                    os.path.join(testimgdir, '{:06d}.png'.format(i)),
                    to8b(rgba))
                imageio.imwrite(
                    os.path.join(testimgdir, '{:06d}_acc.png'.format(i)),
                    to8b(acc))

                with writer.as_default():
                    tf.summary.image('rgba',
                                     to8b(rgba)[tf.newaxis],
                                     step=i + 1)
                    tf.summary.image('disp',
                                     disp[tf.newaxis, ..., tf.newaxis],
                                     step=i + 1)
                    tf.summary.image('acc',
                                     acc[tf.newaxis, ..., tf.newaxis],
                                     step=i + 1)

                    tf.summary.scalar('psnr_holdout', psnr, step=i + 1)
                    tf.summary.image('rgb_holdout',
                                     target[tf.newaxis],
                                     step=i + 1)

                if args.N_importance > 0:
                    with writer.as_default():
                        tf.summary.image('rgb0',
                                         to8b(extras['rgb0'])[tf.newaxis],
                                         step=i + 1)
                        tf.summary.image('disp0',
                                         extras['disp0'][tf.newaxis, ...,
                                                         tf.newaxis],
                                         step=i + 1)
                        tf.summary.image('z_std',
                                         extras['z_std'][tf.newaxis, ...,
                                                         tf.newaxis],
                                         step=i + 1)

        global_step.assign_add(1)
Пример #4
0
    I = I.unsqueeze(0).to(device) if args.gpu != 'cpu' else I.unsqueeze(0)
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        output, _ = model(I)  # output size: 1 x 21 x 64(H) x 64(W)
        # print('Inference time: {:.4f} s'.format(time.time()-start_time))
        kps_pred_np = get_final_preds(
            output,
            use_softmax=cfg.MODEL.HEATMAP_SOFTMAX).cpu().numpy().squeeze()
    return kps_pred_np


# images: 114 x H x W x 3
# poses(c2w and intrinsic): 114 x 3 x 5
images, poses, bds, render_poses, i_test = load_llff_data(base_dir,
                                                          factor=3,
                                                          recenter=True)

import pickle
fname = './pose2d_pred.txt'
try:
    print('Load')
    with open(fname, 'rb') as f:
        pts = pickle.load(f)
except:
    print('Failed')
    pts = []
    color_lower = (80, 45, 30)
    color_upper = (120, 190, 180)

    images = (images * 255).astype(np.uint8)
Пример #5
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    # Multi-GPU
    args.n_gpus = torch.cuda.device_count()
    print(f"Using {args.n_gpus} GPU(s).")

    # Load data
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.

        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.
        far = 6.

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip)

        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.
        far = hemi_R + 1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)
    global_step = start

    bds_dict = {
        'near': near,
        'far': far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        with torch.no_grad():
            if args.render_test:
                # render_test switches to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None

            testsavedir = os.path.join(
                basedir, expname, 'renderonly_{}_{:06d}'.format(
                    'test' if args.render_test else 'path', start))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            rgbs, _, _ = render_path(render_poses,
                                     hwf,
                                     args.chunk,
                                     render_kwargs_test,
                                     gt_imgs=images,
                                     savedir=testsavedir,
                                     render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                             to8b(rgbs),
                             fps=30,
                             quality=8)

            return

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    K = args.num_poses
    if use_batching or K >= len(poses):
        # For random ray batching
        print('get rays')
        rays = np.stack(
            [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]],
            0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:, None]],
                                  1)  # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb,
                                [0, 2, 3, 1, 4])  # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            0)  # train images only
        rays_rgb = np.reshape(rays_rgb,
                              [-1, 3, 3])  # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        i_batch = 0
    else:
        print('We are taking batches of rays from',
              K,
              'image(s) at a time',
              flush=True)
        train_poses = np.stack([poses[i] for i in i_train], axis=0)
        train_images = np.stack([images[i] for i in i_train], axis=0)
        nearest_rays = []
        train_translations = train_poses[:, :3, 3]
        train_translations_indices = np.arange(len(train_translations))
        print('Caching nearest rays for each pose...')
        for i, _ in enumerate(tqdm(train_translations)):
            translation = train_translations[i]
            distances = np.linalg.norm(train_translations - translation,
                                       axis=-1)
            knn_poses_indices = np.array(
                sorted(train_translations_indices,
                       key=lambda i: distances[i]))[:K]
            knn_poses = train_poses[knn_poses_indices]
            knn_images = train_images[knn_poses_indices]
            nearest_rays.append(
                get_rays_rgb_for_poses(H, W, knn_poses, knn_images, focal))
    print('Done', flush=True)

    # Move training data to GPU
    images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)

    N_iters = args.N_iters
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))

    start = start
    for i in trange(start, N_iters):
        time0 = time.time()

        # Sample random ray batch
        optimizer.zero_grad()
        N_rand_small = 1024

        if not use_batching:
            # Random from K-poses.
            pose_i = np.random.choice(i_train)
            rays_rgb = nearest_rays[pose_i]
            assert N_rand <= rays_rgb.shape[0]

        select_inds = np.random.choice(rays_rgb.shape[0],
                                       size=[N_rand],
                                       replace=False)  # (N_rand,)
        # epoch_ended = False
        # num_rays = N_rand if not use_batching else min(N_rand, rays_rgb.size(0) - i_batch)
        num_rays = N_rand

        total_weight = 0.0
        for j in range(0, N_rand, N_rand_small):
            # if epoch_ended:
            #     break

            # Pick with replacement.
            curr_select_inds = select_inds[j:j + N_rand_small]
            batch = rays_rgb[curr_select_inds]

            if not use_batching:
                # rays_rgb was not moved to device for KNN-pose mode, since it may be too big.
                batch = torch.Tensor(batch).to(device)

            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            #####  Core optimization loop  #####
            rgb, disp, acc, extras = render(H,
                                            W,
                                            focal,
                                            chunk=args.chunk,
                                            rays=batch_rays,
                                            verbose=i < 10,
                                            retraw=True,
                                            **render_kwargs_train)

            # optimizer.zero_grad()
            img_loss = img2mse(rgb, target_s)
            trans = extras['raw'][..., -1]
            loss = img_loss

            psnr = mse2psnr(img_loss)

            if 'rgb0' in extras:
                img_loss0 = img2mse(extras['rgb0'], target_s)
                loss = loss + img_loss0
                # psnr0 = mse2psnr(img_loss0)

            weight = (rgb.size(0) / float(num_rays))
            loss_backprop = loss * weight
            total_weight += weight
            loss_backprop.backward()

        assert np.isclose(total_weight, 1.0), total_weight
        optimizer.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate**(global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time() - time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
        #####           end            #####

        # Rest is logging
        if i % args.i_weights == 0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save(
                {
                    'global_step':
                    global_step,
                    'network_fn_state_dict':
                    render_kwargs_train['network_fn'].state_dict(),
                    'network_fine_state_dict':
                    render_kwargs_train['network_fine'].state_dict(),
                    'optimizer_state_dict':
                    optimizer.state_dict(),
                }, path)
            print('Saved checkpoints at', path)

        # if i%args.i_video==0 and i > 0:
        #     # Turn on testing mode
        #     with torch.no_grad():
        #         rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
        #     print('Done, saving', rgbs.shape, disps.shape)
        #     moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format('video', i))
        #     imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
        #     imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

        # if args.use_viewdirs:
        #     render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
        #     with torch.no_grad():
        #         rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
        #     render_kwargs_test['c2w_staticcam'] = None
        #     imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)
        adjusted_step = int(
            (global_step - start) * N_rand / float(N_rand_small)) + start

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(basedir, expname,
                                       'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            with torch.no_grad():
                _, _, test_mse = render_path(torch.Tensor(
                    poses[i_test]).to(device),
                                             hwf,
                                             args.chunk,
                                             render_kwargs_test,
                                             gt_imgs=images[i_test],
                                             savedir=testsavedir)
                test_psnr = mse2psnr(torch.from_numpy(np.array(test_mse)))
                writer.add_scalar('mse_test',
                                  test_mse,
                                  global_step=global_step)
                writer.add_scalar('psnr_test',
                                  test_psnr,
                                  global_step=global_step)
                writer.add_scalar('psnr_test_by_t',
                                  test_psnr,
                                  global_step=adjusted_step)
                writer.add_scalar('mse_test_by_t',
                                  test_mse,
                                  global_step=adjusted_step)
            print('Saved test set')

        if i % args.i_print == 0:
            tqdm.write(
                f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
            print('', flush=True)
            writer.add_scalar('loss', loss, global_step=global_step)
            writer.add_scalar('psnr', psnr, global_step=global_step)
            writer.add_scalar('lr', new_lrate, global_step=global_step)
            # writer.add_histogram('tran', trans, global_step=global_step)

            writer.add_scalar('mse_by_t', loss, global_step=adjusted_step)
            writer.add_scalar('psnr_by_t', psnr, global_step=adjusted_step)

            # if args.N_importance > 0:
            #     writer.add_scalar('psnr0', psnr0, global_step=global_step)

        if i % args.i_img == 0:
            # Log a rendered validation view to Tensorboard
            # img_i = np.random.choice(i_val)
            img_i = i_val[0]
            target = images[img_i]
            pose = poses[img_i, :3, :4]
            with torch.no_grad():
                rgb, disp, acc, extras = render(H,
                                                W,
                                                focal,
                                                chunk=args.chunk,
                                                c2w=pose,
                                                **render_kwargs_test)

            mse = img2mse(rgb, target)
            psnr = mse2psnr(mse)
            writer.add_image('rgb',
                             rgb,
                             dataformats='HWC',
                             global_step=global_step)
            writer.add_image('disp',
                             torch.stack(3 * [disp], dim=-1),
                             dataformats='HWC',
                             global_step=global_step)
            writer.add_image('acc',
                             torch.stack(3 * [acc], dim=-1),
                             dataformats='HWC',
                             global_step=global_step)
            writer.add_scalar('mse_holdout', mse, global_step=global_step)
            writer.add_scalar('mse_holdout_by_t',
                              mse,
                              global_step=adjusted_step)
            writer.add_image('rgb_holdout',
                             target,
                             dataformats='HWC',
                             global_step=global_step)

            # if args.N_importance > 0:
            #     with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
            #         tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
            #         tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis])
            #         tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis])

        global_step += 1
Пример #6
0
def train():

    parser = config_parser()
    args = parser.parse_args()
    
    if args.random_seed is not None:
        print('Fixing random seed', args.random_seed)
        np.random.seed(args.random_seed)
        tf.compat.v1.set_random_seed(args.random_seed)

    # Load data

    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
                                                                  recenter=True, bd_factor=.75,
                                                                  spherify=args.spherify)

        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape,
              render_poses.shape, hwf, args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                            (i not in i_test and i not in i_val)])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            args.near = tf.reduce_min(bds) * .9
            args.far = tf.reduce_max(bds) * 1.
        else:
            args.near = 0.
            args.far = 1.
        print('NEAR FAR', args.near, args.far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split, extras = load_blender_data(
            args.datadir,
            args.half_res,
            args.testskip,
            args.image_extn,
            mask_directory=args.mask_directory,
            get_depths=args.get_depth_maps if args.near is None and args.far is None else False,
            image_field=args.image_fieldname,
            image_dir_override=args.image_dir_override,
            trainskip=args.trainskip,
            train_frames_field=args.frames_field)

        if args.mask_directory is not None:
            masks = extras['masks']

        if args.get_depth_maps:
            depth_maps = extras['depth_maps']

        K = None
        if args.use_K:
            K = extras['K']

        print('Loaded blender', images.shape,
              render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split
        if args.near is None and args.far is None:
            if args.Z_limits_from_pose:
                Zs = poses[:, 2, 3]
                args.near = np.min(np.abs(Zs)) * 0.5
                args.far = np.max(np.abs(Zs)) * 1.5

            elif args.get_depth_maps and args.mask_directory is not None:
                args.near = np.min(np.mean(depth_maps[masks])) * 0.9
                args.far = np.max(depth_maps[masks]) * 1.1
                print('using masked depth')

            elif args.get_depth_maps:
                args.near = np.min(depth_maps) * 0.9
                args.far = np.max(depth_maps) * 1.1
                print('using masked depth')
            else:
                args.near = 0.
                args.far = 2.

        print(f'args.near: {args.near} far: {args.far}')

        if args.mask_directory is not None and images.shape[-1] == 3:
            images = np.concatenate([images, masks[..., np.newaxis]], axis=-1)

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':


        images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
                                                                 basedir=args.datadir,
                                                                 testskip=args.testskip)

        print('Loaded deepvoxels', images.shape,
              render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        args.near = hemi_R-1.
        args.far = hemi_R+1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    if args.mask_directory and not args.white_bkgd:
        assert os.path.isdir(args.mask_directory), f'args.mask_directory not found at: {args.mask_directory}'
        if args.mask_images:
            images *= masks[..., np.newaxis]

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    filename_splits = extras['filenames']
    enable_keypoints = False
    keypoint_loss_coeff = 0
    if args.colmap_keypoints_filename is not None:
        enable_keypoints = True
        from NERFCO.nerf_keypoint_network import get_keypoint_masks
        import MODELS.random_embeddings

        train_keypoint_masks, args.number_of_keypoints = \
            get_keypoint_masks(args.colmap_keypoints_filename,
                               i_train,
                               filename_splits,
                               H, W)

        train_keypoint_masks = train_keypoint_masks.ravel()
        test_keypoint_masks, test_keypoint_count = \
            get_keypoint_masks(args.colmap_keypoints_filename,
                               i_test,
                               filename_splits,
                               H, W)

        number_of_keypoints = np.sum(train_keypoint_masks > 0)

        assert test_keypoint_count == args.number_of_keypoints, \
            f'test_keypoint_count: {test_keypoint_count} == args.number_of_keypoints: {args.number_of_keypoints}'

        print(train_keypoint_masks.shape, train_keypoint_masks.min(), train_keypoint_masks.max(), number_of_keypoints)
        keypoint_embeddings_filename = os.path.join(args.basedir, args.expname, 'keypoint_embeddings.pkl')
        random_keypoint_embeddings_object = MODELS.random_embeddings.StaticRandomEmbeddings(
            args.number_of_keypoints + 1,
            args.keypoint_embedding_size,
            embedding_filename=keypoint_embeddings_filename,
            zero_embedding_origin=args.zero_embedding_origin)

        keypoint_embeddings = random_keypoint_embeddings_object.embeddings

    import NERFCO.extract_keypoints
    if args.autoencoded_keypoints_filename is not None:
        enable_keypoints = True
        assert os.path.isfile(args.autoencoded_keypoints_filename), \
            f'autoencoded_keypoints_filename not found at: {args.autoencoded_keypoints_filename}'

        import pickle
        input = open(args.autoencoded_keypoints_filename, 'rb')
        data = pickle.load(input)
        input.close()

        train_keypoint_masks = data['train_keypoint_masks']
        test_keypoint_masks = data['test_keypoint_masks']
        keypoint_embeddings = data['encoded_embeddings']

        NERFCO.extract_keypoints.test_keypoint_masks(train_keypoint_masks, keypoint_embeddings)
        NERFCO.extract_keypoints.test_keypoint_masks(test_keypoint_masks, keypoint_embeddings)

        args.keypoint_embedding_size = keypoint_embeddings.shape[-1]
        train_keypoint_masks = train_keypoint_masks.ravel()
        print(f'loaded {args.keypoint_detector} keypoints: {keypoint_embeddings.shape} from: {args.autoencoded_keypoints_filename}')

    elif args.keypoints_filename is not None and args.keypoint_detector in ['SIFT', 'ORB']:
        enable_keypoints = True

        train_keypoint_masks, test_keypoint_masks, keypoint_embeddings = \
            NERFCO.extract_keypoints.get_keypoints_and_maps(
                args.keypoints_filename,
                i_test,
                i_train,
                filename_splits,
                H, W)

        NERFCO.extract_keypoints.test_keypoint_masks(train_keypoint_masks, keypoint_embeddings)
        NERFCO.extract_keypoints.test_keypoint_masks(test_keypoint_masks, keypoint_embeddings)

        args.keypoint_embedding_size = keypoint_embeddings.shape[-1]
        train_keypoint_masks = train_keypoint_masks.ravel()
        print(f'loaded {args.keypoint_detector} keypoints: {keypoint_embeddings.shape} from: {args.keypoints_filename}')

    elif args.learnable_embeddings_filename is not None:
        enable_keypoints = True

        train_keypoint_masks, test_keypoint_masks, static_keypoints = \
            NERFCO.extract_keypoints.get_keypoints_and_maps(
                args.keypoints_filename,
                i_test,
                i_train,
                filename_splits,
                H, W)

        args.number_of_keypoints = static_keypoints.shape[0]

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf(
        args)

    if args.use_K:
        K_dict = {
            'K': K,
        }

        render_kwargs_train.update(K_dict)
        render_kwargs_test.update(K_dict)

    if args.learnable_embeddings_filename is not None:
        keypoint_embeddings = models['keypoint_embeddings']

        NERFCO.extract_keypoints.test_keypoint_masks(train_keypoint_masks, keypoint_embeddings[:])
        NERFCO.extract_keypoints.test_keypoint_masks(test_keypoint_masks, keypoint_embeddings[:])

        args.keypoint_embedding_size = keypoint_embeddings.shape[-1]
        train_keypoint_masks = train_keypoint_masks.ravel()
        print(f'loaded {args.keypoint_detector} keypoints: {keypoint_embeddings.shape} from: {args.learnable_embeddings_filename}')


    bds_dict = {
        'near': tf.cast(args.near, tf.float32),
        'far': tf.cast(args.far, tf.float32),
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Short circuit if only rendering out from trained model
    if args.render_only:
        # render_poses = poses[::args.testskip]
        render_test_data(args,
                         render_poses,
                         images,
                         i_test,
                         start,
                         render_kwargs_test,
                         hwf,
                         K if K is not None else None,
                         embeddings=keypoint_embeddings if enable_keypoints else None,
                         gt_keypoint_map=test_keypoint_masks if enable_keypoints else None)
        exit(0)

    # Create optimizer
    lrate = args.lrate
    if args.lrate_decay > 0:
        lrate = tf.keras.optimizers.schedules.ExponentialDecay(lrate,
                                                               decay_steps=args.lrate_decay * 1000, decay_rate=0.1)
    optimizer = tf.keras.optimizers.Adam(lrate)
    models['optimizer'] = optimizer

    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step.assign(start)

    photometric_loss_function = img2mse
    if args.use_huber_loss:
        photometric_loss_function = tf.compat.v1.losses.huber_loss

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching.
        #
        # Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is
        # interpreted as,
        #   axis=0: ray origin in world space
        #   axis=1: ray direction in world space
        #   axis=2: observed RGB color of pixel

        # get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3]
        # for each pixel in the image. This stack() adds a new dimension.
        if args.use_K:
            print('get rays K')
            rays = [NERFCO.nerf_renderer.get_rays_tf_K(H, W, K, p) for p in poses[:, :3, :4]]
        else:
            print('get rays')
            rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]]

        rays = np.stack(rays, axis=0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.concatenate([rays, images[:, None, ...]], 1)
        # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])
        rays_rgb = np.stack([rays_rgb[i] for i in i_train], axis=0)  # train images only
        #
        if args.depth_from_camera and args.depth_loss:
            train_depths = np.stack([depth_maps[i] for i in i_train], axis=0).ravel()  # train images only

        rays_rgb = np.stack([rays_rgb[i] for i in i_train], axis=0)  # train images only
        if args.mask_directory and not args.white_bkgd:
            train_masks = np.stack([masks[i] for i in i_train], axis=0)
            pixel_train_masks = train_masks.ravel()
            if args.ray_masking:
                rays_rgb = rays_rgb[np.where(train_masks)]

        # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = np.reshape(rays_rgb, [-1, 3, 3])
        rays_rgb = rays_rgb.astype(np.float32)

        if args.keypoint_oversample:
            keypoint_mask = train_keypoint_masks.astype(np.bool)
            over_sample = train_keypoint_masks.shape[0] // np.sum(keypoint_mask) - 1
            keypoint_rays_rgb = rays_rgb[keypoint_mask]
            non_zero_keypoints = train_keypoint_masks[keypoint_mask]
            assert non_zero_keypoints[0] == train_keypoint_masks[np.argmax(keypoint_mask)], 'first true keypoint should match'

            # repeat rgb
            repeated_keypoint_rays_rgb = np.repeat(keypoint_rays_rgb, over_sample, axis=0)
            assert np.all(keypoint_rays_rgb[0] == repeated_keypoint_rays_rgb[0]), 'repeats should be the same'
            rays_rgb = np.concatenate([rays_rgb, repeated_keypoint_rays_rgb])

            # repeat keypoints
            repeated_keypoint_masks = np.repeat(non_zero_keypoints, over_sample, axis=0)
            assert repeated_keypoint_masks[0] == train_keypoint_masks[np.argmax(keypoint_mask)]
            train_keypoint_masks = np.concatenate([train_keypoint_masks, repeated_keypoint_masks])

            # check rays and keypoint masks are the same size
            assert rays_rgb.shape[0] == train_keypoint_masks.shape[0]

    N_iters = 1000000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    writer = tf.contrib.summary.create_file_writer(
        os.path.join(basedir, 'summaries', expname))
    writer.set_as_default()
    i_batch = 0
    for i in range(start, N_iters):
        time0 = time.time()
        if i >= args.keypoint_iterations_start:
            keypoint_loss_coeff = args.keypoint_loss_coeff
        # Sample random ray batch

        if use_batching:
            # shuffle
            if i == start or i_batch >= rays_rgb.shape[0]:
                rays_rgb, permutations = shuffler(rays_rgb)
                if args.mask_directory and args.sigma_masking:
                    pixel_train_masks = shuffler(pixel_train_masks, permutations)

                if enable_keypoints:
                    train_keypoint_masks = shuffler(train_keypoint_masks, permutations)

                if args.depth_from_camera and args.depth_loss:
                    train_depths = shuffler(train_depths, permutations)

                i_batch = 0

            # Random over all images
            batch = rays_rgb[i_batch:i_batch+N_rand]  # [B, 2+1, 3*?]
            batch = tf.transpose(batch, [1, 0, 2])


            if args.sigma_masking:
                in_mask_pixels_batch = pixel_train_masks[i_batch: i_batch + N_rand]

            if enable_keypoints:
                keypoint_masks_batch = train_keypoint_masks[i_batch: i_batch + N_rand]

            if args.depth_from_camera and args.depth_loss:
                train_depths_batch = train_depths[i_batch: i_batch + N_rand]

            i_batch += N_rand

            # batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position
            # target_s[n, rgb] = example_id, observed color.
            if args.depth_from_camera and args.depth_loss:
                batch_rays, target_s = (batch[0], batch[0], train_depths_batch), batch[2]
            else:
                batch_rays, target_s = batch[:2], batch[2]



        else:
            # Random from one image
            test_frame_number = np.random.choice(i_train)
            target = images[test_frame_number]
            pose = poses[test_frame_number, :3, :4]

            if N_rand is not None:
                if args.use_K:
                    rays_o, rays_d = NERFCO.nerf_renderer.get_rays_K(H, W, K, pose)
                else:
                    rays_o, rays_d = get_rays(H, W, focal, pose)

                if i < args.precrop_iters:
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = tf.stack(tf.meshgrid(
                        tf.range(H//2 - dH, H//2 + dH), 
                        tf.range(W//2 - dW, W//2 + dW), 
                        indexing='ij'), -1)
                    if i < 10:
                        print('precrop', dH, dW, coords[0,0], coords[-1,-1])
                else:
                    coords = tf.stack(tf.meshgrid(
                        tf.range(H), tf.range(W), indexing='ij'), -1)
                coords = tf.reshape(coords, [-1, 2])
                select_inds = np.random.choice(
                    coords.shape[0], size=[N_rand], replace=False)
                select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis])
                rays_o = tf.gather_nd(rays_o, select_inds)
                rays_d = tf.gather_nd(rays_d, select_inds)
                batch_rays = tf.stack([rays_o, rays_d], 0)
                target_s = tf.gather_nd(target, select_inds)

        #####  Core optimization loop  #####

        with tf.GradientTape() as nerf_gradient_tape, tf.GradientTape() as embedding_gradient_tape:

            # Make predictions for color, disparity, accumulated opacity.
            rgb, disp, acc, extras = render(
                H, W, focal, chunk=args.chunk, rays=batch_rays,
                verbose=i < 10, retraw=True, **render_kwargs_train)
            # Compute MSE loss between predicted and true RGB.

            if args.sigma_masking:
                loss = photometric_loss_function(rgb[in_mask_pixels_batch], target_s[in_mask_pixels_batch])
            else:
                loss = photometric_loss_function(rgb, target_s)

            if args.force_black_background:
                loss += 0.1 * photometric_loss_function(rgb[~in_mask_pixels_batch], 0.)

            if 'keypoint_map' in extras:
                keypoint_loss = NERFCO.nerf_keypoint_network.get_keypoint_loss(keypoint_embeddings,
                                                                               keypoint_masks_batch,
                                                                               extras['keypoint_map'])
                loss += keypoint_loss_coeff * keypoint_loss

                if args.learnable_embeddings_filename is not None:
                    embedding_loss = keypoint_loss_coeff * keypoint_embeddings.distance_correlation_loss(extras['xyz_map'],
                                                                                                         extras['keypoint_map'])

            trans = extras['raw'][..., -1]
            psnr = mse2psnr(loss)

            if args.sigma_masking:
                loss += 0.01 * tf.reduce_sum(acc[~in_mask_pixels_batch]) / N_rand

            # Add MSE loss for coarse-grained model
            if 'rgb0' in extras:
                if args.sigma_masking:
                    img_loss0 = photometric_loss_function(extras['rgb0'][in_mask_pixels_batch], target_s[in_mask_pixels_batch])
                    psnr0 = mse2psnr(img_loss0)
                    loss += img_loss0 + 0.01 * tf.reduce_sum(extras['acc0'][~in_mask_pixels_batch]) / N_rand
                else:
                    img_loss0 = photometric_loss_function(extras['rgb0'], target_s)
                    psnr0 = mse2psnr(img_loss0)
                    loss += img_loss0

                if args.force_black_background:
                    loss += 0.1 * photometric_loss_function(extras['rgb0'][~in_mask_pixels_batch], 0.)

                if 'keypoint_map_0' in extras:
                    coarse_keypoint_loss = NERFCO.nerf_keypoint_network.get_keypoint_loss(keypoint_embeddings,
                                                                                          keypoint_masks_batch,
                                                                                          extras['keypoint_map_0'])
                    loss += keypoint_loss_coeff * coarse_keypoint_loss

                if args.depth_from_camera and args.depth_loss:
                    loss += tf.losses.huber_loss(train_depths_batch, extras['depth_map'])

        gradients = nerf_gradient_tape.gradient(loss, grad_vars)
        optimizer.apply_gradients(zip(gradients, grad_vars))

        if args.learnable_embeddings_filename is not None:
            gradients = embedding_gradient_tape.gradient(embedding_loss, models['keypoint_embeddings'].trainable_variables)
            optimizer.apply_gradients(zip(gradients, models['keypoint_embeddings'].trainable_variables))

        dt = time.time()-time0

        #####           end            #####

        # Rest is logging

        def save_weights(net, prefix, i):
            path = os.path.join(
                basedir, expname, '{}_{:06d}.npy'.format(prefix, i))
            np.save(path, net.get_weights())
            print('saved weights at', path)

        if i % args.i_weights == 0:
            for k in models:
                save_weights(models[k], k, i)

        if i % args.i_video == 0 and i > 0:

            rgbs, test_extras = render_path(
                render_poses, hwf, args.chunk, render_kwargs_test)
            disps = test_extras['disp']
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(
                basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)), fps=30, quality=8)

            if args.use_viewdirs:
                render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4]
                rgbs_still, _ = render_path(
                    render_poses, hwf, args.chunk, render_kwargs_test)
                render_kwargs_test['c2w_staticcam'] = None
                imageio.mimwrite(moviebase + 'rgb_still.mp4',
                                 to8b(rgbs_still), fps=30, quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(
                basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            render_path(poses[i_test], hwf, args.chunk, render_kwargs_test,
                        gt_imgs=images[i_test], savedir=testsavedir)
            print('Saved test set')

        if i % args.i_print == 0 or i < 10:
            output_line = f'{expname} {i}, {psnr.numpy():.3g}, {loss.numpy():.3g} '
            # report loss
            if enable_keypoints:
                gt_kp_mask = keypoint_masks_batch.astype(np.bool)
                network_output = np.squeeze(extras['keypoint_map'])
                output_line += f'kp loss:{keypoint_loss:.2g} '
                if len(network_output.shape) == 1:
                    output_line+= f'GT kp#:{int(np.sum(keypoint_masks_batch.astype(np.bool).astype(np.float))):d} '\
                                  f'TP+TN:{np.sum(np.isclose(gt_kp_mask.astype(np.float), network_output)) / network_output.shape[0]:.2g} '\
                                  f'TP:{np.sum(np.isclose(gt_kp_mask[gt_kp_mask].astype(np.float), network_output[gt_kp_mask])) / np.sum(gt_kp_mask):.2g} '\
                                  f'TN:{np.sum(np.isclose(gt_kp_mask[~gt_kp_mask].astype(np.float), network_output[~gt_kp_mask])) / np.sum(~gt_kp_mask):.2g} '

            print(output_line)
            with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
                tf.contrib.summary.scalar('loss', loss)
                tf.contrib.summary.scalar('psnr', psnr)
                tf.contrib.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.contrib.summary.scalar('psnr0', psnr0)

                if enable_keypoints:
                    tf.contrib.summary.histogram('keypoint output', network_output)
                    tf.contrib.summary.scalar('keypoint_loss', keypoint_loss)

            if i % args.i_img == 0:
                # report accuracy

                # Log a rendered validation view to Tensorboard
                test_frame_number = np.random.choice(i_val)
                target = images[test_frame_number]
                pose = poses[test_frame_number, :3, :4]

                rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
                                                **render_kwargs_test)

                psnr = mse2psnr(img2mse(rgb, target))
                
                # Save out the validation image for Tensorboard-free monitoring
                testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')
                os.makedirs(testimgdir, exist_ok=True)
                imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(rgb))

                with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                    output_image = to8b(rgb)[tf.newaxis]
                    tf.contrib.summary.image('rgb', output_image)
                    tf.contrib.summary.image(
                        'disp', disp[tf.newaxis, ..., tf.newaxis])
                    tf.contrib.summary.image(
                        'acc', acc[tf.newaxis, ..., tf.newaxis])

                    tf.contrib.summary.scalar('psnr_holdout', psnr)
                    tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
                    if enable_keypoints:
                        # TODO: test why the train renders seem to be less lossy than the test renders
                        # the test renders highlight the entire object whereas
                        # the train renders don't, but have a good loss
                        test_frame_in_split = np.where(i_val == test_frame_number)[0][0]
                        gt_keypoint_mask = test_keypoint_masks[test_frame_in_split]
                        gt_non_zeros_kp_mask = gt_keypoint_mask.astype(np.bool)
                        test_keypoint_loss = NERFCO.nerf_keypoint_network.get_keypoint_loss(keypoint_embeddings,
                                                                                            gt_keypoint_mask,
                                                                                            extras['keypoint_map'])
                        tf.contrib.summary.scalar('test_keypoint_loss', test_keypoint_loss)

                        if len(keypoint_embeddings.shape) == 1:
                            network_output = extras['keypoint_map']
                            network_output_image = np.reshape(network_output.numpy(), (H, W))
                            keypoint_accuracy_image = network_output_image == gt_non_zeros_kp_mask
                            tf.contrib.summary.histogram('keypointyness', network_output)
                            tf.contrib.summary.scalar('keypoint_acc',
                                                      np.sum(keypoint_accuracy_image) / (H*W))
                            tf.contrib.summary.scalar('non_zero_keypoint_acc',
                                                      np.sum(keypoint_accuracy_image[gt_non_zeros_kp_mask]) / np.sum(
                                                          gt_non_zeros_kp_mask))

                            tf.contrib.summary.image('keypoint_image_acc',
                                                     to8b(keypoint_accuracy_image)[tf.newaxis, ..., tf.newaxis])

                        else:
                            inferred_keypoint_image, keypoint_accuracy_image = \
                                NERFCO.nerf_keypoint_network.create_embedded_keypoint_image(extras['keypoint_map'],
                                                                                            keypoint_embeddings,
                                                                                            gt_keypoint_mask)

                            tf.contrib.summary.scalar('keypoint_acc', np.sum(keypoint_accuracy_image) / (H * W))
                            tf.contrib.summary.scalar('non_zero_keypoint_acc',
                                                      np.sum(keypoint_accuracy_image[gt_non_zeros_kp_mask]) / np.sum(gt_non_zeros_kp_mask))

                            tf.contrib.summary.image('keypoint_image_acc',
                                                     to8b(keypoint_accuracy_image)[tf.newaxis, ..., tf.newaxis])
                            tf.contrib.summary.image('keypoint_image',
                                                     to8b(inferred_keypoint_image)[tf.newaxis, ..., tf.newaxis])

                if args.N_importance > 0:
                    with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                        tf.contrib.summary.image(
                            'rgb0', to8b(extras['rgb0'])[tf.newaxis])
                        tf.contrib.summary.image(
                            'disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis])
                        tf.contrib.summary.image(
                            'z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis])

        global_step.assign_add(1)
Пример #7
0
# expname = 'fern_test'

config = os.path.join(basedir, expname, 'config.txt')
print('Args:')
print(open(config, 'r').read())
parser = run_nerf_fast.config_parser()

weights_name = 'model_200000.npy'
# weights_name = 'model_000700.npy'
args = parser.parse_args('--config {} --ft_path {}'.format(
    config, os.path.join(basedir, expname, weights_name)))
print('loaded args')

images, poses, bds, render_poses, i_test = load_llff_data(
    args.datadir,
    args.factor,
    recenter=True,
    bd_factor=.75,
    spherify=args.spherify)
H, W, focal = poses[0, :3, -1].astype(np.float32)
poses = poses[:, :3, :4]
print(f"NUM IMAGES: {poses.shape[0]}")

H = int(H)
W = int(W)
hwf = [H, W, focal]

images = images.astype(np.float32)
poses = poses.astype(np.float32)

if args.no_ndc:
    near = tf.reduce_min(bds) * .9
Пример #8
0
    def load_llff_dataset(
        render_kwargs_train_=None,
        render_kwargs_test_=None,
        return_nerf_volume_extent=False,
    ):

        datadir = args.datadir
        factor = args.factor
        spherify = args.spherify
        bd_factor = args.bd_factor

        # actual loading
        images, poses, bds, render_poses, i_test = load_llff_data(
            datadir,
            factor=factor,
            recenter=True,
            bd_factor=bd_factor,
            spherify=spherify,
        )
        extras = _get_multi_view_helper_mappings(images.shape[0])

        # poses
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]  # N x 3 x 4
        all_rotations = poses[:, :3, :3]  # N x 3 x 3
        all_translations = poses[:, :3, 3]  # N x 3

        render_poses = render_poses[:, :3, :4]
        render_rotations = render_poses[:, :3, :3]
        render_translations = render_poses[:, :3, 3]

        # splits
        i_test = []  # [i_test]
        if args.test_block_size > 0 and args.train_block_size > 0:
            print("splitting timesteps into training (" +
                  str(args.train_block_size) + ") and test (" +
                  str(args.test_block_size) + ") blocks")
            num_timesteps = len(dataset_extras["raw_timesteps"])
            test_timesteps = np.concatenate([
                np.arange(
                    min(num_timesteps, blocks_start + args.train_block_size),
                    min(
                        num_timesteps,
                        blocks_start + args.train_block_size +
                        args.test_block_size,
                    ),
                ) for blocks_start in np.arange(
                    0, num_timesteps, args.train_block_size +
                    args.test_block_size)
            ])
            i_test = [
                imageid for imageid, timestep in enumerate(
                    dataset_extras["imageid_to_timestepid"])
                if timestep in test_timesteps
            ]

        i_test = np.array(i_test)
        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        # near, far
        # if args.no_ndc:
        near = np.ndarray.min(bds) * 0.9
        far = np.ndarray.max(bds) * 1.0
        # else:
        #    near = 0.
        #    far = 1.
        bds_dict = {
            "near": near,
            "far": far,
        }
        if render_kwargs_train_ is not None:
            render_kwargs_train_.update(bds_dict)
        if render_kwargs_test_ is not None:
            render_kwargs_test_.update(bds_dict)

        if return_nerf_volume_extent:
            ray_params = checkpoint_dict["ray_params"]
            min_point, max_point = determine_nerf_volume_extent(
                get_parallelized_render_function(),
                poses,
                hwf[0],
                hwf[1],
                hwf[2],
                ray_params,
                render_kwargs_test,
            )
            extras["min_nerf_volume_point"] = min_point.detach()
            extras["max_nerf_volume_point"] = max_point.detach()

        return (
            images,
            poses,
            all_rotations,
            all_translations,
            bds,
            render_poses,
            render_rotations,
            render_translations,
            i_train,
            i_val,
            i_test,
            near,
            far,
            extras,
        )
Пример #9
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    # Load data
    K = None
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        full_original_track_poses = torch.from_numpy(poses).float().to(device)

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.

        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.
        far = 6.

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'custom':
        images, poses, render_poses, full_original_track_poses, hwf, i_split = load_custom_data(
            args.scene, args.half_res, args.testskip, args.inv)
        print('Loaded RealEstate', images.shape, render_poses.shape, hwf,
              render_poses.shape)
        i_train, i_val, i_test = i_split

        #use ndc
        near = 0.05
        far = args.far

        images = images[..., :3]

    elif args.dataset_type == 'LINEMOD':
        images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(
            args.datadir, args.half_res, args.testskip)
        print(
            f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}'
        )
        print(f'[CHECK HERE] near: {near}, far: {far}.')
        i_train, i_val, i_test = i_split

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip)

        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.
        far = hemi_R + 1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # images are all of the images needed
    # poses (c2w) are all of the poses needed, matches images in terms of indexing
    # i_train, i_val, i_test are the indexes of images/poses that are train, val, test respectively
    # render_poses (c2w) are the poses to render a novel track video with
    # hwf is height, width, focal length (not normalized), which are used to construct K if it hasn't been loaded

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if K is None:
        modifier = 0
        K = np.array([[focal + modifier * W, 0, 0.5 * W],
                      [0, focal + modifier * H, 0.5 * H], [0, 0, 1]])

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    basedir = os.path.join(basedir, dset_to_prefix[args.dataset_type])
    expname = args.expname

    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)
    global_step = start

    bds_dict = {
        'near': near,
        'far': far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        with torch.no_grad():
            if args.render_test:
                # render_test switches to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None

            testsavedir = os.path.join(
                basedir, expname, 'renderonly_{}_{:06d}'.format(
                    'test' if args.render_test else 'path', start))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            rgbs, _ = render_path(render_poses,
                                  hwf,
                                  K,
                                  args.chunk,
                                  render_kwargs_test,
                                  gt_imgs=images,
                                  savedir=testsavedir,
                                  render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                             to8b(rgbs),
                             fps=30,
                             quality=8)

            return

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching
        print('get rays')
        rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:, :3, :4]],
                        0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:, None]],
                                  1)  # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb,
                                [0, 2, 3, 1, 4])  # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            0)  # train images only
        rays_rgb = np.reshape(rays_rgb,
                              [-1, 3, 3])  # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)

        print('done')
        i_batch = 0

    # Move training data to GPU
    if use_batching:
        images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)

    N_iters = 200000 + 1
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))

    start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()

        # Sample random ray batch
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch + N_rand]  # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            target = torch.Tensor(target).to(device)
            pose = poses[img_i, :3, :4]

            if N_rand is not None:
                rays_o, rays_d = get_rays(
                    H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

                if i < args.precrop_iters:
                    dH = int(H // 2 * args.precrop_frac)
                    dW = int(W // 2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H // 2 - dH, H // 2 + dH - 1,
                                           2 * dH),
                            torch.linspace(W // 2 - dW, W // 2 + dW - 1,
                                           2 * dW)), -1)
                    if i == start:
                        print(
                            f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}"
                        )
                else:
                    coords = torch.stack(
                        torch.meshgrid(torch.linspace(0, H - 1, H),
                                       torch.linspace(0, W - 1, W)),
                        -1)  # (H, W, 2)

                coords = torch.reshape(coords, [-1, 2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0],
                                               size=[N_rand],
                                               replace=False)  # (N_rand,)
                select_coords = coords[select_inds].long()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0],
                                  select_coords[:, 1]]  # (N_rand, 3)

        #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H,
                                        W,
                                        K,
                                        chunk=args.chunk,
                                        rays=batch_rays,
                                        verbose=i < 10,
                                        retraw=True,
                                        **render_kwargs_train)

        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][..., -1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()
        optimizer.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate**(global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time() - time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
        #####           end            #####

        # Rest is logging
        if i % args.i_weights == 0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save(
                {
                    'global_step':
                    global_step,
                    'network_fn_state_dict':
                    render_kwargs_train['network_fn'].state_dict(),
                    'network_fine_state_dict':
                    render_kwargs_train['network_fine'].state_dict(),
                    'optimizer_state_dict':
                    optimizer.state_dict(),
                }, path)
            print('Saved checkpoints at', path)

        if i % args.i_video == 0 and i > 0:
            # Turn on testing mode
            compare_dir = os.path.join(
                '/data/vision/billf/intrinsic/neural-render/ericqian/results/',
                dset_to_prefix[args.dataset_type])
            novel_compare_dir = os.path.join(compare_dir, args.scene, 'novel',
                                             'nerf')
            os.makedirs(novel_compare_dir, exist_ok=True)
            with torch.no_grad():
                rgbs, disps = render_path(render_poses,
                                          hwf,
                                          K,
                                          args.chunk,
                                          render_kwargs_test,
                                          savedir=novel_compare_dir)
            moviebase = os.path.join(
                basedir, expname, '{}_renderposes_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)

            moviebase_compare = os.path.join(compare_dir, args.scene, 'novel')
            imageio.mimwrite(moviebase_compare + '/nerf_rgb.mp4',
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase_compare + '/nerf_disp.mp4',
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)

            # Turn on testing mode
            original_compare_dir = os.path.join(compare_dir, args.scene,
                                                'original', 'nerf')
            os.makedirs(original_compare_dir, exist_ok=True)
            with torch.no_grad():
                rgbs, disps = render_path(full_original_track_poses,
                                          hwf,
                                          K,
                                          args.chunk,
                                          render_kwargs_test,
                                          savedir=original_compare_dir)
            moviebase = os.path.join(
                basedir, expname, '{}_fulloriginal_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)

            moviebase_compare = os.path.join(compare_dir, args.scene,
                                             'original')
            imageio.mimwrite(moviebase_compare + '/nerf_rgb.mp4',
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase_compare + '/nerf_disp.mp4',
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(basedir, expname,
                                       'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            with torch.no_grad():
                render_path(torch.Tensor(poses[i_test]).to(device),
                            hwf,
                            K,
                            args.chunk,
                            render_kwargs_test,
                            gt_imgs=images[i_test],
                            savedir=testsavedir)

        if i % args.i_print == 0:
            tqdm.write(
                f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
        """
            print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))

            with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
                tf.contrib.summary.scalar('loss', loss)
                tf.contrib.summary.scalar('psnr', psnr)
                tf.contrib.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.contrib.summary.scalar('psnr0', psnr0)


            if i%args.i_img==0:

                # Log a rendered validation view to Tensorboard
                img_i=np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3,:4]
                with torch.no_grad():
                    rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
                                                        **render_kwargs_test)

                psnr = mse2psnr(img2mse(rgb, target))

                with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):

                    tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
                    tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
                    tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])

                    tf.contrib.summary.scalar('psnr_holdout', psnr)
                    tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])


                if args.N_importance > 0:

                    with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                        tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
                        tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
                        tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
        """

        global_step += 1
Пример #10
0
def cloud_size_vs_performance():
    basedir = './logs'
    expname = 'fern_example'

    config = os.path.join(basedir, expname, 'config.txt')
    print('Args:')
    print(open(config, 'r').read())
    parser = run_nerf.config_parser()

    weights_name = 'model_200000.npy'
    args = parser.parse_args('--config {} --ft_path {}'.format(config, os.path.join(basedir, expname, weights_name)))
    print('loaded args')

    images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 
                                                            recenter=True, bd_factor=.75, 
                                                            spherify=args.spherify)
    H, W, focal = poses[0,:3,-1].astype(np.float32)
    poses = poses[:, :3, :4]
    H = int(H)
    W = int(W)
    hwf = [H, W, focal]

    images = images.astype(np.float32)
    poses = poses.astype(np.float32)

    near = 0.
    far = 1.

    if not isinstance(i_test, list):
        i_test = [i_test]

    if args.llffhold > 0:
        print('Auto LLFF holdout,', args.llffhold)
        i_test = np.arange(images.shape[0])[::args.llffhold]

    _, render_kwargs, start, grad_vars, models = run_nerf.create_nerf(args)
    to_use = i_test[0]

    bds_dict = {
        'near' : tf.cast(near, tf.float32),
        'far' : tf.cast(far, tf.float32),
    }
    render_kwargs.update(bds_dict)

    print('Render kwargs:')
    pprint.pprint(render_kwargs)

    res_dir = "./cloud_size_test"

    res = {}
    res['cloud_size'] = []
    res['mse'] = []
    res['psnr'] = []
    res['time'] = []

    for i in [1,2,4,8,16,32]:
        print(f'Running with cloud downsampled {i}x')
        start_time = time.time()
        ret_vals = run_nerf.render(H, W, focal, c2w=poses[to_use], pc=True, cloudsize=i, **render_kwargs)
        end_time = time.time()
        img = np.clip(ret_vals[0],0,1)
        mse = run_nerf.img2mse(images[to_use], img)
        psnr = run_nerf.mse2psnr(mse)
        res['cloud_size'].append((17 * H * W) // (i * i))
        res['mse'].append(float(mse))
        res['psnr'].append(float(psnr))
        res['time'].append(end_time - start_time)

    # a = [1,2,4,8,16,32]
    # b = [1/x for x in a]

    # make plots
    # cs vs psnr
    fig, ax = plt.subplots(1,1)
    fig.suptitle('PSNR vs Point Cloud Size')
    ax.set_xlabel('Cloud Size')
    ax.set_ylabel('PSNR')
    plt.xscale('log')
    ax.plot(res['cloud_size'],res['psnr'])
    plt.savefig(os.path.join(res_dir, 'cs_psnr.png'))

    fig, ax = plt.subplots(1,1)
    fig.suptitle('PSNR vs Running Time')
    ax.set_xlabel('Time')
    ax.set_ylabel('PSNR')
    plt.xscale('log')
    ax.plot(res['time'],res['psnr'])
    plt.savefig(os.path.join(res_dir, 'time_psnr.png'))

    fig, ax = plt.subplots(1,1)
    fig.suptitle('Running Time vs Cloud Size')
    ax.set_xlabel('Cloud Size')
    ax.set_ylabel('Running Time')
    plt.xscale('log')
    plt.yscale('log')
    ax.plot(res['cloud_size'],res['time'])
    plt.savefig(os.path.join(res_dir, 'cs_time.png'))
    
    with open(os.path.join(res_dir, 'results.txt'), 'w') as outfile:
        json.dump(res,outfile)
Пример #11
0
def train_mu():
    parser = config_parser()
    args = parser.parse_args()

    if args.random_seed is not None:
        print('Fixing random seed', args.random_seed)
        np.random.seed(args.random_seed)
        tf.compat.v1.set_random_seed(args.random_seed)

    # Load data

    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = tf.reduce_min(bds) * .9
            far = tf.reduce_max(bds) * 1.
        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)
    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    mu_workdir = os.path.join(basedir, expname + '_mu')
    os.makedirs(mu_workdir, exist_ok=True)

    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    nerf_render_kwargs_train, nerf_render_kwargs_test, models = create_nerf(
        args)
    start, mu_grad_vars, model_mu, mu_embed_fn, mu_embeddirs_fn = create_mu_model(
        mu_workdir, args)

    bds_dict = {
        'near': tf.cast(near, tf.float32),
        'far': tf.cast(far, tf.float32),
    }
    nerf_render_kwargs_train.update(bds_dict)
    nerf_render_kwargs_test.update(bds_dict)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        if args.render_test:
            # render_test switches to test poses
            images = images[i_test]
        else:
            # Default is smoother render_poses path
            images = None

        testsavedir = os.path.join(
            basedir, expname, 'renderonly_{}_{:06d}'.format(
                'test' if args.render_test else 'path', start))
        os.makedirs(testsavedir, exist_ok=True)
        print('test poses shape', render_poses.shape)

        rgbs, _ = render_path(render_poses,
                              hwf,
                              args.chunk,
                              nerf_render_kwargs_train,
                              gt_imgs=images,
                              savedir=testsavedir,
                              render_factor=args.render_factor)
        print('Done rendering', testsavedir)
        imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                         to8b(rgbs),
                         fps=30,
                         quality=8)

        return

    # Create optimizer
    lrate = args.lrate
    if args.lrate_decay > 0:
        lrate = tf.keras.optimizers.schedules.ExponentialDecay(
            lrate, decay_steps=args.lrate_decay * 1000, decay_rate=0.1)
    optimizer = tf.keras.optimizers.Adam(lrate)
    models['optimizer'] = optimizer
    models['model_mu'] = [model_mu, mu_embed_fn, mu_embeddirs_fn]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step.assign(start)

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching.
        #
        # Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is
        # interpreted as,
        #   axis=0: ray origin in world space
        #   axis=1: ray direction in world space
        #   axis=2: observed RGB color of pixel
        print('get rays')
        # get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3]
        # for each pixel in the image. This stack() adds a new dimension.
        rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]]
        rays = np.stack(rays, axis=0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.concatenate([rays, images[:, None, ...]], 1)
        # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            axis=0)  # train images only
        # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = np.reshape(rays_rgb, [-1, 3, 3])
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        print('done')
        i_batch = 0

    N_iters = 1000000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    writer = tf.summary.create_file_writer(
        os.path.join(basedir, 'summaries', expname))

    for i in range(start, N_iters):
        time0 = time.time()

        # Sample random ray batch

        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch + N_rand]  # [B, 2+1, 3*?]
            batch = tf.transpose(batch, [1, 0, 2])

            # batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position
            # target_s[n, rgb] = example_id, observed color.
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                np.random.shuffle(rays_rgb)
                i_batch = 0

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            pose = poses[img_i, :3, :4]

            if N_rand is not None:
                rays_o, rays_d = get_rays(H, W, focal, pose)
                if i < args.precrop_iters:
                    dH = int(H // 2 * args.precrop_frac)
                    dW = int(W // 2 * args.precrop_frac)
                    coords = tf.stack(
                        tf.meshgrid(tf.range(H // 2 - dH, H // 2 + dH),
                                    tf.range(W // 2 - dW, W // 2 + dW),
                                    indexing='ij'), -1)
                    if i < 10:
                        print('precrop', dH, dW, coords[0, 0], coords[-1, -1])
                else:
                    coords = tf.stack(
                        tf.meshgrid(tf.range(H), tf.range(W), indexing='ij'),
                        -1)
                coords = tf.reshape(coords, [-1, 2])
                select_inds = np.random.choice(coords.shape[0],
                                               size=[N_rand],
                                               replace=False)
                select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis])
                rays_o = tf.gather_nd(rays_o, select_inds)
                rays_d = tf.gather_nd(rays_d, select_inds)
                batch_rays = tf.stack([rays_o, rays_d], 0)
                target_s = tf.gather_nd(target, select_inds)

        #####  Core optimization loop  #####
        _, _, _, extras = nr.render(H,
                                    W,
                                    focal,
                                    chunk=args.chunk,
                                    rays=batch_rays,
                                    verbose=i < 10,
                                    retraw=True,
                                    **nerf_render_kwargs_train)

        mu_exp = extras['mu_exp']
        pts = extras['pts']
        viewdirs = extras['viewdirs']

        with tf.GradientTape(persistent=False) as tape:
            mu_out = run_mu_model(pts, viewdirs, model_mu, mu_embed_fn,
                                  mu_embeddirs_fn)
            mu_out = tf.reshape(mu_out, mu_exp.shape)
            mu_out = tf.math.tanh(tf.nn.relu(mu_out))
            # print(mu_out.shape, mu_exp.shape)
            mu_loss = img2mse(mu_exp, mu_out)

        gradients = tape.gradient(mu_loss, mu_grad_vars)
        optimizer.apply_gradients(zip(gradients, mu_grad_vars))

        dt = time.time() - time0

        #####           end            #####

        # Rest is logging

        def save_mu_weights(net, i):
            path = os.path.join(mu_workdir, 'model_mu_{:06d}.npy'.format(i))
            np.save(path, net.get_weights())
            print('saved weights at', path)

        if i % args.i_weights == 0:
            save_mu_weights(model_mu, i)

        if i % args.i_print == 0 or i < 10:
            print(expname + '_mu', i, mu_loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))
            with writer.as_default():
                tf.summary.scalar('mu_loss', mu_loss, step=i + 1)

        global_step.assign_add(1)
Пример #12
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    # Load data

    if args.dataset_type == "llff":
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=0.75,
            spherify=args.spherify,
        )
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print("Loaded llff", images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print("Auto LLFF holdout,", args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print("DEFINING BOUNDS")
        if args.no_ndc:
            near = np.ndarray.min(bds) * 0.9
            far = np.ndarray.max(bds) * 1.0

        else:
            near = 0.0
            far = 1.0
        print("NEAR FAR", near, far)

    elif args.dataset_type == "blender":
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip)
        print("Loaded blender", images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.0
        far = 6.0

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1.0 -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == "deepvoxels":

        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip)

        print("Loaded deepvoxels", images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.0
        far = hemi_R + 1.0

    else:
        print("Unknown dataset type", args.dataset_type, "exiting")
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, "args.txt")
    with open(f, "w") as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write("{} = {}\n".format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, "config.txt")
        with open(f, "w") as file:
            file.write(open(args.config, "r").read())
    # data prepare ready

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)

    global_step = start

    bds_dict = {
        "near": near,
        "far": far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print("RENDER ONLY")
        with torch.no_grad():
            if args.render_test:
                # render_test switches to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None

            testsavedir = os.path.join(
                basedir,
                expname,
                "renderonly_{}_{:06d}".format(
                    "test" if args.render_test else "path", start),
            )
            os.makedirs(testsavedir, exist_ok=True)
            print("test poses shape", render_poses.shape)

            rgbs, _ = render_path(
                render_poses,
                hwf,
                args.chunk,
                render_kwargs_test,
                gt_imgs=images,
                savedir=testsavedir,
                render_factor=args.render_factor,
            )
            print("Done rendering", testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, "video.mp4"),
                             to8b(rgbs),
                             fps=30,
                             quality=8)

            return

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching
        print("get rays")
        rays = np.stack(
            [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]],
            0)  # [N, ro+rd, H, W, 3]
        print("done, concats")
        rays_rgb = np.concatenate([rays, images[:, None]],
                                  1)  # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb,
                                [0, 2, 3, 1, 4])  # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            0)  # train images only
        rays_rgb = np.reshape(rays_rgb,
                              [-1, 3, 3])  # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print("shuffle rays")
        np.random.shuffle(rays_rgb)

        print("done")
        i_batch = 0

    # Move training data to GPU
    images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)

    N_iters = 200000 + 1
    print("Begin")
    print("TRAIN views are", i_train)
    print("TEST views are", i_test)
    print("VAL views are", i_val)

    # Summary writers
    # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))

    start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()

        # Sample random ray batch
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch + N_rand]  # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]  # get one image
            pose = poses[img_i, :3, :4]  # get the pose of the image

            if N_rand is not None:
                # for one image, there will be HxW cameras rays,
                # with the same start points (ro) and different end points (rd)
                rays_o, rays_d = get_rays(
                    H, W, focal, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

                if i < args.precrop_iters:
                    dH = int(H // 2 * args.precrop_frac)
                    dW = int(W // 2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H // 2 - dH, H // 2 + dH - 1,
                                           2 * dH),
                            torch.linspace(W // 2 - dW, W // 2 + dW - 1,
                                           2 * dW),
                        ),
                        -1,
                    )
                    if i == start:
                        print(
                            f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}"
                        )
                else:
                    coords = torch.stack(
                        torch.meshgrid(torch.linspace(0, H - 1, H),
                                       torch.linspace(0, W - 1, W)),
                        -1,
                    )  # (H, W, 2)

                coords = torch.reshape(coords, [-1, 2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0],
                                               size=[N_rand],
                                               replace=False)  # (N_rand,)
                # random select some 2D coordinates with (x, y) index to select some camera rays randomly
                select_coords = coords[select_inds].long()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)
                # get the rgb values from the orginal image
                target_s = target[select_coords[:, 0],
                                  select_coords[:, 1]]  # (N_rand, 3)

        #####  Core optimization loop  #####
        # in this iteration, optimize the model
        rgb, disp, acc, extras = render(
            H,
            W,
            focal,
            chunk=args.chunk,
            rays=batch_rays,
            verbose=i < 10,
            retraw=True,
            **render_kwargs_train,
        )

        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        trans = extras["raw"][..., -1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if "rgb0" in extras:
            img_loss0 = img2mse(extras["rgb0"], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()
        optimizer.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate**(global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group["lr"] = new_lrate
        ################################

        dt = time.time() - time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
        #####           end            #####

        # Rest is logging
        if i % args.i_weights == 0:
            path = os.path.join(basedir, expname, "{:06d}.tar".format(i))
            torch.save(
                {
                    "global_step":
                    global_step,
                    "network_fn_state_dict":
                    render_kwargs_train["network_fn"].state_dict(),
                    "network_fine_state_dict":
                    render_kwargs_train["network_fine"].state_dict(),
                    "optimizer_state_dict":
                    optimizer.state_dict(),
                },
                path,
            )
            print("Saved checkpoints at", path)

        if i % args.i_video == 0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                rgbs, disps = render_path(render_poses, hwf, args.chunk,
                                          render_kwargs_test)
            print("Done, saving", rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname,
                                     "{}_spiral_{:06d}_".format(expname, i))
            imageio.mimwrite(moviebase + "rgb.mp4",
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase + "disp.mp4",
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)

            # if args.use_viewdirs:
            #     render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
            #     with torch.no_grad():
            #         rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
            #     render_kwargs_test['c2w_staticcam'] = None
            #     imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(basedir, expname,
                                       "testset_{:06d}".format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print("test poses shape", poses[i_test].shape)
            with torch.no_grad():
                render_path(
                    torch.Tensor(poses[i_test]).to(device),
                    hwf,
                    args.chunk,
                    render_kwargs_test,
                    gt_imgs=images[i_test],
                    savedir=testsavedir,
                )
            print("Saved test set")

        if i % args.i_print == 0:
            tqdm.write(
                f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
        """
            print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))

            with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
                tf.contrib.summary.scalar('loss', loss)
                tf.contrib.summary.scalar('psnr', psnr)
                tf.contrib.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.contrib.summary.scalar('psnr0', psnr0)


            if i%args.i_img==0:

                # Log a rendered validation view to Tensorboard
                img_i=np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3,:4]
                with torch.no_grad():
                    rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
                                                        **render_kwargs_test)

                psnr = mse2psnr(img2mse(rgb, target))

                with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):

                    tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
                    tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
                    tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])

                    tf.contrib.summary.scalar('psnr_holdout', psnr)
                    tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])


                if args.N_importance > 0:

                    with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                        tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
                        tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
                        tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
        """

        global_step += 1
Пример #13
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    # Multi-GPU
    args.n_gpus = torch.cuda.device_count()
    print("Using {} GPU(s).".format(args.n_gpus))

    # Load data
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test, gt_depths = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.

        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)
        DEPTH_RATIO = 0.00335

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.
        far = 6.

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip)

        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.
        far = hemi_R + 1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)
    global_step = start

    bds_dict = {
        'near': near,
        'far': far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        with torch.no_grad():
            if args.render_test:
                # render_test switches to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None

            testsavedir = os.path.join(
                basedir, expname, 'renderonly_{}_{:06d}'.format(
                    'test' if args.render_test else 'path', start))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            rgbs, _, _ = render_path(render_poses,
                                     hwf,
                                     args.chunk,
                                     render_kwargs_test,
                                     gt_imgs=images,
                                     savedir=testsavedir,
                                     render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                             to8b(rgbs),
                             fps=30,
                             quality=8)

            return

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching
        print('get rays')
        rays = np.stack(
            [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]],
            0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:, None]],
                                  1)  # [N, ro+rd+rgb, H, W, 3]
        # [N, ro+rd+rgb+depth, H, W, 3]
        gt_depths_channeled = np.zeros(
            (gt_depths.shape[0], gt_depths.shape[1], gt_depths.shape[2], 3))
        for i in range(gt_depths.shape[0]):
            gt_depths_channeled[i] = np.stack([gt_depths[i]] * 3, 2)
        rays_rgb = np.concatenate(
            [rays_rgb, gt_depths_channeled[:, None, ...]], 1)
        rays_rgb = np.transpose(
            rays_rgb, [0, 2, 3, 1, 4])  # [N, H, W, ro+rd+rgb+depth, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            0)  # train images only
        rays_rgb = np.reshape(rays_rgb,
                              [-1, 4, 3])  # [(N-1)*H*W, ro+rd+rgb+depth, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)

        print('done')
        i_batch = 0

    # Move training data to GPU
    images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    gt_depths = torch.Tensor(gt_depths).to(device)
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)

    N_iters = 200000 + 1
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))

    start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()

        # Sample random ray batch
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch + N_rand]  # [B, 2+1+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            # target_depth[n, depth] = example_id, observed color.
            batch_rays, target_s, target_depth = batch[:2], batch[2], batch[
                3, :, 0]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            pose = poses[img_i, :3, :4]
            depth_i = gt_depths[img_i]
            depth_i = depth_i * DEPTH_RATIO  # scale down the depth
            pix_T_cam = pack_intrinsics(4.39947516e+02, 4.39947516e+02, 240,
                                        320)
            # distortion = 2.78915786e-02
            xyz_i = depth2pointcloud(
                depth_i.unsqueeze(0).unsqueeze(0),
                pix_T_cam)  #unproject to get pointcloud
            origin_T_camX = eye_4x4(1)
            origin_T_camX[:, :3, :4] = poses[img_i:img_i + 1, :4, :4]
            xyz_w_i = apply_4x4(
                origin_T_camX,
                xyz_i)  # get the world coordinates (NeRF works in C2W)

            if N_rand is not None:
                rays_o, rays_d = get_rays(
                    H, W, focal, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

                if i < args.precrop_iters:
                    dH = int(H // 2 * args.precrop_frac)
                    dW = int(W // 2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H // 2 - dH, H // 2 + dH - 1,
                                           2 * dH),
                            torch.linspace(W // 2 - dW, W // 2 + dW - 1,
                                           2 * dW)), -1)
                    if i == start:
                        print(
                            "[Config] Center cropping of size {} x {} is enabled until iter {}"
                            .format(2 * dH, 2 * dW, args.precrop_iters))
                else:
                    coords = torch.stack(
                        torch.meshgrid(torch.linspace(0, H - 1, H),
                                       torch.linspace(0, W - 1, W)),
                        -1)  # (H, W, 2)

                coords = torch.reshape(coords, [-1, 2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0],
                                               size=[N_rand],
                                               replace=False)  # (N_rand,)
                select_coords = coords[select_inds].long()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)  # (2, N_rand, 3)
                target_s = target[select_coords[:, 0],
                                  select_coords[:, 1]]  # (N_rand, 3)
                target_depth = depth_i[select_coords[:, 0],
                                       select_coords[:, 1]]  # (N_rand, 1)
                target_xyz = xyz_w_i[0, select_inds]  # (1, N_rand, 3)
                target_xyz = target_xyz[
                    target_depth >
                    0.]  # Only select those points where depth is valid

        #####  Core optimization loop  #####
        rgb, disp, acc, depth, extras = render(H,
                                               W,
                                               focal,
                                               chunk=args.chunk,
                                               rays=batch_rays,
                                               verbose=i < 10,
                                               retraw=True,
                                               **render_kwargs_train)

        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][..., -1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        depth_mask = torch.ones(target_depth[target_depth != 0.].shape)
        avg_depth_ratio = reduce_masked_mean(
            depth[target_depth != 0.] / target_depth[target_depth != 0.],
            depth_mask)

        # st()
        free_occ_xyz = fill_ray_single(target_xyz)
        target_labels = torch.zeros(
            (free_occ_xyz.shape[0], free_occ_xyz.shape[1], 1), device='cuda')
        target_labels[:, -1] = 1  # last poinnt is occupied
        free_occ_xyz = torch.reshape(free_occ_xyz, (-1, 3)).unsqueeze(1)
        target_labels = torch.reshape(target_labels, (-1, 1))
        rays_to_pass = batch_rays[:, target_depth > 0., :]  # (2, N, 3)
        rays_to_pass = rays_to_pass.repeat(1, 100, 1)  # (2, N*samps, 3)
        sigma = eval_depth(
            pts=free_occ_xyz,
            rays=rays_to_pass,
            network_fn=render_kwargs_train['network_fn'],
            network_query_fn=render_kwargs_train['network_query_fn'])
        occfreespace_loss = F.binary_cross_entropy_with_logits(
            sigma, target_labels)
        loss += occfreespace_loss

        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)
            avg_depth_ratio0 = reduce_masked_mean(
                extras['depth0'][target_depth != 0.] /
                target_depth[target_depth != 0.], depth_mask)

        loss.backward()
        optimizer.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate**(global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time() - time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
        #####           end            #####

        # Rest is logging
        if i % args.i_weights == 0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save(
                {
                    'global_step':
                    global_step,
                    'network_fn_state_dict':
                    render_kwargs_train['network_fn'].state_dict(),
                    'network_fine_state_dict':
                    render_kwargs_train['network_fine'].state_dict(),
                    'optimizer_state_dict':
                    optimizer.state_dict(),
                }, path)
            print('Saved checkpoints at', path)

        if i % args.i_video == 0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                rgbs, disps, depths = render_path(render_poses, hwf,
                                                  args.chunk,
                                                  render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape, depths.shape)
            moviebase = os.path.join(basedir, expname,
                                     '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase + 'depth.mp4',
                             to8b(depths / np.max(depths)),
                             fps=30,
                             quality=8)

            # if args.use_viewdirs:
            #     render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
            #     with torch.no_grad():
            #         rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
            #     render_kwargs_test['c2w_staticcam'] = None
            #     imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(basedir, expname,
                                       'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            with torch.no_grad():
                render_path(torch.Tensor(poses[i_test]).to(device),
                            hwf,
                            args.chunk,
                            render_kwargs_test,
                            gt_imgs=images[i_test],
                            savedir=testsavedir)
            print('Saved test set')

        if i % args.i_print == 0:
            tqdm.write("[TRAIN] Iter: {} Loss: {}  PSNR: {}".format(
                i, loss.item(), psnr.item()))
            writer.add_scalar('loss', loss, i)
            writer.add_scalar('img_loss', img_loss, i)
            writer.add_scalar('occ_loss', occfreespace_loss, i)
            writer.add_scalar('psnr', psnr, i)
            writer.add_histogram('tran', trans, i)
            writer.add_scalar('avg_depth_ratio', avg_depth_ratio, i)
            if args.N_importance > 0:
                writer.add_scalar('psnr0', psnr0, i)

            if i % args.i_img == 0:
                # Log a rendered validation view to Tensorboard
                img_i = np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3, :4]
                with torch.no_grad():
                    rgb, disp, acc, depth, extras = render(
                        H,
                        W,
                        focal,
                        chunk=args.chunk,
                        c2w=pose,
                        **render_kwargs_test)

                psnr = mse2psnr(img2mse(rgb, target))

                writer.add_image('rgb',
                                 to8b(rgb.cpu().numpy()),
                                 i,
                                 dataformats='HWC')
                writer.add_image('disp', disp.unsqueeze(0), i)
                writer.add_image('acc', acc.unsqueeze(0), i)
                writer.add_image('depth', depth.unsqueeze(0), i)

                writer.add_scalar('psnr_holdout', psnr, i)
                writer.add_image('rgb_holdout', target, i, dataformats='HWC')

        global_step += 1
Пример #14
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    # Load data

    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.

        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.
        far = 6.

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip)

        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.
        far = hemi_R + 1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)
    global_step = start

    bds_dict = {
        'near': near,
        'far': far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        with torch.no_grad():
            if args.render_test:
                # render_test switches to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None

            testsavedir = os.path.join(
                basedir, expname, 'renderonly_{}_{:06d}'.format(
                    'test' if args.render_test else 'path', start))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            rgbs, _ = render_path(render_poses,
                                  hwf,
                                  args.chunk,
                                  render_kwargs_test,
                                  gt_imgs=images,
                                  savedir=testsavedir,
                                  render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                             to8b(rgbs),
                             fps=30,
                             quality=8)

            return

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching
        print('get rays')
        rays = np.stack(
            [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]],
            0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:, None]],
                                  1)  # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb,
                                [0, 2, 3, 1, 4])  # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            0)  # train images only
        rays_rgb = np.reshape(rays_rgb,
                              [-1, 3, 3])  # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)

        print('done')
        i_batch = 0

    # Move training data to GPU
    images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)

    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))

    start = start + 1
    for i in trange(start, args.N_iters + 1):
        time0 = time.time()

        # Sample random ray batch
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch + N_rand]  # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            pose = poses[img_i, :3, :4]

            if N_rand is not None:
                rays_o, rays_d = get_rays(
                    H, W, focal, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

                if i < args.precrop_iters:
                    dH = int(H // 2 * args.precrop_frac)
                    dW = int(W // 2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H // 2 - dH, H // 2 + dH - 1,
                                           2 * dH),
                            torch.linspace(W // 2 - dW, W // 2 + dW - 1,
                                           2 * dW)), -1)
                    if i == start:
                        print(
                            f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}"
                        )
                else:
                    coords = torch.stack(
                        torch.meshgrid(torch.linspace(0, H - 1, H),
                                       torch.linspace(0, W - 1, W)),
                        -1)  # (H, W, 2)

                coords = torch.reshape(coords, [-1, 2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0],
                                               size=[N_rand],
                                               replace=False)  # (N_rand,)
                select_coords = coords[select_inds].long()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0],
                                  select_coords[:, 1]]  # (N_rand, 3)

        #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H,
                                        W,
                                        focal,
                                        chunk=args.chunk,
                                        rays=batch_rays,
                                        verbose=i < 10,
                                        retraw=True,
                                        **render_kwargs_train)

        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][..., -1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()
        optimizer.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate**(global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time() - time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
        #####           end            #####

        # Rest is logging
        if i % args.i_weights == 0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save(
                {
                    'global_step':
                    global_step,
                    'network_fn_state_dict':
                    render_kwargs_train['network_fn'].state_dict(),
                    'network_fine_state_dict':
                    render_kwargs_train['network_fine'].state_dict(),
                    'optimizer_state_dict':
                    optimizer.state_dict(),
                }, path)
            print('Saved checkpoints at', path)

        if i % args.i_video == 0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                rgbs, disps = render_path(render_poses, hwf, args.chunk,
                                          render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname,
                                     '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)

            # if args.use_viewdirs:
            #     render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
            #     with torch.no_grad():
            #         rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
            #     render_kwargs_test['c2w_staticcam'] = None
            #     imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(basedir, expname,
                                       'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            with torch.no_grad():
                render_path(torch.Tensor(poses[i_test]).to(device),
                            hwf,
                            args.chunk,
                            render_kwargs_test,
                            gt_imgs=images[i_test],
                            savedir=testsavedir)
            print('Saved test set')

        if i % args.i_print == 0:
            tqdm.write(
                f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")

        global_step += 1
Пример #15
0
def get_data():
    basedir = './logs'
    expname = 'fern_example'

    config = os.path.join(basedir, expname, 'config.txt')
    print('Args:')
    print(open(config, 'r').read())
    parser = run_nerf.config_parser()

    weights_name = 'model_200000.npy'
    args = parser.parse_args('--config {} --ft_path {}'.format(config, os.path.join(basedir, expname, weights_name)))
    print('loaded args')

    images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 
                                                            recenter=True, bd_factor=.75, 
                                                            spherify=args.spherify)
    H, W, focal = poses[0,:3,-1].astype(np.float32)
    poses = poses[:, :3, :4]
    H = int(H)
    W = int(W)
    hwf = [H, W, focal]

    images = images.astype(np.float32)
    poses = poses.astype(np.float32)

    near = 0.
    far = 1.

    if not isinstance(i_test, list):
        i_test = [i_test]

    if args.llffhold > 0:
        print('Auto LLFF holdout,', args.llffhold)
        i_test = np.arange(images.shape[0])[::args.llffhold]

    _, render_kwargs, start, grad_vars, models = run_nerf.create_nerf(args)

    bds_dict = {
        'near' : tf.cast(near, tf.float32),
        'far' : tf.cast(far, tf.float32),
    }
    render_kwargs.update(bds_dict)

    print('Render kwargs:')
    pprint.pprint(render_kwargs)

    results = {}
    results['pc'] = {}
    results['no_pc'] = {}

    # NOTE: Where to output results!
    result_directory = "./fern_pc_results"
    img_dir = os.path.join(result_directory, "imgs")

    down = 1

    plt.imsave(os.path.join(img_dir, f"GT{i_test[0]}.png"), images[i_test[0]])
    plt.imsave(os.path.join(img_dir, f"GT{i_test[1]}.png"), images[i_test[1]])

    for num_samps in [4,8,16,32,64]:
        print(f'Running {num_samps} sample test')
        for pc in [True, False]:
            print(f'{"not " if not pc else ""}using pc')
            results['pc' if pc else 'no_pc'][num_samps] = {}
            render_kwargs['N_samples'] = num_samps
            render_kwargs['N_importance'] = 2*num_samps

            total_time = 0
            total_mse = 0
            total_psnr = 0
            for i in [i_test[0], i_test[1]]:
                gt = images[i]
                start_time = time.time()
                ret_vals = run_nerf.render(H//down, W//down, focal/down, c2w=poses[i], pc=pc, cloudsize=16, **render_kwargs)
                end_time = time.time()
                
                # add to cum time
                total_time += (end_time - start_time)
                
                # add to accuracy
                img = np.clip(ret_vals[0],0,1)
                # TODO: make sure this is commented out for real results (just used to test that it runs)
                # mse = run_nerf.img2mse(np.zeros((H//down, W//down,3), dtype=np.float32), img)
                mse = run_nerf.img2mse(gt, img)
                psnr = run_nerf.mse2psnr(mse)
                total_mse += float(mse)
                total_psnr += float(psnr)

                plt.imsave(os.path.join(img_dir, f'IMG{i}_{"pc" if pc else "no_pc"}_{num_samps}samples.png'), img)

            total_time /= 2.
            total_mse /= 2.
            total_psnr /= 2.
            results['pc' if pc else 'no_pc'][num_samps]['time'] = total_time
            results['pc' if pc else 'no_pc'][num_samps]['mse'] = total_mse
            results['pc' if pc else 'no_pc'][num_samps]['psnr'] = total_psnr

    with open(os.path.join(result_directory, 'results.txt'), 'w') as outfile:
        json.dump(results,outfile)
Пример #16
0
def train():
    
    parser = config_parser()
    args = parser.parse_args()
    
    
    # Load data
    
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 
                                                                  recenter=True, bd_factor=.75, 
                                                                  spherify=args.spherify)
        hwf = poses[0,:3,-1]
        poses = poses[:,:3,:4]
        print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]
        
        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]
            
        i_val = i_test
        i_train = np.array([i for i in np.arange(int(images.shape[0])) if 
                        (i not in i_test and i not in i_val)])
        
        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = tf.reduce_min(bds) * .9
            far = tf.reduce_max(bds) * 1.
        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)
            

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split
        
        near = 2.
        far = 6.
        
        if args.white_bkgd:
            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]
        
        
    elif args.dataset_type == 'deepvoxels':
        
        images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, 
                                                                 basedir=args.datadir, 
                                                                 testskip=args.testskip)
        
        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split
        
        hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
        near = hemi_R-1.
        far = hemi_R+1.
        

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return
    
    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]
    
    if args.render_test:
        render_poses = np.array(poses[i_test])

    
    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())
    
    
    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf(args)
    
    bds_dict = {
        'near' : tf.cast(near, tf.float32),
        'far' : tf.cast(far, tf.float32),
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)
    
    
    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        if args.render_test:
            # render_test switches to test poses
            images = images[i_test]
        else:
            # Default is smoother render_poses path
            images = None
            
        testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', global_step.numpy()))
        os.makedirs(testsavedir, exist_ok=True)
        print('test poses shape', render_poses.shape)
        
        rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
        print('Done rendering', testsavedir) 
        imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
        
        return
        
        
    # Create optimizer
    lrate = args.lrate
    if args.lrate_decay > 0:
        lrate = tf.keras.optimizers.schedules.ExponentialDecay(lrate,
            decay_steps=args.lrate_decay * 1000, decay_rate=0.1)
    optimizer = tf.keras.optimizers.Adam(lrate)
    models['optimizer'] = optimizer
    
    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step.assign(start)
    
    
    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching
        print('get rays')
        rays = np.stack([get_rays_np(H, W, focal, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]    
        rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
        rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]    
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        print('done')
        i_batch = 0
        
    
    N_iters = 1000000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)
    
    # Summary writers
    writer = tf.contrib.summary.create_file_writer(os.path.join(basedir, 'summaries', expname))
    writer.set_as_default()
        
        
    for i in range(start, N_iters):
        time0 = time.time()
        
        # Sample random ray batch
        
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
            batch = tf.transpose(batch, [1,0,2])
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                np.random.shuffle(rays_rgb)
                i_batch = 0
            
        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            pose = poses[img_i, :3,:4]
            
            if N_rand is not None:
                rays_o, rays_d = get_rays(H, W, focal, pose)
                coords = tf.stack(tf.meshgrid(tf.range(H), tf.range(W), indexing='ij'), -1)
                coords = tf.reshape(coords, [-1,2])
                select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)
                select_inds = tf.gather_nd(coords, select_inds[:,tf.newaxis])
                rays_o = tf.gather_nd(rays_o, select_inds)
                rays_d = tf.gather_nd(rays_d, select_inds)
                batch_rays = tf.stack([rays_o, rays_d], 0)
                target_s = tf.gather_nd(target, select_inds)
        
        
        #####  Core optimization loop  #####
        
        with tf.GradientTape() as tape:
            
            rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, 
                                                   verbose=i < 10, retraw=True, 
                                                   **render_kwargs_train)
            
            img_loss = img2mse(rgb, target_s)
            trans = extras['raw'][...,-1]
            loss = img_loss 
            psnr = mse2psnr(img_loss)
            
            if 'rgb0' in extras:
                img_loss0 = img2mse(extras['rgb0'], target_s)
                loss += img_loss0
                psnr0 = mse2psnr(img_loss0)

        gradients = tape.gradient(loss, grad_vars)
        optimizer.apply_gradients(zip(gradients, grad_vars))
        
        dt = time.time()-time0
        
        #####           end            #####
        
        
        # Rest is logging

        def save_weights(net, prefix, i): 
            path = os.path.join(basedir, expname, '{}_{:06d}.npy'.format(prefix, i))
            np.save(path, net.get_weights())
            print('saved weights at', path)
    
        if i%args.i_weights==0:
            for k in models:
                save_weights(models[k], k, i)
                
            
        if i%args.i_video==0 and i > 0:
            
            rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
            
            if args.use_viewdirs:
                render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
                rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
                render_kwargs_test['c2w_staticcam'] = None
                imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)
                
                
        if i%args.i_testset==0 and i > 0:
            testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            render_path(poses[i_test], hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
            print('Saved test set') 
            
            

        if i%args.i_print==0 or i < 10:
            
            print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))
            with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
                tf.contrib.summary.scalar('loss', loss)
                tf.contrib.summary.scalar('psnr', psnr)
                tf.contrib.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.contrib.summary.scalar('psnr0', psnr0)
                    

            if i%args.i_img==0:
                
                # Log a rendered validation view to Tensorboard
                img_i=np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3,:4]
                
                rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, 
                                                       **render_kwargs_test)
                
                psnr = mse2psnr(img2mse(rgb, target))

                with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                    
                    tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
                    tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
                    tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])
                
                    tf.contrib.summary.scalar('psnr_holdout', psnr)
                    tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
                
                
                if args.N_importance > 0:

                    with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                        tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
                        tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
                        tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
                        
                    

        global_step.assign_add(1)
Пример #17
0
def train():

    parser = config_parser()
    args = parser.parse_args()
    
    if args.random_seed is not None:
        print('Fixing random seed', args.random_seed)
        np.random.seed(args.random_seed)
        tf.compat.v1.set_random_seed(args.random_seed)

    # Load data
    sc = 1.
    center = None
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test, center = load_llff_data(args.datadir, args.factor,
                                                                  recenter=True, bd_factor=.75,
                                                                  spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        #     with open(f, 'w') as file:
        #         file.write(open(args.config, 'r').read())
        print('Loaded llff', images.shape,
              render_poses.shape, hwf, args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            # i_test = np.arange(images.shape[0])[::args.llffhold]
            i_test = np.arange(images.shape[0])[62::3]
            #i_test = np.concatenate([np.arange(images.shape[0])[-4::],
             #                        np.arange(images.shape[0])[0:int(images.shape[0] /# args.llffhold):]
              #                       ])

        i_val = i_test
        i_train = np.array([i for i in np.arange(int(images.shape[0]))])
        # i_train = np.array([i for i in np.arange(int(images.shape[0])) if
        #                     (i not in i_test and i not in i_val)])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = tf.reduce_min(bds) * .9
            far = tf.reduce_max(bds) * 1.
        else:
            near = 0.
            far = 1.
        sc = 1. / (bds.min() * .75)
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape,
              render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.
        far = 6.

        if args.white_bkgd:
            images = images[..., :3]*images[..., -1:] + (1.-images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
                                                                 basedir=args.datadir,
                                                                 testskip=args.testskip)

        print('Loaded deepvoxels', images.shape,
              render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R-1.
        far = hemi_R+1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    if args.render_test:
        render_poses = np.array(poses[i_test])
    if args.render_poses:
        #     load the render path from file
        render_poses = np.loadtxt(os.path.join(args.datadir, "render_poses.txt"))
        if render_poses.shape[0] >= 15:
            # render_poses.reshape((render_poses.shape[0] / 15, 3,5))
            render_poses = np.reshape(render_poses, (render_poses.shape[0] / 15, 3, 5)).astype(np.float32)
        else:
            # render_poses.reshape((render_poses.shape[0], 3, 5))
            render_poses = np.reshape(render_poses, (render_poses.shape[0], 3, 5)).astype(np.float32)
        # Correct rotation matrix ordering and move variable dim to axis 0
        render_poses = np.concatenate([render_poses[:, 1:2, :], -render_poses[:, 0:1, :], render_poses[:, 2:, :]], 1)
        render_poses = np.moveaxis(render_poses, -1, 0).astype(np.float32)
        # Rescale if bd_factor is provided
        # sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor)
        print("sc", sc)
        render_poses[:, :3, 3] *= sc #1.333 # output from load_llff.py
        # render_poses = recenter_poses(render_poses) # should recenter according to

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf modelllffhold
    render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf(
        args)

    bds_dict = {
        'near': tf.cast(near, tf.float32),
        'far': tf.cast(far, tf.float32),
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        if args.render_test:
            # render_test switches to test poses
            images = images[i_test]
        else:
            # Default is smoother render_poses path
            images = None

        folder = 'path'
        if args.render_test:
            folder = 'test'
        elif args.render_poses:
            folder = 'poses'
        # testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format(
        #     'test' if args.render_test else 'path', start))
        testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format(
            folder, start))
        os.makedirs(testsavedir, exist_ok=True)
        print('test poses shape', render_poses.shape)

        rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test,
                              gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor, ndc=args.NDC)
        print('Done rendering', testsavedir)
        imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                         to8b(rgbs), fps=30, quality=8)

        return

    # Create optimizer
    lrate = args.lrate
    if args.lrate_decay > 0:
        lrate = tf.keras.optimizers.schedules.ExponentialDecay(lrate,
                                                               decay_steps=args.lrate_decay * 1000, decay_rate=0.1)
    optimizer = tf.keras.optimizers.Adam(lrate)
    models['optimizer'] = optimizer

    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step.assign(start)

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching.
        #
        # Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is
        # interpreted as,
        #   axis=0: ray origin in world space
        #   axis=1: ray direction in world space
        #   axis=2: observed RGB color of pixel
        print('get rays')
        # get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3]
        # for each pixel in the image. This stack() adds a new dimension.
        rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]]
        rays = np.stack(rays, axis=0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.concatenate([rays, images[:, None, ...]], 1)
        # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])
        rays_rgb = np.stack([rays_rgb[i]
                             for i in i_train], axis=0)  # train images only
        # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = np.reshape(rays_rgb, [-1, 3, 3])
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        print('done')
        i_batch = 0

    N_iters = 1000000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    writer = tf.contrib.summary.create_file_writer(
        os.path.join(basedir, 'summaries', expname))
    writer.set_as_default()

    for i in range(start, N_iters):
        time0 = time.time()

        # Sample random ray batch

        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch+N_rand]  # [B, 2+1, 3*?]
            batch = tf.transpose(batch, [1, 0, 2])

            # batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position
            # target_s[n, rgb] = example_id, observed color.
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                np.random.shuffle(rays_rgb)
                i_batch = 0

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            pose = poses[img_i, :3, :4]

            if N_rand is not None:
                rays_o, rays_d = get_rays(H, W, focal, pose)
                if i < args.precrop_iters:
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = tf.stack(tf.meshgrid(
                        tf.range(H//2 - dH, H//2 + dH), 
                        tf.range(W//2 - dW, W//2 + dW), 
                        indexing='ij'), -1)
                    if i < 10:
                        print('precrop', dH, dW, coords[0,0], coords[-1,-1])
                else:
                    coords = tf.stack(tf.meshgrid(
                        tf.range(H), tf.range(W), indexing='ij'), -1)
                coords = tf.reshape(coords, [-1, 2])
                select_inds = np.random.choice(
                    coords.shape[0], size=[N_rand], replace=False)
                select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis])
                rays_o = tf.gather_nd(rays_o, select_inds)
                rays_d = tf.gather_nd(rays_d, select_inds)
                batch_rays = tf.stack([rays_o, rays_d], 0)
                target_s = tf.gather_nd(target, select_inds)

        #####  Core optimization loop  #####

        with tf.GradientTape() as tape:

            # Make predictions for color, disparity, accumulated opacity.
            rgb, disp, acc, extras = render(
                H, W, focal, chunk=args.chunk, rays=batch_rays, ndc=args.NDC,
                verbose=i < 10, retraw=True, **render_kwargs_train)

            # Compute MSE loss between predicted and true RGB.
            img_loss = img2mse(rgb, target_s)
            trans = extras['raw'][..., -1]
            loss = img_loss
            psnr = mse2psnr(img_loss)

            # Add MSE loss for coarse-grained model
            if 'rgb0' in extras:
                img_loss0 = img2mse(extras['rgb0'], target_s)
                loss += img_loss0
                psnr0 = mse2psnr(img_loss0)

        gradients = tape.gradient(loss, grad_vars)
        optimizer.apply_gradients(zip(gradients, grad_vars))

        dt = time.time()-time0

        #####           end            #####

        # Rest is logging

        def save_weights(net, prefix, i):
            path = os.path.join(
                basedir, expname, '{}_{:06d}.npy'.format(prefix, i))
            np.save(path, net.get_weights())
            print('saved weights at', path)

        if i % args.i_weights == 0:
            for k in models:
                save_weights(models[k], k, i)

        if i % args.i_video == 0 and i > 0:

            rgbs, disps = render_path(
                render_poses, hwf, args.chunk, render_kwargs_test, ndc=args.NDC)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(
                basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)), fps=30, quality=8)

            if args.use_viewdirs:
                render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4]
                rgbs_still, _ = render_path(
                    render_poses, hwf, args.chunk, render_kwargs_test, ndc=args.NDC)
                render_kwargs_test['c2w_staticcam'] = None
                imageio.mimwrite(moviebase + 'rgb_still.mp4',
                                 to8b(rgbs_still), fps=30, quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(
                basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            render_path(poses[i_test], hwf, args.chunk, render_kwargs_test,
                        gt_imgs=images[i_test], savedir=testsavedir, ndc=args.NDC)
            print('Saved test set')

        if i % args.i_print == 0 or i < 10:

            print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))
            with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
                tf.contrib.summary.scalar('loss', loss)
                tf.contrib.summary.scalar('psnr', psnr)
                tf.contrib.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.contrib.summary.scalar('psnr0', psnr0)

            if i % args.i_img == 0:

                # Log a rendered validation view to Tensorboard
                img_i = np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3, :4]

                rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,ndc=args.NDC,
                                                **render_kwargs_test)

                psnr = mse2psnr(img2mse(rgb, target))
                
                # Save out the validation image for Tensorboard-free monitoring
                testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')
                if i==0:
                    os.makedirs(testimgdir, exist_ok=True)
                imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(rgb))

                with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):

                    tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
                    tf.contrib.summary.image(
                        'disp', disp[tf.newaxis, ..., tf.newaxis])
                    tf.contrib.summary.image(
                        'acc', acc[tf.newaxis, ..., tf.newaxis])

                    tf.contrib.summary.scalar('psnr_holdout', psnr)
                    tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])

                if args.N_importance > 0:

                    with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                        tf.contrib.summary.image(
                            'rgb0', to8b(extras['rgb0'])[tf.newaxis])
                        tf.contrib.summary.image(
                            'disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis])
                        tf.contrib.summary.image(
                            'z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis])

        global_step.assign_add(1)