예제 #1
0
def test_time_optimize(args, model, meta_state_dict, tto_view):
    """
    quicky optimize the meta trained model to a target appearance
    and return the corresponding network weights
    """

    model.load_state_dict(meta_state_dict)
    optim = torch.optim.SGD(model.parameters(), args.tto_lr)

    pixels = tto_view['img'].reshape(-1, 3)
    rays_o, rays_d = get_rays_tourism(tto_view['H'], tto_view['W'],
                                      tto_view['kinv'], tto_view['pose'])
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

    num_rays = rays_d.shape[0]
    for step in range(args.tto_steps):
        indices = torch.randint(num_rays, size=[args.tto_batchsize])
        raybatch_o, raybatch_d = rays_o[indices], rays_d[indices]
        pixelbatch = pixels[indices]
        t_vals, xyz = sample_points(raybatch_o,
                                    raybatch_d,
                                    tto_view['bound'][0],
                                    tto_view['bound'][1],
                                    args.num_samples,
                                    perturb=True)

        optim.zero_grad()
        rgbs, sigmas = model(xyz)
        colors = volume_render(rgbs, sigmas, t_vals)
        loss = F.mse_loss(colors, pixelbatch)
        loss.backward()
        optim.step()

    state_dict = copy.deepcopy(model.state_dict())
    return state_dict
예제 #2
0
def val_meta(args, model, val_loader, device):
    """
    validate the meta trained model for phototourism
    """
    meta_trained_state = model.state_dict()
    val_model = copy.deepcopy(model)
    
    val_psnrs = []
    for img, pose, kinv, bound in val_loader:
        img, pose, kinv, bound = img.to(device), pose.to(device), kinv.to(device), bound.to(device)
        img, pose, kinv, bound = img.squeeze(), pose.squeeze(), kinv.squeeze(), bound.squeeze()
        rays_o, rays_d = get_rays_tourism(img.shape[0], img.shape[1], kinv, pose)
        
        # optimize on the left half, test on the right half
        left_width = img.shape[1]//2
        right_width = img.shape[1] - left_width
        tto_img, test_img = torch.split(img, [left_width, right_width], dim=1)
        tto_rays_o, test_rays_o = torch.split(rays_o, [left_width, right_width], dim=1)
        tto_rays_d, test_rays_d = torch.split(rays_d, [left_width, right_width], dim=1)

        val_model.load_state_dict(meta_trained_state)
        val_optim = torch.optim.SGD(val_model.parameters(), args.inner_lr)

        inner_loop(val_model, val_optim, tto_img, tto_rays_o, tto_rays_d,
                    bound, args.num_samples, args.train_batchsize, args.inner_steps)
        
        psnr = report_result(val_model, test_img, test_rays_o, test_rays_d, bound, 
                                args.num_samples, args.test_batchsize)
        val_psnrs.append(psnr)

    val_psnr = torch.stack(val_psnrs).mean()
    return val_psnr
예제 #3
0
def test():
    parser = argparse.ArgumentParser(description='phototourism with meta-learning')
    parser.add_argument('--config', type=str, required=True,
                        help='config file for the scene')
    parser.add_argument('--weight-path', type=str, required=True,
                        help='path to the meta-trained weight file')
    args = parser.parse_args()

    with open(args.config) as config:
        info = json.load(config)
        for key, value in info.items():
            args.__dict__[key] = value

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    test_set = build_tourism(image_set="test", args=args)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

    model = build_nerf(args)
    model.to(device)

    checkpoint = torch.load(args.weight_path, map_location=device)
    meta_state_dict = checkpoint['meta_model_state_dict']
    
    test_psnrs = []
    for idx, (img, pose, kinv, bound) in enumerate(test_loader):
        img, pose, kinv, bound = img.to(device), pose.to(device), kinv.to(device), bound.to(device)
        img, pose, kinv, bound = img.squeeze(), pose.squeeze(), kinv.squeeze(), bound.squeeze()
        rays_o, rays_d = get_rays_tourism(img.shape[0], img.shape[1], kinv, pose)
        
        # optimize on the left half, test on the right half
        left_width = img.shape[1]//2
        right_width = img.shape[1] - left_width
        tto_img, test_img = torch.split(img, [left_width, right_width], dim=1)
        tto_rays_o, test_rays_o = torch.split(rays_o, [left_width, right_width], dim=1)
        tto_rays_d, test_rays_d = torch.split(rays_d, [left_width, right_width], dim=1)

        model.load_state_dict(meta_state_dict)
        optim = torch.optim.SGD(model.parameters(), args.tto_lr)

        inner_loop(model, optim, tto_img, tto_rays_o, tto_rays_d,
                    bound, args.num_samples, args.tto_batchsize, args.tto_steps)
        
        psnr = report_result(model, test_img, test_rays_o, test_rays_d, bound, 
                            args.num_samples, args.test_batchsize)
        
        print(f"test view {idx+1}, psnr:{psnr:.3f}")
        test_psnrs.append(psnr)

    test_psnrs = torch.stack(test_psnrs)
    print("----------------------------------")
    print(f"test dataset mean psnr: {test_psnrs.mean():.3f}")

    print("\ncreating interpolation video ...\n")
    create_interpolation_video(args, model, meta_state_dict, test_set, device)
    print("\ninterpolation video created!")
예제 #4
0
def train_meta(args, meta_model, meta_optim, data_loader, device):
    """
    train the meta_model for one epoch using reptile meta learning
    https://arxiv.org/abs/1803.02999
    """
    for img, pose, kinv, bound in data_loader:
        img, pose, kinv, bound = img.to(device), pose.to(device), kinv.to(device), bound.to(device)
        img, pose, kinv, bound = img.squeeze(), pose.squeeze(), kinv.squeeze(), bound.squeeze()
        rays_o, rays_d = get_rays_tourism(img.shape[0], img.shape[1], kinv, pose)

        meta_optim.zero_grad()

        inner_model = copy.deepcopy(meta_model)
        inner_optim = torch.optim.SGD(inner_model.parameters(), args.inner_lr)

        inner_loop(inner_model, inner_optim, img, rays_o, rays_d, bound,
                    args.num_samples, args.train_batchsize, args.inner_steps)
        
        with torch.no_grad():
            for meta_param, inner_param in zip(meta_model.parameters(), inner_model.parameters()):
                meta_param.grad = meta_param - inner_param
        
        meta_optim.step()
예제 #5
0
def synthesize_view(args, model, H, W, kinv, pose, bound):
    """
    given camera intrinsics and camera pose, synthesize a novel view
    """
    rays_o, rays_d = get_rays_tourism(H, W, kinv, pose)
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

    t_vals, xyz = sample_points(rays_o,
                                rays_d,
                                bound[0],
                                bound[1],
                                args.num_samples,
                                perturb=False)

    synth = []
    num_rays = rays_d.shape[0]
    with torch.no_grad():
        for i in range(0, num_rays, args.test_batchsize):
            rgbs_batch, sigmas_batch = model(xyz[i:i + args.test_batchsize])
            color_batch = volume_render(rgbs_batch, sigmas_batch,
                                        t_vals[i:i + args.test_batchsize])
            synth.append(color_batch)
        synth = torch.cat(synth, dim=0).reshape(H, W, 3)
    return synth