Exemple #1
0
    def fixed_sampling(self):
        z = np.random.uniform(low=-0.01, high=0.01, size=1)[0]
        z = -0.01
        theta_sample = np.random.uniform(low=0.0, high=2.0 * np.pi, size=1)[0]
        #theta_sample = -0.07*np.pi
        x = np.sqrt((self.dist**2 - z**2)) * np.cos(theta_sample)
        y = np.sqrt((self.dist**2 - z**2)) * np.sin(theta_sample)

        cam_position = torch.tensor([x, y, z]).unsqueeze(0)
        if (z < 0):
            R = look_at_rotation(cam_position, up=((0, 0, -1), )).squeeze()
        else:
            R = look_at_rotation(cam_position, up=((0, 0, 1), )).squeeze()

        # Rotate in-plane
        if (True):  #not self.simple_pose_sampling):
            rot_degrees = np.random.uniform(low=-90.0, high=90.0, size=1)
            rot_degress = 30.0
            rot = scipyR.from_euler('z', rot_degrees, degrees=True)
            rot_mat = torch.tensor(rot.as_matrix(), dtype=torch.float32)
            R = torch.matmul(R, rot_mat)
            R = R.squeeze()

        t = torch.tensor([0.0, 0.0, self.dist])
        return R, t
Exemple #2
0
    def tless_sampling(self):
        theta_sample = np.random.uniform(low=0.0, high=2.0 * np.pi, size=1)[0]
        phi_sample = np.random.uniform(low=0.0, high=2.0 * np.pi, size=1)[0]

        x = self.dist * np.sin(theta_sample) * np.cos(phi_sample)
        y = self.dist * np.sin(theta_sample) * np.sin(phi_sample)
        z = self.dist * np.cos(theta_sample)

        cam_position = torch.tensor([float(x), float(y),
                                     float(z)]).unsqueeze(0)
        if (z < 0):
            R = look_at_rotation(cam_position, up=((0, 0, -1), )).squeeze()
        else:
            R = look_at_rotation(cam_position, up=((0, 0, 1), )).squeeze()

        # Rotate in-plane
        if (not self.simple_pose_sampling):
            rot_degrees = np.random.uniform(low=-90.0, high=90.0, size=1)
            rot = scipyR.from_euler('z', rot_degrees, degrees=True)
            rot_mat = torch.tensor(rot.as_matrix(), dtype=torch.float32)
            R = torch.matmul(R, rot_mat)
            R = R.squeeze()

        t = torch.tensor([0.0, 0.0, self.dist])
        return R, t
Exemple #3
0
    def sphere_wolfram_sampling(self):
        x1 = np.random.uniform(low=-1.0, high=1.0, size=1)[0]
        x2 = np.random.uniform(low=-1.0, high=1.0, size=1)[0]
        test = x1**2 + x2**2

        while (test >= 1.0):
            x1 = np.random.uniform(low=-1.0, high=1.0, size=1)[0]
            x2 = np.random.uniform(low=-1.0, high=1.0, size=1)[0]
            test = x1**2 + x2**2

        x = 2.0 * x1 * (1.0 - x1**2 - x2**2)**(0.5)
        y = 2.0 * x2 * (1.0 - x1**2 - x2**2)**(0.5)
        z = 1.0 - 2.0 * (x1**2 + x2**2)

        cam_position = torch.tensor([x, y, z]).unsqueeze(0)
        if (z < 0):
            R = look_at_rotation(cam_position, up=((0, 0, -1), )).squeeze()
        else:
            R = look_at_rotation(cam_position, up=((0, 0, 1), )).squeeze()

        # Rotate in-plane
        if (not self.simple_pose_sampling):
            rot_degrees = np.random.uniform(low=-90.0, high=90.0, size=1)
            rot = scipyR.from_euler('z', rot_degrees, degrees=True)
            rot_mat = torch.tensor(rot.as_matrix(), dtype=torch.float32)
            R = torch.matmul(R, rot_mat)
            R = R.squeeze()

        t = torch.tensor([0.0, 0.0, self.dist])
        return R, t
Exemple #4
0
    def forward(self, mesh):
        # Render the image using the updated camera position. Based on the new position of the
        # camera we calculate the rotation and translation matrices
        # (1,3)  -> (n,3)
        R = look_at_rotation(self.camera_position, device=self.device)  # (1, 3, 3) -> (n,3,3)
        # (1,3,1) -> (n,3,1)
        T = -torch.bmm(R.transpose(1, 2), self.camera_position[:, :, None])[:, :, 0]  # (1, 3) -> (n,3)

        if self.light:
            images = torch.empty(self.nviews * len(mesh), 224, 224, 4, device=self.device)
            # the loop is needed because for now pytorch3d do not allow a batch of lights
            for i in range(self.nviews):
                self.lights.location = self.camera_position[i]
                imgs = self.renderer(meshes_world=mesh.clone(), R=R[None, i], T=T[None, i],
                                     lights=self.lights)
                for k, j in zip(range(len(imgs)), range(0, len(imgs) * self.nviews, self.nviews)):
                    images[i + j] = imgs[k]
        else:
            meshes = mesh.extend(self.nviews)
            # because now we have n elements in R and T we need to expand them to be the same size of meshes
            R = R.repeat(len(mesh), 1, 1)
            T = T.repeat(len(mesh), 1)

            images = self.renderer(meshes_world=meshes.clone(), R=R, T=T)

        images = images.permute(0, 3, 1, 2)
        y = self.net_1(images[:, :3, :, :])
        y = y.view((int(images.shape[0] / self.nviews), self.nviews, y.shape[-3], y.shape[-2], y.shape[-1]))
        return self.net_2(torch.max(y, 1)[0].view(y.shape[0], -1))
