Ejemplo n.º 1
0
def fit_uv_mesh(initial_mesh: dict,
                target_dataset,
                max_iterations: int = 5000,
                resolution: int = 4,
                log_interval: int = 10,
                dispaly_interval=1000,
                display_res=512,
                out_dir=None,
                mp4save_interval=None):
    glctx = dr.RasterizeGLContext()

    r_rot = util.random_rotation_translation(0.25)

    # Smooth rotation for display.
    ang = 0.0
    a_rot = np.matmul(util.rotate_x(-0.4), util.rotate_y(ang))
    dist = 2

    # Modelview and modelview + projection matrices.
    proj = util.projection(x=0.4, n=1.0, f=200.0)
    r_mv = np.matmul(util.translate(0, 0, -1.5 - dist), r_rot)
    r_mvp = np.matmul(proj, r_mv).astype(np.float32)
    a_mv = np.matmul(util.translate(0, 0, -3.5), a_rot)
    a_mvp = np.matmul(proj, a_mv).astype(np.float32)

    pos_idx = initial_mesh['pos_idx'].cuda()
    vtx_pos = initial_mesh['vtx_pos'].cuda()
    tex = np.ones((1024, 1024, 3), dtype=np.float32) / 2

    uv, uv_idx = init_uv()
    uv_idx = uv_idx[:pos_idx.shape[0]]
    pos_idx = torch.from_numpy(pos_idx.astype(np.int32)).cuda()
    vtx_pos = torch.from_numpy(pos.astype(np.float32)).cuda()
    uv_idx = torch.from_numpy(uv_idx.astype(np.int32)).cuda()
    vtx_uv = torch.from_numpy(uv.astype(np.float32)).cuda()
    tex = torch.from_numpy(tex.astype(np.float32)).cuda()

    # Render reference and optimized frames. Always enable mipmapping for reference.
    color = render(glctx, r_mvp, vtx_pos, pos_idx, vtx_uv, uv_idx, tex, 1024,
                   False, 0)
    Image.fromarray((color[0].detach().cpu().numpy() * 255).astype(
        np.uint8)).save('test.png')
Ejemplo n.º 2
0
def fit_mesh(initial_mesh: dict,
             target_dataset_dir: str,
             max_iterations: int = 10000,
             resolution: int = 256,
             log_interval: int = 1000,
             display_interval=None,
             display_res=512,
             out_dir=None,
             mp4save_interval=None):

    distance = 3

    target_dataset = util.ReferenceImages(target_dataset_dir, resolution,
                                          resolution)

    pos_idx = torch.from_numpy(initial_mesh['pos_idx'].astype(np.int32))
    vtx_pos = torch.from_numpy(initial_mesh['vtx_pos'].astype(np.float32))

    laplace = util.compute_laplace_matrix(vtx_pos, pos_idx).cuda()
    pos_idx = pos_idx.cuda()
    vtx_pos = vtx_pos.cuda()

    init_rot = util.rotate_z(-math.pi / 2).cuda()
    vtx_pos = transform_pos(init_rot, vtx_pos)[0][:, 0:3]
    vtx_pos.requires_grad = True

    uv, uv_idx = init_uv()
    uv_idx = uv_idx[:pos_idx.shape[0]]
    uv_idx = torch.from_numpy(uv_idx.astype(np.int32)).cuda()
    vtx_uv = torch.from_numpy(uv.astype(np.float32)).cuda()
    vtx_uv.requires_grad = True

    #col_idx  = torch.from_numpy(initial_mesh['col_idx'].astype(np.int32)).cuda()
    #vtx_col  = initial_mesh['vtx_col'].cuda()
    tex = torch.ones((1024, 1024, 3)).float() / 2
    tex = tex.cuda()
    tex.requires_grad = True

    glctx = dr.RasterizeGLContext()

    M1 = torch.eye(len(target_dataset)).cuda()
    M1.requires_grad = True
    M2 = torch.eye(len(target_dataset)).cuda()
    M2.requires_grad = True

    #M3 = torch.zeros((3, vtx_pos.shape[0], len(target_dataset))).cuda()
    M3 = torch.zeros((3 * vtx_pos.shape[0], len(target_dataset))).cuda()
    M3.requires_grad = True

    lr_ramp = .1
    params = [{
        'params': [M1, M2, M3],
        'lr': 1e-3
    }, {
        'params': tex,
        'lr': 1e-2
    }]
    #lambdas = [lambda x: max(0.01, 10**(-x*0.0005)), lambda x: lr_ramp**(float(x)/float(max_iterations))]

    optimizer = torch.optim.Adam(params)
    #scheduler    = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas)

    total_steps = 0

    for i in range(max_iterations):
        for j, (img, angle) in enumerate(target_dataset):
            img = img.cuda().permute(2, 1, 0)

            frame_tensor = torch.zeros(len(target_dataset))
            frame_tensor[j] = 1
            frame_tensor = frame_tensor.cuda()
            frame_tensor.requires_grad = True

            deltas = torch.matmul(
                M3, torch.matmul(M2, torch.matmul(M1,
                                                  frame_tensor))).flatten()
            #deformed_vtxs = vtx_pos + deltas.T
            deformed_vtxs = (vtx_pos.flatten() + deltas).reshape(
                (vtx_pos.shape[0], 3))

            # create the model-view-projection matrix
            # rotate model about z axis by angle
            #rot = util.rotate_y(angle)
            rot = torch.eye(4)
            # translate by distance
            tr = util.translate(z=-distance)
            # perspective projection
            proj = util.projection(x=0.4)

            mtx = proj.matmul(tr.matmul(rot)).cuda()
            mtx.requires_grad = True

            estimate = render(glctx,
                              mtx,
                              deformed_vtxs,
                              pos_idx,
                              vtx_uv,
                              uv_idx,
                              tex,
                              resolution,
                              enable_mip=False,
                              max_mip_level=4)[0]

            # compute loss
            loss = torch.mean((estimate - img)**2)

            # compute regularizer
            reg = torch.mean((util.compute_curvature(deformed_vtxs, laplace) -
                              util.compute_curvature(vtx_pos, laplace))**2)

            # combine
            loss = 5 * loss + 0 * reg

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step()

            with torch.no_grad():
                # clamp texture between 0 and 1
                tex.clamp_(0, 1)

            if (display_interval and
                (i % display_interval == 0)) or (i == max_iterations - 1):
                with torch.no_grad():
                    estimate = render(
                        glctx,
                        mtx,
                        deformed_vtxs,
                        pos_idx,
                        vtx_uv,
                        uv_idx,
                        tex,
                        resolution,
                        enable_mip=True,
                        max_mip_level=4)[0].detach().cpu().numpy()
                    plt.imshow(estimate)
                    plt.show()
                    plt.imshow(img.detach().cpu().numpy())
                    plt.show()

            if log_interval and i % log_interval == 0:
                print(f"Loss: {loss}")
                print(M1.grad)

    with torch.no_grad():
        for i, (im, _) in enumerate(target_dataset):
            frame_tensor = torch.zeros(len(target_dataset))
            frame_tensor[j] = 1
            frame_tensor = frame_tensor.cuda()

            deltas = torch.matmul(
                M3, torch.matmul(M2, torch.matmul(M1,
                                                  frame_tensor))).flatten()
            deformed_vtxs = (vtx_pos.flatten() + deltas).reshape(
                (vtx_pos.shape[0], 3))

            write_obj(f"frame_{i}.obj",
                      deformed_vtxs.detach().cpu().tolist(),
                      pos_idx.detach().cpu().tolist())
    Image.fromarray((tex.detach().cpu().numpy() * 255).astype(
        np.uint8)).save('diff_render_tex.png')
    print("Outputted texture to diff_render_tex.png")
