def fit_closure():
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            _, landmarks_3D, flame_regularizer_loss = flamelayer()

            my_mesh = make_mesh(flamelayer, detach=False)
            obj1 = lmks_factor * lmks_fit_loss(landmarks_3D) + silh_factor * silh_fit_loss(my_mesh)
            obj = obj1 + flame_regularizer_loss
            print('obj - ', obj)
            if obj.requires_grad:
                obj.backward()
            return obj
def fit_flame_silhouette_perspectiv(flamelayer, cam, renderer, target_silh, optimizer, device, target_2d_lmks):
    torch_target_2d_lmks = torch.from_numpy(target_2d_lmks).cuda()
    torch_target_silh = torch.from_numpy(target_silh).cuda()
    factor = 1  # TODO what shoud factor be???

    def lmks_fit_loss(landmarks_3D):
        print(type(landmarks_3D))
        landmarks_2D = cam.transform_points(landmarks_3D)[:, :2]
        return flamelayer.weights['lmk'] * torch.sum(torch.sub(landmarks_2D, torch_target_2d_lmks) ** 2) / (factor ** 2)

    def silh_fit_loss(my_mesh):
        silhouette = renderer.render_sil(my_mesh).squeeze()[..., 3]
        return torch.sum(torch.sub(silhouette, torch_target_silh) ** 2) / (factor ** 2)

    def fit_closure():
        if torch.is_grad_enabled():
            optimizer.zero_grad()
        _, landmarks_3D, flame_regularizer_loss = flamelayer()

        my_mesh = make_mesh(flamelayer, detach=False)
        obj1 = lmks_fit_loss(landmarks_3D) + silh_fit_loss(my_mesh)
        obj = obj1 + flame_regularizer_loss
        print('obj - ', obj)
        if obj.requires_grad:
            obj.backward()
        return obj

    def log_obj(str):
        if FIT_2D_DEBUG_MODE:
            _, _, flame_regularizer_loss = flamelayer()
            my_mesh = make_mesh(flamelayer, )
            print(str + ' obj = ', lmks_fit_loss(my_mesh) + silh_fit_loss(my_mesh))

    def log(str):
        if FIT_2D_DEBUG_MODE:
            print(str)

    # log('Optimizing rigid transformation')
    # log_obj('Before optimization obj')
    # optimizer.step(fit_closure)
    # log_obj('After optimization obj')

    for i in range(200):
        optimizer.zero_grad()

        _, landmarks_3D, flame_regularizer_loss = flamelayer()

        my_mesh = make_mesh(flamelayer, detach=False)
        obj1 = lmks_fit_loss(landmarks_3D) + 1e-6 * silh_fit_loss(my_mesh)
        loss = obj1 + flame_regularizer_loss
        loss.backward()
        print(flamelayer.transl.grad)
        optimizer.step()
Beispiel #3
0
def plot_landmarks(renderer,
                   target_img,
                   target_lmks,
                   flamelayer,
                   cam,
                   device,
                   lmk_dist=0.0,
                   shape_reg=0.0,
                   exp_reg=0.0,
                   neck_pose_reg=0.0,
                   jaw_pose_reg=0.0,
                   eyeballs_pose_reg=0.0):
    if lmk_dist > 0.0 or shape_reg > 0.0 or exp_reg > 0.0 or neck_pose_reg > 0.0 or jaw_pose_reg > 0.0 or eyeballs_pose_reg > 0.0:
        print(
            'lmk_dist: %f, shape_reg: %f, exp_reg: %f, neck_pose_reg: %f, jaw_pose_reg: %f, eyeballs_pose_reg: %f'
            % (lmk_dist, shape_reg, exp_reg, neck_pose_reg, jaw_pose_reg,
               eyeballs_pose_reg))

    _, landmarks_3D, _ = flamelayer()
    optim_lmks = cam.transform_points(landmarks_3D)[:, :2]
    optim_lmks = optim_lmks.detach().cpu().numpy().squeeze()
    my_mesh = make_mesh(flamelayer, True)
    # transform coord system from [-1,1] to [n,m] of target img
    coord_transfromer = CoordTransformer(target_img.shape)
    # target lmks
    plt_target_lmks = target_lmks.copy()
    plt_target_lmks = coord_transfromer.cam2screen(plt_target_lmks)

    # model lmks
    plt_opt_lmks = optim_lmks.copy()
    plt_opt_lmks = coord_transfromer.cam2screen(plt_opt_lmks)

    for (x, y) in plt_target_lmks:
        cv2.circle(target_img, (int(x), int(y)), 4, (0, 0, 255), -1)

    for (x, y) in plt_opt_lmks:
        cv2.circle(target_img, (int(x), int(y)), 4, (255, 0, 0), -1)

    if sys.version_info >= (3, 0):
        # rendered_img = render_mesh(Mesh(scale * verts, faces), height=target_img.shape[0], width=target_img.shape[1])
        rendered_img = renderer.render_phong(my_mesh)
        rendered_img = rendered_img.detach().cpu().numpy().squeeze()
        rendered_img = cv2.resize(rendered_img,
                                  (target_img.shape[0], target_img.shape[1]))
        # rendered_img = cv2.UMat(np.array(rendered_img, dtype=np.uint8))
        for (x, y) in plt_opt_lmks:
            cv2.circle(rendered_img, (int(x), int(y)), 4, (0, 255, 0), -1)

        target_img = np.hstack((target_img / 255, rendered_img[:, :, :3]))

    cv2.imshow('target_img', target_img)
    cv2.waitKey()
    def optimize_Adam(self, optimizer, lmks_factor, silh_factor):
        for i in range(200):
            optimizer.zero_grad()

            _, landmarks_3D, flame_regularizer_loss = self._flamelayer()

            my_mesh = make_mesh(self._flamelayer, detach=False)
            obj1 = lmks_factor * self._lmks_fit_loss(landmarks_3D) + \
                   silh_factor * self._silh_fit_loss(my_mesh)

            loss = obj1 + flame_regularizer_loss
            loss.backward()
            optimizer.step()
Beispiel #5
0
def plot_silhouette(flamelayer, renderer, target_silh):
    target_silh = target_silh.squeeze()
    mesh = make_mesh(flamelayer, detach=True)
    silhouete = renderer.render_sil(mesh)
    silhouete = silhouete.detach().cpu().numpy().squeeze()
    # target_img = np.hstack((target_silh-silhouete[:, :, 3], silhouete[:, :, 3]))
    # cv2.imshow('target_img', target_img)
    # cv2.waitKey()

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.imshow(silhouete.squeeze()[..., 3] -
               target_silh)  # only plot the alpha channel of the RGBA image
    plt.grid(False)
    plt.subplot(1, 2, 2)
    plt.imshow(target_silh)
    plt.grid(False)
    plt.show()
 def log_obj(str):
     if FIT_2D_DEBUG_MODE:
         _, _, flame_regularizer_loss = flamelayer()
         my_mesh = make_mesh(flamelayer, )
         print(str + ' obj = ', lmks_fit_loss(my_mesh) + silh_fit_loss(my_mesh))