Exemple #5
0
def campos_to_R_T_det(campos, theta, dx, dy, device='cpu', at=((0, 0, 0),), up=((0, 1, 0), )):
    R = look_at_rotation(campos, at=at, device=device, up=up)  # (n, 3, 3)
    # translation = translation_matrix(dx, dy, device=device).unsqueeze(0)
    R = torch.bmm(R, rotation_theta(theta, device_=device))
    # R = torch.bmm(translation, R)
    T = -torch.bmm(R.transpose(1, 2), campos.unsqueeze(2))[:, :, 0]
    # T = T.T.unsqueeze(0)
    # T = torch.bmm(translation, T)[0].T  # (1, 3)
    return R, T
    def forward(self):
        R = look_at_rotation(self.camera_position[None, :], device=self.device)
        T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :,
                                                               None])[:, :, 0]

        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)

        loss = torch.sum((image[..., :3] - self.image_ref)**2)
        return loss, image
Exemple #7
0
def campos_to_R_T(campos,
                  theta,
                  device='cpu',
                  at=((0, 0, 0), ),
                  up=((0, 1, 0), )):
    R = look_at_rotation(campos, at=at, device=device, up=up)  # (n, 3, 3)
    R = torch.bmm(R, rotation_theta(theta, device_=device))
    T = -torch.bmm(R.transpose(1, 2), campos.unsqueeze(2))[:, :, 0]  # (1, 3)
    return R, T
    def forward(self, mesh):
        # R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
        # T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]  # (1, 3)

        t = Transform3d(device=self.device).scale(
            self.camera_position[3] * self.distance_range).rotate_axis_angle(
                self.camera_position[0] * self.angle_range,
                axis="X",
                degrees=False).rotate_axis_angle(
                    self.camera_position[1] * self.angle_range,
                    axis="Y",
                    degrees=False).rotate_axis_angle(self.camera_position[2] *
                                                     self.angle_range,
                                                     axis="Z",
                                                     degrees=False)
        # translation = Translate(T[0][0], T[0][1], T[0][2], device=self.device)

        # t = Transform3d(matrix=self.camera_position)
        vertices = t.transform_points(self.vertices)

        R = look_at_rotation(vertices[:self.nviews], device=self.device)
        T = -torch.bmm(R.transpose(1, 2), vertices[:self.nviews, :,
                                                   None])[:, :, 0]

        if self.light:
            images = torch.empty(self.nviews * len(mesh),
                                 224,
                                 224,
                                 4,
                                 device=self.device)
            # the loop is needed because for now pytorch3d do not allow a batch of lights
            for i in range(self.nviews):
                self.lights.location = vertices[i]
                imgs = self.renderer(meshes_world=mesh.clone(),
                                     R=R[None, i],
                                     T=T[None, i],
                                     lights=self.lights)
                for k, j in zip(range(len(imgs)),
                                range(0,
                                      len(imgs) * self.nviews, self.nviews)):
                    images[i + j] = imgs[k]
        else:
            # self.lights.location = self.light_position
            meshes = mesh.extend(self.nviews)
            # because now we have n elements in R and T we need to expand them to be the same size of meshes
            R = R.repeat(len(mesh), 1, 1)
            T = T.repeat(len(mesh), 1)

            images = self.renderer(meshes_world=meshes.clone(), R=R,
                                   T=T)  # , lights=self.lights)

        images = images.permute(0, 3, 1, 2)
        y = self.net_1(images[:, :3, :, :])
        y = y.view((int(images.shape[0] / self.nviews), self.nviews,
                    y.shape[-3], y.shape[-2], y.shape[-1]))
        return self.net_2(torch.max(y, 1)[0].view(y.shape[0], -1))
Exemple #9
0
def image_fit_loss(face_mesh_alt):
    R = look_at_rotation(camera_translation[None, :],
                         device=device)  # (1, 3, 3)
    T = -torch.bmm(R.transpose(1, 2),
                   camera_translation[None, :, None])[:, :, 0]  # (1, 3)
    silhouette = silhouette_renderer(meshes_world=face_mesh_alt.clone(),
                                     R=R,
                                     T=T)
    print(type(silhouete))
    return torch.sum((silhouette[..., 3] - image_ref_torch)**2) / (factor**2)