Ejemplo n.º 3
0
def fit_earth(max_iter=20000,
              log_interval=10,
              display_interval=None,
              display_res=1024,
              enable_mip=True,
              res=512,
              ref_res=4096,
              lr_base=1e-2,
              lr_ramp=0.1,
              out_dir=None,
              log_fn=None,
              texsave_interval=None,
              texsave_fn=None,
              imgsave_interval=None,
              imgsave_fn=None):

    log_file = None
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
        if log_fn:
            log_file = open(out_dir + '/' + log_fn, 'wt')
    else:
        imgsave_interval, texsave_interval = None, None

    # Mesh and texture adapted from "3D Earth Photorealistic 2K" model at
    # https://www.turbosquid.com/3d-models/3d-realistic-earth-photorealistic-2k-1279125
    datadir = f'{pathlib.Path(__file__).absolute().parents[1]}/data'
    with np.load(f'{datadir}/earth.npz') as f:
        pos_idx, pos, uv_idx, uv, tex = f.values()
    tex = tex.astype(np.float32) / 255.0
    max_mip_level = 9  # Texture is a 4x3 atlas of 512x512 maps.
    print("Mesh has %d triangles and %d vertices." %
          (pos_idx.shape[0], pos.shape[0]))

    # Some input geometry contains vertex positions in (N, 4) (with v[:,3]==1).  Drop
    # the last column in that case.
    if pos.shape[1] == 4: pos = pos[:, 0:3]

    # Create position/triangle index tensors
    pos_idx = torch.from_numpy(pos_idx.astype(np.int32)).cuda()
    vtx_pos = torch.from_numpy(pos.astype(np.float32)).cuda()
    uv_idx = torch.from_numpy(uv_idx.astype(np.int32)).cuda()
    vtx_uv = torch.from_numpy(uv.astype(np.float32)).cuda()

    tex = torch.from_numpy(tex.astype(np.float32)).cuda()
    tex_opt = torch.full(tex.shape, 0.2, device='cuda', requires_grad=True)
    glctx = dr.RasterizeGLContext()

    ang = 0.0

    # Adam optimizer for texture with a learning rate ramp.
    optimizer = torch.optim.Adam([tex_opt], lr=lr_base)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda x: lr_ramp**(float(x) / float(max_iter)))

    # Render.
    ang = 0.0
    texloss_avg = []
    for it in range(max_iter + 1):
        # Random rotation/translation matrix for optimization.
        r_rot = util.random_rotation_translation(0.25)

        # Smooth rotation for display.
        a_rot = np.matmul(util.rotate_x(-0.4), util.rotate_y(ang))
        dist = np.random.uniform(0.0, 48.5)

        # Modelview and modelview + projection matrices.
        proj = util.projection(x=0.4, n=1.0, f=200.0)
        r_mv = np.matmul(util.translate(0, 0, -1.5 - dist), r_rot)
        r_mvp = np.matmul(proj, r_mv).astype(np.float32)
        a_mv = np.matmul(util.translate(0, 0, -3.5), a_rot)
        a_mvp = np.matmul(proj, a_mv).astype(np.float32)

        # Measure texture-space RMSE loss
        with torch.no_grad():
            texmask = torch.zeros_like(tex)
            tr = tex.shape[1] // 4
            texmask[tr + 13:2 * tr - 13, 25:-25, :] += 1.0
            texmask[25:-25, tr + 13:2 * tr - 13, :] += 1.0
            # Measure only relevant portions of texture when calculating texture
            # PSNR.
            texloss = (torch.sum(texmask * (tex - tex_opt)**2) /
                       torch.sum(texmask))**0.5  # RMSE within masked area.
            texloss_avg.append(float(texloss))

        # Render reference and optimized frames. Always enable mipmapping for reference.
        color = render(glctx, r_mvp, vtx_pos, pos_idx, vtx_uv, uv_idx, tex,
                       ref_res, True, max_mip_level)
        color_opt = render(glctx, r_mvp, vtx_pos, pos_idx, vtx_uv, uv_idx,
                           tex_opt, res, enable_mip, max_mip_level)

        # Reduce the reference to correct size.
        while color.shape[1] > res:
            color = util.bilinear_downsample(color)

        # Compute loss and perform a training step.
        loss = torch.mean((color - color_opt)**2)  # L2 pixel loss.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Print/save log.
        if log_interval and (it % log_interval == 0):
            texloss_val = np.mean(np.asarray(texloss_avg))
            texloss_avg = []
            psnr = -10.0 * np.log10(texloss_val**
                                    2)  # PSNR based on average RMSE.
            s = "iter=%d,loss=%f,psnr=%f" % (it, texloss_val, psnr)
            print(s)
            if log_file:
                log_file.write(s + '\n')

        # Show/save image.
        display_image = display_interval and (it % display_interval == 0)
        save_image = imgsave_interval and (it % imgsave_interval == 0)
        save_texture = texsave_interval and (it % texsave_interval) == 0

        if display_image or save_image:
            ang = ang + 0.1

            with torch.no_grad():
                result_image = render(glctx, a_mvp, vtx_pos, pos_idx, vtx_uv,
                                      uv_idx, tex_opt, res, enable_mip,
                                      max_mip_level)[0].cpu().numpy()

                if display_image:
                    util.display_image(result_image,
                                       size=display_res,
                                       title='%d / %d' % (it, max_iter))
                if save_image:
                    util.save_image(out_dir + '/' + (imgsave_fn % it),
                                    result_image)

                if save_texture:
                    texture = tex_opt.cpu().numpy()[::-1]
                    util.save_image(out_dir + '/' + (texsave_fn % it), texture)

    # Done.
    if log_file:
        log_file.close()
Ejemplo n.º 4
0
def fit_pose(max_iter=10000,
             repeats=1,
             log_interval=10,
             display_interval=None,
             display_res=512,
             lr_base=0.01,
             lr_falloff=1.0,
             nr_base=1.0,
             nr_falloff=1e-4,
             grad_phase_start=0.5,
             resolution=256,
             out_dir=None,
             log_fn=None,
             mp4save_interval=None,
             mp4save_fn=None):

    log_file = None
    writer = None
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
        if log_fn:
            log_file = open(out_dir + '/' + log_fn, 'wt')
        if mp4save_interval != 0:
            writer = imageio.get_writer(f'{out_dir}/{mp4save_fn}',
                                        mode='I',
                                        fps=30,
                                        codec='libx264',
                                        bitrate='16M')
    else:
        mp4save_interval = None

    datadir = f'{pathlib.Path(__file__).absolute().parents[1]}/data'
    with np.load(f'{datadir}/cube_p.npz') as f:
        pos_idx, pos, col_idx, col = f.values()
    print("Mesh has %d triangles and %d vertices." %
          (pos_idx.shape[0], pos.shape[0]))

    # Some input geometry contains vertex positions in (N, 4) (with v[:,3]==1).  Drop
    # the last column in that case.
    if pos.shape[1] == 4: pos = pos[:, 0:3]

    # Create position/triangle index tensors
    pos_idx = torch.from_numpy(pos_idx.astype(np.int32)).cuda()
    vtx_pos = torch.from_numpy(pos.astype(np.float32)).cuda()
    col_idx = torch.from_numpy(col_idx.astype(np.int32)).cuda()
    vtx_col = torch.from_numpy(col.astype(np.float32)).cuda()

    glctx = dr.RasterizeGLContext()

    for rep in range(repeats):
        pose_target = torch.tensor(q_rnd(), device='cuda')
        pose_init = q_rnd()
        pose_opt = torch.tensor(pose_init / np.sum(pose_init**2)**0.5,
                                dtype=torch.float32,
                                device='cuda',
                                requires_grad=True)

        loss_best = np.inf
        pose_best = pose_opt.detach().clone()

        # Modelview + projection matrix.
        mvp = torch.tensor(np.matmul(util.projection(x=0.4),
                                     util.translate(0, 0,
                                                    -3.5)).astype(np.float32),
                           device='cuda')

        # Adam optimizer for texture with a learning rate ramp.
        optimizer = torch.optim.Adam([pose_opt],
                                     betas=(0.9, 0.999),
                                     lr=lr_base)
        # Render.
        for it in range(max_iter + 1):
            # Set learning rate.
            itf = 1.0 * it / max_iter
            nr = nr_base * nr_falloff**itf
            lr = lr_base * lr_falloff**itf
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # Noise input.
            if itf >= grad_phase_start:
                noise = q_unit()
            else:
                noise = q_scale(q_rnd(), nr)
                noise = q_mul(noise, q_rnd_S4())  # Orientation noise.

            # Render.
            color = render(glctx, torch.matmul(mvp, q_to_mtx(pose_target)),
                           vtx_pos, pos_idx, vtx_col, col_idx, resolution)
            pose_total_opt = q_mul_torch(pose_opt, noise)
            mtx_total_opt = torch.matmul(mvp, q_to_mtx(pose_total_opt))
            color_opt = render(glctx, mtx_total_opt, vtx_pos, pos_idx, vtx_col,
                               col_idx, resolution)

            # Image-space loss.
            diff = (color_opt - color)**2  # L2 norm.
            diff = torch.tanh(5.0 * torch.max(diff, dim=-1)[0])
            loss = torch.mean(diff)

            # Measure image-space loss and update best found pose.
            loss_val = float(loss)
            if (loss_val < loss_best) and (loss_val > 0.0):
                pose_best = pose_total_opt.detach().clone()
                loss_best = loss_val
                if itf < grad_phase_start:
                    with torch.no_grad():
                        pose_opt[:] = pose_best

            # Print/save log.
            if log_interval and (it % log_interval == 0):
                err = q_angle_deg(pose_opt, pose_target)
                ebest = q_angle_deg(pose_best, pose_target)
                s = "rep=%d,iter=%d,err=%f,err_best=%f,loss=%f,loss_best=%f,lr=%f,nr=%f" % (
                    rep, it, err, ebest, loss_val, loss_best, lr, nr)
                print(s)
                if log_file:
                    log_file.write(s + "\n")

            # Run gradient training step.
            if itf >= grad_phase_start:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            with torch.no_grad():
                pose_opt /= torch.sum(pose_opt**2)**0.5

            # Show/save image.
            display_image = display_interval and (it % display_interval == 0)
            save_mp4 = mp4save_interval and (it % mp4save_interval == 0)

            if display_image or save_mp4:
                c = color[0].detach().cpu().numpy()
                img_ref = color[0].detach().cpu().numpy()
                img_opt = color_opt[0].detach().cpu().numpy()
                img_best = render(glctx, torch.matmul(mvp,
                                                      q_to_mtx(pose_best)),
                                  vtx_pos, pos_idx, vtx_col, col_idx,
                                  resolution)[0].detach().cpu().numpy()
                result_image = np.concatenate([img_ref, img_best, img_opt],
                                              axis=1)

                if display_image:
                    util.display_image(result_image,
                                       size=display_res,
                                       title='(%d) %d / %d' %
                                       (rep, it, max_iter))
                if save_mp4:
                    writer.append_data(
                        np.clip(np.rint(result_image * 255.0), 0,
                                255).astype(np.uint8))

    # Done.
    if writer is not None:
        writer.close()
    if log_file:
        log_file.close()
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import imageio
import numpy as np
import torch
import nvdiffrast.torch as dr