Exemple #10
0
    def calc(self, cam_pos):
        # Render the image using the updated camera position. Based on the new position of the
        # camer we calculate the rotation and translation matrices
        R = look_at_rotation(cam_pos[None, :], device=self.device)  # (1, 3, 3)
        T = -torch.bmm(R.transpose(1, 2), cam_pos[None, :, None])[:, :,
                                                                  0]  # (1, 3)

        image = self.renderer.render(self.meshes, R, T)

        loss = torch.sum((image[..., 3] - self.image_ref)**2)
        return loss, image
Exemple #11
0
    def forward(self):
        # Render the image using the updated camera position. Based on the new position of the
        # camer we calculate the rotation and translation matrices
        R = look_at_rotation(self.camera_position[None, :],
                             device=self.device)  # (1, 3, 3)
        T = -torch.bmm(R.transpose(1, 2),
                       self.camera_position[None, :, None])[:, :, 0]  # (1, 3)
        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)

        # Calculate the silhouette loss
        loss = torch.sum((image[..., 3] - self.image_ref)**2)
        return loss, image
Exemple #12
0
    def __init__(self, cfgs):
        super().__init__()
        self.device = cfgs.get('device', 'cpu')
        self.image_size = cfgs.get('image_size', 64)
        self.min_depth = cfgs.get('min_depth', 0.9)
        self.max_depth = cfgs.get('max_depth', 1.1)
        self.rot_center_depth = cfgs.get('rot_center_depth',
                                         (self.min_depth + self.max_depth) / 2)
        self.border_depth = cfgs.get(
            'border_depth', 0.3 * self.min_depth + 0.7 * self.max_depth)
        self.fov = cfgs.get('fov', 10)

        #### camera intrinsics
        #             (u)   (x)
        #    d * K^-1 (v) = (y)
        #             (1)   (z)

        ## renderer for visualization
        R = [[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]
        R = torch.FloatTensor(R).to(self.device)
        t = torch.zeros(1, 3, dtype=torch.float32).to(self.device)
        fx = (self.image_size - 1) / 2 / (math.tan(
            self.fov / 2 * math.pi / 180))
        fy = (self.image_size - 1) / 2 / (math.tan(
            self.fov / 2 * math.pi / 180))
        cx = (self.image_size - 1) / 2
        cy = (self.image_size - 1) / 2
        K = [[fx, 0., cx], [0., fy, cy], [0., 0., 1.]]
        K = torch.FloatTensor(K).to(self.device)
        self.inv_K = torch.inverse(K).unsqueeze(0)
        self.K = K.unsqueeze(0)

        # Initialize an OpenGL perspective camera.
        R = look_at_rotation(((0, 0, 0), ),
                             at=((0, 0, 1), ),
                             up=((0, -1, 0), ))
        cameras = OpenGLPerspectiveCameras(device=self.device,
                                           fov=self.fov,
                                           R=R)
        lights = DirectionalLights(
            ambient_color=((1.0, 1.0, 1.0), ),
            diffuse_color=((0.0, 0.0, 0.0), ),
            specular_color=((0.0, 0.0, 0.0), ),
            direction=((0, 1, 0), ),
            device=self.device,
        )
        raster_settings = RasterizationSettings(
            image_size=self.image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
        )
        self.rasterizer_torch = MeshRasterizer(cameras=cameras,
                                               raster_settings=raster_settings)
Exemple #13
0
def differentiable_face_render(vert, tri, colors, bg_img, h, w):
    """
    vert: (N, nver, 3)
    tri: (ntri, 3)
    colors: (N, nver. 3)
    bg_img: (N, 3, H, W)
    """
    assert h == w
    N, nver, _ = vert.shape
    ntri = tri.shape[0]
    tri = torch.from_numpy(tri).to(vert.device).unsqueeze(0).expand(N, ntri, 3)
    # Transform to Pytorch3D world space
    vert_t = vert + torch.tensor((0.5, 0.5, 0), dtype=torch.float, device=vert.device).view(1, 1, 3)
    vert_t = vert_t * torch.tensor((-1, 1, -1), dtype=torch.float, device=vert.device).view(1, 1, 3)
    mesh_torch = Meshes(verts=vert_t, faces=tri, textures=TexturesVertex(verts_features=colors))
    # Render
    R = look_at_rotation(camera_position=((0, 0, -300),)).to(vert.device).expand(N, 3, 3)
    T = torch.tensor((0, 0, 300), dtype=torch.float, device=vert.device).view(1, 3).expand(N, 3)
    focal = torch.tensor((2. / float(w), 2. / float(h)), dtype=torch.float, device=vert.device).view(1, 2).expand(N, 2)
    cameras = OrthographicCameras(device=vert.device, R=R, T=T, focal_length=focal)
    raster_settings = RasterizationSettings(image_size=h, blur_radius=0.0, faces_per_pixel=1)
    lights = DirectionalLights(ambient_color=((1., 1., 1.),), diffuse_color=((0., 0., 0.),),
                               specular_color=((0., 0., 0.),), direction=((0, 0, 1),), device=vert.device)
    blend_params = BlendParams(background_color=(0, 0, 0))
    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        ),
        shader=SoftPhongShader(
            device=vert.device,
            cameras=cameras,
            lights=lights,
            blend_params=blend_params
        )
    )
    images = renderer(mesh_torch)[:, :, :, :3]        # (N, H, W, 3)
    # Add background
    if bg_img is not None:
        bg_img = bg_img.permute(0, 2, 3, 1)         # (N, H, W, 3)
        images = torch.where(torch.eq(images.sum(dim=3, keepdim=True).expand(N, h, w, 3), 0), bg_img, images)
    return images