def tensor(*args, **kwargs):
    return torch.tensor(*args, device='cuda', **kwargs)


pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]],
             dtype=torch.float32)
col = tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=torch.float32)
tri = tensor([[0, 1, 2]], dtype=torch.int32)

glctx = dr.RasterizeGLContext()
rast, _ = dr.rasterize(glctx, pos, tri, resolution=[256, 256])
out, _ = dr.interpolate(col, rast, tri)

img = out.cpu().numpy()[0, ::-1, :, :]  # Flip vertically.
img = np.clip(np.rint(img * 255), 0,
              255).astype(np.uint8)  # Quantize to np.uint8

print("Saving to 'tri.png'.")
imageio.imsave('tri.png', img)
def fit_mesh_col(
    initial_mesh: dict,
    target_dataset_dir: str,
    max_iterations: int = 10000,
    resolution: int = 256,
    log_interval: int = None,
    display_interval = None,
    display_res = 512,
    out_dir = None,
    mp4save_interval = None
    ):

    distance = 3

    target_dataset = util.ReferenceImages(target_dataset_dir, resolution, resolution)

    pos_idx = torch.from_numpy(initial_mesh['pos_idx'].astype(np.int32))
    vtx_pos = torch.from_numpy(initial_mesh['vtx_pos'].astype(np.float32))

    laplace = util.compute_laplace_matrix(vtx_pos, pos_idx).cuda()
    pos_idx = pos_idx.cuda()
    vtx_pos = vtx_pos.cuda()

    init_rot = util.rotate_z(-math.pi/2).cuda()
    vtx_pos = transform_pos(init_rot, vtx_pos)[0][:, 0:3]
    vtx_pos.requires_grad = True

    col_idx  = torch.from_numpy(initial_mesh['pos_idx'].astype(np.int32)).cuda()
    vtx_col  = torch.ones_like(vtx_pos) * 0.5
    vtx_col.requires_grad = True

    glctx = dr.RasterizeGLContext()


    M1 = torch.eye(len(target_dataset)).cuda()
    M1.requires_grad = True
    M2 = torch.eye(len(target_dataset)).cuda()
    M2.requires_grad = True

    #M3 = torch.zeros((3, vtx_pos.shape[0], len(target_dataset))).cuda()
    M3 = torch.zeros((3 * vtx_pos.shape[0], len(target_dataset))).cuda()
    M3.requires_grad = True

    lr_ramp = .1
    params = [{'params': [M1, M2, M3], 'lr': 1e-3}, {'params': vtx_col, 'lr': 1e-2}]
    # params = [{'params': vtx_col, 'lr': 1e-2}]
    #lambdas = [lambda x: max(0.01, 10**(-x*0.0005)), lambda x: lr_ramp**(float(x)/float(max_iterations))]


    optimizer    = torch.optim.Adam(params)
    #scheduler    = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas)

    total_steps = 0

    loss_hist, l2_hist, reg_hist = [], [], []

    for i in range(max_iterations):
        for j, (img, angle) in enumerate(target_dataset):
            img = img.cuda().permute(2,1,0)

            frame_tensor = torch.zeros(len(target_dataset))
            frame_tensor[j] = 1
            frame_tensor = frame_tensor.cuda()
            frame_tensor.requires_grad = True

            deltas = torch.matmul(M3, torch.matmul(M2, torch.matmul(M1, frame_tensor))).flatten()
            #deformed_vtxs = vtx_pos + deltas.T
            deformed_vtxs = (vtx_pos.flatten() + deltas).reshape((vtx_pos.shape[0], 3))

            # create the model-view-projection matrix
            # rotate model about z axis by angle
            rot = util.rotate_y(angle)
            #rot = torch.eye(4)
            # translate by distance
            tr = util.translate(z=-distance)
            # perspective projection
            proj = util.projection(x=0.4)

            mtx = proj.matmul(tr.matmul(rot)).cuda()
            mtx.requires_grad = True

            estimate = render(glctx, mtx, deformed_vtxs, pos_idx, col_idx, vtx_col, resolution)[0]

            # compute loss
            loss = torch.mean((estimate - img) ** 2)

            # compute regularizer
            reg = torch.mean((util.compute_curvature(deformed_vtxs, laplace) - util.compute_curvature(vtx_pos, laplace)) ** 2) + torch.mean(deltas**2)
            
            # combine
            loss = loss + 5 * reg

            loss_hist.append(loss.cpu().numpy())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step()

            with torch.no_grad():
                #print(f"Loss: {loss}")
                # clamp color between 0 and 1
                vtx_col.clamp_(0, 1)

            if (display_interval and (i % display_interval == 0)) or (i == max_iterations - 1):
                print(loss)
                with torch.no_grad():
                    estimate = render(glctx, mtx, deformed_vtxs, pos_idx, col_idx, vtx_col, resolution)[0].detach().cpu().numpy()
                    Image.fromarray((estimate * 255).astype(np.uint8)).save('estimate.png')
                    img = img.detach().cpu().numpy()
                    Image.fromarray((img * 255).astype(np.uint8)).save('img.png')


    with torch.no_grad():
        for i, (im, _) in enumerate(target_dataset):
            frame_tensor = torch.zeros(len(target_dataset))
            frame_tensor[j] = 1
            frame_tensor = frame_tensor.cuda()

            deltas = torch.matmul(M3, torch.matmul(M2, torch.matmul(M1, frame_tensor))).flatten()
            deformed_vtxs = (vtx_pos.flatten() + deltas).reshape((vtx_pos.shape[0], 3))
            deformed_vtxs = torch.clamp(deformed_vtxs, -1.0, 1.0)

            #write_obj(f"frame_{i}.obj", deformed_vtxs.detach().cpu().tolist(), pos_idx.detach().cpu().tolist())
            util.write_obj(f"frame_{i}.obj", deformed_vtxs.detach().cpu().tolist(), pos_idx.detach().cpu().tolist(), vtx_col.detach().cpu().tolist())

    np.savez('vtx_col.npz', vtx_col=vtx_col.cpu().detach().numpy())