Exemple #14
0
def cam_trajectory_rotation(num_points: int = 4, device: str = gendevice):
    """
    Returns: list of camera poses (R,T) from trajectory along a spherical spiral
    """

    shape = SphericalSpiral(c=6,
                            a=3,
                            t_min=1 * math.pi,
                            t_max=1.05 * math.pi,
                            num_points=num_points)
    up = torch.tensor([[1.0, 0.0, 0.0]])
    for count, cp in enumerate(shape._tuples):
        cp_tensor = torch.tensor(cp).to(device)
        R_new = look_at_rotation(cp_tensor[None, :], device=device)
        T_new = -torch.bmm(R_new.transpose(1, 2), cp_tensor[None, :,
                                                            None])[:, :, 0]
        if not count:
            R = [R_new]
            T = [T_new]
        else:
            R.append(R_new)
            T.append(T_new)
    return (torch.stack(R)[:, 0, :], torch.stack(T)[:, 0, :])
Exemple #15
0
    loss = image_fit_loss(face_mesh_alt)
    # obj = obj1 + flame_regularizer_loss

    loss_model, _ = model()
    print('loss - ', loss)
    print('loss model - ', loss_model)
    obj = loss
    if obj.requires_grad:
        obj.backward()
    return obj


optimizer.step(fit_closure)
loss = image_fit_loss(face_mesh_alt)
R = look_at_rotation(camera_translation[None, :], device=model.device)
T = -torch.bmm(R.transpose(1, 2), camera_translation[None, :,
                                                     None])[:, :, 0]  # (1, 3)
image = silhouette_renderer(meshes_world=face_mesh_alt.clone(), R=R, T=T)
image = image[..., 3].detach().squeeze().cpu().numpy()
image = img_as_ubyte(image)
writer.append_data(image)
print('long LBFGS')
plt.subplot(121)
plt.imshow(image)
plt.title("iter: %d, loss: %0.2f" % (1, loss.data))
plt.grid("off")
plt.axis("off")
plt.subplot(122)
plt.imshow(image_ref)
plt.grid("off")
# plt.grid("off")
# plt.title("Reference silhouette")
# plt.show()

for i in np.arange(1000):
    optimizer.zero_grad()
    loss, _ = model()
    loss.backward()
    optimizer.step()

    print("{0} - loss: {1}".format(i, loss.data))
    #loop.set_description('Optimizing (loss %.4f)' % loss.data)

    # Save outputs to create a GIF.
    if True:  #i % 10 == 0:
        R = look_at_rotation(model.camera_position[None, :],
                             device=model.device)
        T = -torch.bmm(R.transpose(1, 2),
                       model.camera_position[None, :, None])[:, :, 0]  # (1, 3)
        image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
        image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        image = img_as_ubyte(image)
        #writer.append_data(image)

        if (i % 10 == 0):
            fig = plt.figure(figsize=(6, 6))
            plt.imshow(image[..., :3])
            plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
            plt.grid("off")
            plt.axis("off")
            fig.tight_layout()
            fig.savefig("iteration{0}.png".format(i), dpi=fig.dpi)
Exemple #17
0
# show model silhouette using pytorch code
verts, _, _ = flamelayer()
# Initialize each vertex to be white in color.
verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
textures = Textures(verts_rgb=verts_rgb.to(device))
faces = torch.tensor(np.int32(flamelayer.faces), dtype=torch.long).cuda()

my_mesh = Meshes(verts=[verts.to(device)],
                 faces=[faces.to(device)],
                 textures=textures)

camera_position = nn.Parameter(
    torch.from_numpy(np.array([1, 0.1, 0.1],
                              dtype=np.float32)).to(my_mesh.device))

R = look_at_rotation(camera_position[None, :], device=device)  # (1, 3, 3)
T = -torch.bmm(R.transpose(1, 2), camera_position[None, :, None])[:, :,
                                                                  0]  # (1, 3)

silhouete = silhouette_renderer(meshes_world=my_mesh.clone(), R=R, T=T)
silhouete = silhouete.detach().cpu().numpy()[..., 3]

print(silhouete.shape)
plt.imshow(silhouete.squeeze())
plt.show()


class Model(nn.Module):
    def __init__(self, meshes, renderer, image_ref):
        super().__init__()
        self.meshes = meshes
Exemple #18
0
def pred_synth(segpose,
               params: Params,
               mesh_type: str = "dolphin",
               device: str = "cuda"):

    if not params.pred_dir and not os.path.exists(params.pred_dir):
        raise FileNotFoundError(
            "Prediction directory has not been set or the file does not exist, please set using cli args or params"
        )
    pred_folders = [
        join(params.pred_dir, f) for f in os.listdir(params.pred_dir)
    ]
    count = 1
    for p in sorted(pred_folders):
        try:
            print(p)
            manager = RenderManager.from_path(p)
            manager.rectify_paths(base_folder=params.pred_dir)
        except FileNotFoundError:
            continue
        # Run Silhouette Prediction Network
        logging.info(f"Starting mask predictions")
        mask_priors = []
        R_pred, T_pred = [], []
        q_loss, t_loss = 0, 0
        # Collect Translation stats
        R_gt, T_gt = manager._trajectory
        poses_gt = EvMaskPoseDataset.preprocess_poses(manager._trajectory)
        std_T, mean_T = torch.std_mean(T_gt)
        for idx in range(len(manager)):
            try:
                ev_frame = manager.get_event_frame(idx)
            except Exception as e:
                print(e)
                break
            mask_pred, pose_pred = predict_segpose(segpose, ev_frame,
                                                   params.threshold_conf,
                                                   params.img_size)
            # mask_pred = smooth_predicted_mask(mask_pred)
            manager.add_pred(idx, mask_pred, "silhouette")
            mask_priors.append(torch.from_numpy(mask_pred))

            # Make qexp a torch function
            # q_pred = qexp(pose_pred[:, 3:])
            # q_targ = qexp(poses_gt[idx, 3:].unsqueeze(0))
            ####  SHOULD THIS BE NORMALIZED ??
            q_pred = pose_pred[:, 3:]
            q_targ = poses_gt[idx, 3:]

            q_pred_unit = q_pred / torch.norm(q_pred)
            q_targ_unit = q_targ / torch.norm(q_targ)
            # print("learnt: ", q_pred_unit, q_targ_unit)

            t_pred = pose_pred[:, :3] * std_T + mean_T
            t_targ = poses_gt[idx, :3] * std_T + mean_T
            T_pred.append(t_pred)

            q_loss += quaternion_angular_error(q_pred_unit, q_targ_unit)
            t_loss += t_error(t_pred, t_targ)

            r_pred = rc.quaternion_to_matrix(q_pred).unsqueeze(0)
            R_pred.append(r_pred.squeeze(0))

        q_loss_mean = q_loss / (idx + 1)
        t_loss_mean = t_loss / (idx + 1)

        # Convert R,T to world-to-view transforms --> Pytorch3d convention for the :

        R_pred_abs = torch.cat(R_pred)
        T_pred_abs = torch.cat(T_pred)
        # Take inverse of view-to-world (output of net) to obtain w2v
        wtv_trans = (get_world_to_view_transform(
            R=R_pred_abs, T=T_pred_abs).inverse().get_matrix())
        T_pred = wtv_trans[:, 3, :3]
        R_pred = wtv_trans[:, :3, :3]
        R_pred_test = look_at_rotation(T_pred_abs)
        T_pred_test = -torch.bmm(R_pred_test.transpose(1, 2),
                                 T_pred_abs[:, :, None])[:, :, 0]
        # Convert back to view-to-world to get absolute
        vtw_trans = (get_world_to_view_transform(
            R=R_pred_test, T=T_pred_test).inverse().get_matrix())
        T_pred_trans = vtw_trans[:, 3, :3]
        R_pred_trans = vtw_trans[:, :3, :3]

        # Calc pose error for this:
        q_loss_mean_test = 0
        t_loss_mean_test = 0
        for idx in range(len(R_pred_test)):
            q_pred_trans = rc.matrix_to_quaternion(R_pred_trans[idx]).squeeze()
            q_targ = poses_gt[idx, 3:]
            q_targ_unit = q_targ / torch.norm(q_targ)
            # print("look: ", q_test, q_targ)
            q_loss_mean_test += quaternion_angular_error(
                q_pred_trans, q_targ_unit)
            t_targ = poses_gt[idx, :3] * std_T + mean_T
            t_loss_mean_test += t_error(T_pred_trans[idx], t_targ)
        q_loss_mean_test /= idx + 1
        t_loss_mean_test /= idx + 1

        logging.info(
            f"Mean Translation Error: {t_loss_mean}; Mean Rotation Error: {q_loss_mean}"
        )
        logging.info(
            f"Mean Translation Error: {t_loss_mean_test}; Mean Rotation Error: {q_loss_mean_test}"
        )

        # Plot estimated cameras
        logging.info(f"Plotting pose map")
        idx = random.sample(range(len(R_gt)), k=2)
        pose_plot = plot_cams_from_poses(
            (R_gt[idx], T_gt[idx]), (R_pred[idx], T_pred[idx]), params.device)
        pose_plot_test = plot_cams_from_poses(
            (R_gt[idx], T_gt[idx]), (R_pred_test[idx], T_pred_test[idx]),
            params.device)
        manager.add_pose_plot(pose_plot, "rot+trans")
        manager.add_pose_plot(pose_plot_test, "trans")

        count += 1
        groundtruth_silhouettes = manager._images("silhouette") / 255.0
        print(groundtruth_silhouettes.shape)
        print(torch.stack((mask_priors)).shape)
        seg_iou = neg_iou_loss(groundtruth_silhouettes,
                               torch.stack((mask_priors)) / 255.0)
        print("Seg IoU", seg_iou)

        # RUN MESH DEFORMATION

        # RUN MESH DEFORMATION
        # Run it 3 times: w/ Rot+Trans - w/ Trans+LookAt - w/ GT Pose
        experiments = {
            "GT-Pose": [R_gt, T_gt],
            # "Rot+Trans": [R_pred, T_pred],
            # "Trans+LookAt": [R_pred_test, T_pred_test]
        }

        results = {}
        input_m = torch.stack((mask_priors))

        for i in range(len(experiments.keys())):

            logging.info(
                f"Input pred shape & max: {input_m.shape}, {input_m.max()}")
            # The MeshDeformation model will return silhouettes across all view by default

            mesh_model = MeshDeformationModel(device=device, params=params)

            R, T = list(experiments.values())[i]
            experiment_results = mesh_model.run_optimization(input_m, R, T)
            renders = mesh_model.render_final_mesh((R, T), "predict",
                                                   input_m.shape[-2:])

            mesh_silhouettes = renders["silhouettes"].squeeze(1)
            mesh_images = renders["images"].squeeze(1)
            experiment_name = list(experiments.keys())[i]
            for idx in range(len(mesh_silhouettes)):
                manager.add_pred(
                    idx,
                    mesh_silhouettes[idx].cpu().numpy(),
                    "silhouette",
                    destination=f"mesh_{experiment_name}",
                )
                manager.add_pred(
                    idx,
                    mesh_images[idx].cpu().numpy(),
                    "phong",
                    destination=f"mesh_{experiment_name}",
                )

            # Calculate chamfer loss:
            mesh_pred = mesh_model._final_mesh
            if mesh_type == "dolphin":
                path = "data/meshes/dolphin/dolphin.obj"
                mesh_gt = load_objs_as_meshes(
                    [path],
                    create_texture_atlas=False,
                    load_textures=True,
                    device=device,
                )
            # Shapenet Cars
            elif mesh_type == "shapenet":
                mesh_info = manager.metadata["mesh_info"]
                path = os.path.join(
                    params.gt_mesh_path,
                    f"/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj"
                )
                # path = f"data/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj"
                try:
                    verts, faces, aux = load_obj(path,
                                                 load_textures=True,
                                                 create_texture_atlas=True)

                    mesh_gt = Meshes(
                        verts=[verts],
                        faces=[faces.verts_idx],
                        textures=TexturesAtlas(atlas=[aux.texture_atlas]),
                    ).to(device)
                except:
                    mesh_gt = None
                    print("CANNOT COMPUTE CHAMFER LOSS")
            if mesh_gt:
                mesh_pred_compute, mesh_gt_compute = scale_meshes(
                    mesh_pred.clone(), mesh_gt.clone())
                pcl_pred = sample_points_from_meshes(mesh_pred_compute,
                                                     num_samples=5000)
                pcl_gt = sample_points_from_meshes(mesh_gt_compute,
                                                   num_samples=5000)
                chamfer_loss = chamfer_distance(pcl_pred,
                                                pcl_gt,
                                                point_reduction="mean")
                print("CHAMFER LOSS: ", chamfer_loss)
                experiment_results["chamfer_loss"] = (
                    chamfer_loss[0].cpu().detach().numpy().tolist())

            mesh_iou = neg_iou_loss(groundtruth_silhouettes, mesh_silhouettes)

            experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist()

            results[experiment_name] = experiment_results

            manager.add_pred_mesh(mesh_pred, experiment_name)
        # logging.info(f"Input pred shape & max: {input_m.shape}, {input_m.max()}")
        # # The MeshDeformation model will return silhouettes across all view by default
        #
        #
        #
        # experiment_results = models["mesh"].run_optimization(input_m, R_gt, T_gt)
        # renders = models["mesh"].render_final_mesh(
        #     (R_gt, T_gt), "predict", input_m.shape[-2:]
        # )
        #
        # mesh_silhouettes = renders["silhouettes"].squeeze(1)
        # mesh_images = renders["images"].squeeze(1)
        # experiment_name = params.name
        # for idx in range(len(mesh_silhouettes)):
        #     manager.add_pred(
        #         idx,
        #         mesh_silhouettes[idx].cpu().numpy(),
        #         "silhouette",
        #         destination=f"mesh_{experiment_name}",
        #     )
        #     manager.add_pred(
        #         idx,
        #         mesh_images[idx].cpu().numpy(),
        #         "phong",
        #         destination=f"mesh_{experiment_name}",
        #     )
        #
        # # Calculate chamfer loss:
        # mesh_pred = models["mesh"]._final_mesh
        # if mesh_type == "dolphin":
        #     path = params.gt_mesh_path
        #     mesh_gt = load_objs_as_meshes(
        #         [path],
        #         create_texture_atlas=False,
        #         load_textures=True,
        #         device=device,
        #     )
        # # Shapenet Cars
        # elif mesh_type == "shapenet":
        #     mesh_info = manager.metadata["mesh_info"]
        #     path = params.gt_mesh_path
        #     try:
        #         verts, faces, aux = load_obj(
        #             path, load_textures=True, create_texture_atlas=True
        #         )
        #
        #         mesh_gt = Meshes(
        #             verts=[verts],
        #             faces=[faces.verts_idx],
        #             textures=TexturesAtlas(atlas=[aux.texture_atlas]),
        #         ).to(device)
        #     except:
        #         mesh_gt = None
        #         print("CANNOT COMPUTE CHAMFER LOSS")
        # if mesh_gt and params.is_real_data:
        #     mesh_pred_compute, mesh_gt_compute = scale_meshes(
        #         mesh_pred.clone(), mesh_gt.clone()
        #     )
        #     pcl_pred = sample_points_from_meshes(
        #         mesh_pred_compute, num_samples=5000
        #     )
        #     pcl_gt = sample_points_from_meshes(mesh_gt_compute, num_samples=5000)
        #     chamfer_loss = chamfer_distance(
        #         pcl_pred, pcl_gt, point_reduction="mean"
        #     )
        #     print("CHAMFER LOSS: ", chamfer_loss)
        #     experiment_results["chamfer_loss"] = (
        #         chamfer_loss[0].cpu().detach().numpy().tolist()
        #     )
        #
        # mesh_iou = neg_iou_loss_all(groundtruth_silhouettes, mesh_silhouettes)
        #
        # experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist()
        #
        # results[experiment_name] = experiment_results
        #
        # manager.add_pred_mesh(mesh_pred, experiment_name)

        seg_iou = neg_iou_loss_all(groundtruth_silhouettes, input_m / 255.0)
        gt_iou = neg_iou_loss_all(groundtruth_silhouettes,
                                  groundtruth_silhouettes)

        results["mesh_iou"] = mesh_iou.detach().cpu().numpy().tolist()
        results["seg_iou"] = seg_iou.detach().cpu().numpy().tolist()
        logging.info(f"Mesh IOU list & results: {mesh_iou}")
        logging.info(f"Seg IOU list & results: {seg_iou}")
        logging.info(f"GT IOU list & results: {gt_iou} ")

        # results["mean_iou"] = IOULoss().forward(groundtruth, mesh_silhouettes).detach().cpu().numpy().tolist()
        # results["mean_dice"] = DiceCoeffLoss().forward(groundtruth, mesh_silhouettes)

        manager.set_pred_results(results)
        manager.close()