def fit_cube(max_iter=5000,
             resolution=4,
             discontinuous=False,
             repeats=1,
             log_interval=10,
             display_interval=None,
             display_res=512,
             out_dir=None,
             log_fn=None,
             mp4save_interval=None,
             mp4save_fn=None):

    log_file = None
    writer = None
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
        if log_fn:
            log_file = open(f'{out_dir}/{log_fn}', 'wt')
        if mp4save_interval != 0:
            writer = imageio.get_writer(f'{out_dir}/{mp4save_fn}',
                                        mode='I',
                                        fps=30,
                                        codec='libx264',
                                        bitrate='16M')
    else:
        mp4save_interval = None

    datadir = f'{pathlib.Path(__file__).absolute().parents[1]}/data'
    fn = 'cube_%s.npz' % ('d' if discontinuous else 'c')
    with np.load(f'{datadir}/{fn}') as f:
        pos_idx, vtxp, col_idx, vtxc = f.values()
    print("Mesh has %d triangles and %d vertices." %
          (pos_idx.shape[0], vtxp.shape[0]))

    # Create position/triangle index tensors
    pos_idx = torch.from_numpy(pos_idx.astype(np.int32)).cuda()
    col_idx = torch.from_numpy(col_idx.astype(np.int32)).cuda()
    vtx_pos = torch.from_numpy(vtxp.astype(np.float32)).cuda()
    vtx_col = torch.from_numpy(vtxc.astype(np.float32)).cuda()

    glctx = dr.RasterizeGLContext()

    # Repeats.
    for rep in range(repeats):

        ang = 0.0
        gl_avg = []

        vtx_pos_rand = np.random.uniform(-0.5, 0.5, size=vtxp.shape) + vtxp
        vtx_col_rand = np.random.uniform(0.0, 1.0, size=vtxc.shape)
        vtx_pos_opt = torch.tensor(vtx_pos_rand,
                                   dtype=torch.float32,
                                   device='cuda',
                                   requires_grad=True)
        vtx_col_opt = torch.tensor(vtx_col_rand,
                                   dtype=torch.float32,
                                   device='cuda',
                                   requires_grad=True)

        # Adam optimizer for vertex position and color with a learning rate ramp.
        optimizer = torch.optim.Adam([vtx_pos_opt, vtx_col_opt], lr=1e-2)
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda x: max(0.01, 10**(-x * 0.0005)))

        for it in range(max_iter + 1):
            # Random rotation/translation matrix for optimization.
            r_rot = util.random_rotation_translation(0.25)

            # Smooth rotation for display.
            a_rot = np.matmul(util.rotate_x(-0.4), util.rotate_y(ang))

            # Modelview and modelview + projection matrices.
            proj = util.projection(x=0.4)
            r_mv = np.matmul(util.translate(0, 0, -3.5), r_rot)
            r_mvp = np.matmul(proj, r_mv).astype(np.float32)
            a_mv = np.matmul(util.translate(0, 0, -3.5), a_rot)
            a_mvp = np.matmul(proj, a_mv).astype(np.float32)

            # Compute geometric error for logging.
            with torch.no_grad():
                geom_loss = torch.mean(
                    torch.sum((torch.abs(vtx_pos_opt) - .5)**2, dim=1)**0.5)
                gl_avg.append(float(geom_loss))

            # Print/save log.
            if log_interval and (it % log_interval == 0):
                gl_val = np.mean(np.asarray(gl_avg))
                gl_avg = []
                s = ("rep=%d," % rep) if repeats > 1 else ""
                s += "iter=%d,err=%f" % (it, gl_val)
                print(s)
                if log_file:
                    log_file.write(s + "\n")

            color = render(glctx, r_mvp, vtx_pos, pos_idx, vtx_col, col_idx,
                           resolution)
            color_opt = render(glctx, r_mvp, vtx_pos_opt, pos_idx, vtx_col_opt,
                               col_idx, resolution)

            # Compute loss and train.
            loss = torch.mean((color - color_opt)**2)  # L2 pixel loss.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Show/save image.
            display_image = display_interval and (it % display_interval == 0)
            save_mp4 = mp4save_interval and (it % mp4save_interval == 0)

            if display_image or save_mp4:
                ang = ang + 0.01

                img_b = color[0].cpu().numpy()
                img_o = color_opt[0].detach().cpu().numpy()
                img_d = render(glctx, a_mvp, vtx_pos_opt, pos_idx, vtx_col_opt,
                               col_idx, display_res)[0]
                img_r = render(glctx, a_mvp, vtx_pos, pos_idx, vtx_col,
                               col_idx, display_res)[0]

                scl = display_res // img_o.shape[0]
                img_b = np.repeat(np.repeat(img_b, scl, axis=0), scl, axis=1)
                img_o = np.repeat(np.repeat(img_o, scl, axis=0), scl, axis=1)
                result_image = make_grid(
                    np.stack([
                        img_o, img_b,
                        img_d.detach().cpu().numpy(),
                        img_r.cpu().numpy()
                    ]))

                if display_image:
                    util.display_image(result_image,
                                       size=display_res,
                                       title='%d / %d' % (it, max_iter))
                if save_mp4:
                    writer.append_data(
                        np.clip(np.rint(result_image * 255.0), 0,
                                255).astype(np.uint8))

    # Done.
    if writer is not None:
        writer.close()
    if log_file:
        log_file.close()