Exemple #19
0
def camera_calibration(flamelayer, target_silhouette, cam_pos, optimizer,
                       renderer):
    '''
    Fit FLAME to 2D landmarks
    :param flamelayer           Flame parametric model
    :param scale                Camera scale parameter (weak prespective camera)
    :param target_2d_lmks:      target 2D landmarks provided as (num_lmks x 3) matrix
    :return: The mesh vertices and the weak prespective camera parameter (scale)
    '''
    # torch_target_2d_lmks = torch.from_numpy(target_2d_lmks).cuda()
    # factor = max(max(target_2d_lmks[:,0]) - min(target_2d_lmks[:,0]),max(target_2d_lmks[:,1]) - min(target_2d_lmks[:,1]))

    # def image_fit_loss(landmarks_3D):
    #     landmarks_2D = torch_project_points_weak_perspective(landmarks_3D, scale)
    #     return flamelayer.weights['lmk']*torch.sum(torch.sub(landmarks_2D,torch_target_2d_lmks)**2) / (factor ** 2)

    # Set the cuda device
    device = torch.device("cuda:0")
    verts, _, _ = flamelayer()
    verts = verts.detach()
    # Initialize each vertex to be white in color.
    verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
    textures = Textures(verts_rgb=verts_rgb.to(device))
    faces = torch.tensor(np.int32(flamelayer.faces), dtype=torch.long).cuda()

    mesh = Meshes(verts=[verts.to(device)],
                  faces=[faces.to(device)],
                  textures=textures)
    silhouette_err = SilhouetteErr(mesh, renderer, target_silhouette)

    def fit_closure():
        if torch.is_grad_enabled():
            optimizer.zero_grad()
        loss, sil = silhouette_err.calc(cam_pos)

        obj = loss
        # print(loss)
        # print('cam pos', cam_pos)
        if obj.requires_grad:
            obj.backward()
        return obj

    def log_obj(str):
        if FIT_2D_DEBUG_MODE:
            vertices, landmarks_3D, flame_regularizer_loss = flamelayer()
            # print (str + ' obj = ', image_fit_loss(landmarks_3D))

    def log(str):
        if FIT_2D_DEBUG_MODE:
            print(str)

    loss, sil = silhouette_err.calc(cam_pos)
    sil1 = sil[..., 3].detach().squeeze().cpu().numpy()
    sil2 = silhouette_err.image_ref.detach().cpu().numpy()
    plt.subplot(121)
    plt.imshow(sil1)
    plt.subplot(122)
    plt.imshow(sil2)
    plt.show()
    print(cam_pos, loss)
    log('Optimizing rigid transformation')
    log_obj('Before optimization obj')
    optimizer.step(fit_closure)
    log_obj('After optimization obj')
    loss, sil = silhouette_err.calc(cam_pos)

    sil1 = sil[..., 3].detach().squeeze().cpu().numpy()
    sil2 = silhouette_err.image_ref.detach().cpu().numpy()
    plt.subplot(121)
    plt.imshow(sil1)
    plt.subplot(122)
    plt.imshow(sil2)
    plt.show()
    R = look_at_rotation(cam_pos[None, :], device=device)  # (1, 3, 3)
    T = -torch.bmm(R.transpose(1, 2), cam_pos[None, :, None])[:, :,
                                                              0]  # (1, 3)
    print('R,T')
    print(R)
    print(T)
    return cam_pos
def train(model, criterion, optimizer, train_loader, val_loader, args):
    best_prec1 = 0
    epoch_no_improve = 0

    for epoch in range(1000):

        statistics = Statistics()
        model.train()
        start_time = time.time()

        for i, (input, target) in enumerate(train_loader):
            loss, (prec1, prec5), y_pred, y_true = execute_batch(
                model, criterion, input, target, args.device)

            statistics.update(loss.detach().cpu().numpy(), prec1, prec5,
                              y_pred, y_true)
            # compute gradient and do optimizer step
            optimizer.zero_grad()  #
            loss.backward()
            optimizer.step()

            # if args.net_version == 2:
            #    model.camera_position = model.camera_position.clamp(0, 1)
            del loss
            torch.cuda.empty_cache()

        elapsed_time = time.time() - start_time

        # Evaluate on validation set
        val_statistics = validate(val_loader, model, criterion, args.device)

        log_data(statistics, "train", val_loader.dataset.dataset.classes,
                 epoch)
        log_data(val_statistics, "internal_val",
                 val_loader.dataset.dataset.classes, epoch)

        wandb.log({"Epoch elapsed time": elapsed_time}, step=epoch)
        # print(model.camera_position)
        if epoch % 1 == 0:
            vertices = []
            if args.net_version == 1:
                R = look_at_rotation(model.camera_position, device=args.device)
                T = -torch.bmm(R.transpose(1, 2),
                               model.camera_position[:, :, None])[:, :, 0]
            else:
                t = Transform3d(device=model.device).scale(
                    model.camera_position[3] *
                    model.distance_range).rotate_axis_angle(
                        model.camera_position[0] * model.angle_range,
                        axis="X",
                        degrees=False).rotate_axis_angle(
                            model.camera_position[1] * model.angle_range,
                            axis="Y",
                            degrees=False).rotate_axis_angle(
                                model.camera_position[2] * model.angle_range,
                                axis="Z",
                                degrees=False)

                vertices = t.transform_points(model.vertices)

                R = look_at_rotation(vertices[:model.nviews],
                                     device=model.device)
                T = -torch.bmm(R.transpose(1, 2), vertices[:model.nviews, :,
                                                           None])[:, :, 0]

            cameras = OpenGLPerspectiveCameras(R=R, T=T, device=args.device)
            wandb.log(
                {
                    "Cameras":
                    [wandb.Image(plot_camera_scene(cameras, args.device))]
                },
                step=epoch)
            plt.close()
            images = render_shape(model, R, T, args, vertices)
            wandb.log(
                {
                    "Views": [
                        wandb.Image(
                            image_grid(images,
                                       rows=int(np.ceil(args.nviews / 2)),
                                       cols=2))
                    ]
                },
                step=epoch)
            plt.close()
        #  Save best model and best prediction
        if val_statistics.top1.avg > best_prec1:
            best_prec1 = val_statistics.top1.avg
            save_model("views_net", model, optimizer, args.fname_best)
            epoch_no_improve = 0
        else:
            # Early stopping
            epoch_no_improve += 1
            if epoch_no_improve == 20:
                wandb.run.summary[
                    "best_internal_val_top1_accuracy"] = best_prec1
                wandb.run.summary[
                    "best_internal_val_top1_accuracy_epoch"] = epoch - 20

                return