Ejemplo n.º 8
0
def fit_env_phong(max_iter          = 1000,
                  log_interval      = 10,
                  display_interval  = None,
                  display_res       = 1024,
                  res               = 1024,
                  lr_base           = 1e-2,
                  lr_ramp           = 1.0,
                  out_dir           = None,
                  log_fn            = None,
                  mp4save_interval  = None,
                  mp4save_fn        = None):

    log_file = None
    writer = None
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
        if log_fn:
            log_file = open(out_dir + '/' + log_fn, 'wt')
        if mp4save_interval != 0:
            writer = imageio.get_writer(f'{out_dir}/{mp4save_fn}', mode='I', fps=30, codec='libx264', bitrate='16M')
    else:
        mp4save_interval = None

    # Texture adapted from https://github.com/WaveEngine/Samples/tree/master/Materials/EnvironmentMap/Content/Assets/CubeMap.cubemap
    datadir = f'{pathlib.Path(__file__).absolute().parents[1]}/data'
    with np.load(f'{datadir}/envphong.npz') as f:
        pos_idx, pos, normals, env = f.values()
    env = env.astype(np.float32)/255.0
    env = np.stack(env)[:, ::-1].copy()
    print("Mesh has %d triangles and %d vertices." % (pos_idx.shape[0], pos.shape[0]))

    # Move all the stuff to GPU.
    pos_idx = torch.as_tensor(pos_idx, dtype=torch.int32, device='cuda')
    pos = torch.as_tensor(pos, dtype=torch.float32, device='cuda')
    normals = torch.as_tensor(normals, dtype=torch.float32, device='cuda')
    env = torch.as_tensor(env, dtype=torch.float32, device='cuda')

    # Target Phong parameters.
    phong_rgb = np.asarray([1.0, 0.8, 0.6], np.float32)
    phong_exp = 25.0
    phong_rgb_t = torch.as_tensor(phong_rgb, dtype=torch.float32, device='cuda')

    # Learned variables: environment maps, phong color, phong exponent.
    env_var = torch.ones_like(env) * .5
    env_var.requires_grad_()
    phong_var_raw = torch.as_tensor(np.random.uniform(size=[4]), dtype=torch.float32, device='cuda')
    phong_var_raw.requires_grad_()
    phong_var_mul = torch.as_tensor([1.0, 1.0, 1.0, 10.0], dtype=torch.float32, device='cuda')

    # Render.
    ang = 0.0
    imgloss_avg, phong_avg = [], []
    glctx = dr.RasterizeGLContext()
    zero_tensor = torch.as_tensor(0.0, dtype=torch.float32, device='cuda')
    one_tensor = torch.as_tensor(1.0, dtype=torch.float32, device='cuda')

    # Adam optimizer for environment map and phong with a learning rate ramp.
    optimizer = torch.optim.Adam([env_var, phong_var_raw], lr=lr_base)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_ramp**(float(x)/float(max_iter)))

    for it in range(max_iter + 1):
        phong_var = phong_var_raw * phong_var_mul

        # Random rotation/translation matrix for optimization.
        r_rot = util.random_rotation_translation(0.25)

        # Smooth rotation for display.
        ang = ang + 0.01
        a_rot = np.matmul(util.rotate_x(-0.4), util.rotate_y(ang))

        # Modelview and modelview + projection matrices.
        proj  = util.projection(x=0.4, n=1.0, f=200.0)
        r_mv  = np.matmul(util.translate(0, 0, -3.5), r_rot)
        r_mvp = np.matmul(proj, r_mv).astype(np.float32)
        a_mv  = np.matmul(util.translate(0, 0, -3.5), a_rot)
        a_mvp = np.matmul(proj, a_mv).astype(np.float32)
        a_mvc = a_mvp
        r_mvp = torch.as_tensor(r_mvp, dtype=torch.float32, device='cuda')
        a_mvp = torch.as_tensor(a_mvp, dtype=torch.float32, device='cuda')

        # Solve camera positions.
        a_campos = torch.as_tensor(np.linalg.inv(a_mv)[:3, 3], dtype=torch.float32, device='cuda')
        r_campos = torch.as_tensor(np.linalg.inv(r_mv)[:3, 3], dtype=torch.float32, device='cuda')

        # Random light direction.        
        lightdir = np.random.normal(size=[3])
        lightdir /= np.linalg.norm(lightdir) + 1e-8
        lightdir = torch.as_tensor(lightdir, dtype=torch.float32, device='cuda')

        def render_refl(ldir, cpos, mvp):
            # Transform and rasterize.
            viewvec = pos[..., :3] - cpos[np.newaxis, np.newaxis, :] # View vectors at vertices.
            reflvec = viewvec - 2.0 * normals[np.newaxis, ...] * torch.sum(normals[np.newaxis, ...] * viewvec, -1, keepdim=True) # Reflection vectors at vertices.
            reflvec = reflvec / torch.sum(reflvec**2, -1, keepdim=True)**0.5 # Normalize.
            pos_clip = torch.matmul(pos, mvp.t())[np.newaxis, ...]
            rast_out, rast_out_db = dr.rasterize(glctx, pos_clip, pos_idx, [res, res])
            refl, refld = dr.interpolate(reflvec, rast_out, pos_idx, rast_db=rast_out_db, diff_attrs='all') # Interpolated reflection vectors.

            # Phong light.
            refl = refl / (torch.sum(refl**2, -1, keepdim=True) + 1e-8)**0.5  # Normalize.
            ldotr = torch.sum(-ldir * refl, -1, keepdim=True) # L dot R.

            # Return
            return refl, refld, ldotr, (rast_out[..., -1:] == 0)

        # Render the reflections.
        refl, refld, ldotr, mask = render_refl(lightdir, r_campos, r_mvp)

        # Reference color. No need for AA because we are not learning geometry.
        color = dr.texture(env[np.newaxis, ...], refl, uv_da=refld, filter_mode='linear-mipmap-linear', boundary_mode='cube')
        color = color + phong_rgb_t * torch.max(zero_tensor, ldotr) ** phong_exp # Phong.
        color = torch.where(mask, one_tensor, color) # White background.

        # Candidate rendering same up to this point, but uses learned texture and Phong parameters instead.
        color_opt = dr.texture(env_var[np.newaxis, ...], refl, uv_da=refld, filter_mode='linear-mipmap-linear', boundary_mode='cube')
        color_opt = color_opt + phong_var[:3] * torch.max(zero_tensor, ldotr) ** phong_var[3] # Phong.
        color_opt = torch.where(mask, one_tensor, color_opt) # White background.

        # Compute loss and train.
        loss = torch.mean((color - color_opt)**2) # L2 pixel loss.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Collect losses.
        imgloss_avg.append(loss.detach().cpu().numpy())
        phong_avg.append(phong_var.detach().cpu().numpy())

        # Print/save log.
        if log_interval and (it % log_interval == 0):
            imgloss_val, imgloss_avg = np.mean(np.asarray(imgloss_avg, np.float32)), []
            phong_val, phong_avg = np.mean(np.asarray(phong_avg, np.float32), axis=0), []
            phong_rgb_rmse = np.mean((phong_val[:3] - phong_rgb)**2)**0.5
            phong_exp_rel_err = np.abs(phong_val[3] - phong_exp)/phong_exp
            s = "iter=%d,phong_rgb_rmse=%f,phong_exp_rel_err=%f,img_rmse=%f" % (it, phong_rgb_rmse, phong_exp_rel_err, imgloss_val)
            print(s)
            if log_file:
                log_file.write(s + '\n')

        # Show/save result image.        
        display_image = display_interval and (it % display_interval == 0)
        save_mp4 = mp4save_interval and (it % mp4save_interval == 0)

        if display_image or save_mp4:
            lightdir = np.asarray([.8, -1., .5, 0.0])
            lightdir = np.matmul(a_mvc, lightdir)[:3]
            lightdir /= np.linalg.norm(lightdir)
            lightdir = torch.as_tensor(lightdir, dtype=torch.float32, device='cuda')
            refl, refld, ldotr, mask = render_refl(lightdir, a_campos, a_mvp)
            color_opt = dr.texture(env_var[np.newaxis, ...], refl, uv_da=refld, filter_mode='linear-mipmap-linear', boundary_mode='cube')
            color_opt = color_opt + phong_var[:3] * torch.max(zero_tensor, ldotr) ** phong_var[3]
            color_opt = torch.where(mask, one_tensor, color_opt)
            result_image = color_opt.detach()[0].cpu().numpy()
            if display_image:
                util.display_image(result_image, size=display_res, title='%d / %d' % (it, max_iter))
            if save_mp4:
                writer.append_data(np.clip(np.rint(result_image*255.0), 0, 255).astype(np.uint8))

    # Done.
    if writer is not None:
        writer.close()
    if log_file:
        log_file.close()