Exemplo n.º 1
0
    def keypoint_transfer_tex(self, kp_gt_source, kp_gt_tgt, tex_flow_source, tex_flow_tgt):
        # make heat map for every keypoint
        kp_gt_source_hm = make_heat_map(kp_gt_source)

        # sample from heatmap using predicted flow and take channelwise mean for each face.
        B, numkp = kp_gt_source_hm.shape[:2]
        F, T = tex_flow_source.shape[1:3]
        omega = geom_utils.sample_textures(tex_flow_source, kp_gt_source_hm).view(B, F, T, T, numkp)
        omega = omega.mean([-2, -3])  # B x F x numkp

        # then argmax to find face of every kp
        kp_face = omega.argmax(1).long()  # B x numkp

        # find locations that fall on face corresponding to each kp. take mean to find kp location in tgt
        xx = torch.arange(B).unsqueeze(-1).repeat(1, numkp).view(-1).long()
        kp_pred_tgt = tex_flow_tgt[xx, kp_face.view(-1)]
        kp_pred_tgt = kp_pred_tgt.view(B, numkp, T, T, 2).mean([-2, -3])  # B x numkp x 2

        kp_pred_tgt = kp_pred_tgt.detach().cpu().type_as(kp_gt_tgt).numpy()
        kp_gt_tgt = kp_gt_tgt.cpu().numpy()
        kps_err12, kps_vis12 = self.compute_pck(kp_pred_tgt, kp_gt_tgt)

        return kps_err12, kps_vis12
Exemplo n.º 2
0
Arquivo: main.py Projeto: neka-nat/cmr
    def forward(self):
        opts = self.opts
        if opts.texture:
            pred_codes, self.textures = self.model(self.input_imgs)
        else:
            pred_codes = self.model(self.input_imgs)
        self.delta_v, scale, trans, quat = pred_codes

        self.cam_pred = torch.cat([scale, trans, quat], 1)

        if opts.only_mean_sym:
            del_v = self.delta_v
        else:
            del_v = self.model.symmetrize(self.delta_v)

        # Deform mean shape:
        self.mean_shape = self.model.get_mean_shape()
        self.pred_v = self.mean_shape + del_v

        # Compute keypoints.
        self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp, dim=1)
        self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Decide which camera to use for projection.
        if opts.use_gtpose:
            proj_cam = self.cams
        else:
            proj_cam = self.cam_pred

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts, proj_cam)

        # Render mask.
        self.mask_pred = self.renderer(self.pred_v, self.faces, proj_cam)
        if opts.renderer == "smr":
            self.mask_pred = self.mask_pred[0][:, 0, :, :]

        if opts.texture:
            self.texture_flow = self.textures
            self.textures = geom_utils.sample_textures(self.texture_flow,
                                                       self.imgs)
            if opts.renderer == "smr":
                self.textures = self.textures.contiguous()
                bs, fs, ts, _, _ = self.textures.size()
                self.textures = self.textures.view(bs, fs, -1, 3)
                texture_rgba, p2f_info, _ = self.tex_renderer.forward(
                    self.pred_v.detach(), self.faces, proj_cam.detach(),
                    self.textures)
                self.texture_pred = texture_rgba[:, 0:3, :, :]
            else:
                tex_size = self.textures.size(2)
                self.textures = self.textures.unsqueeze(4).repeat(
                    1, 1, 1, 1, tex_size, 1)
                self.texture_pred = self.tex_renderer(self.pred_v.detach(),
                                                      self.faces,
                                                      proj_cam.detach(),
                                                      textures=self.textures)
        else:
            self.textures = None

        # Compute losses for this instance.
        self.kp_loss = self.projection_loss(self.kp_pred, self.kps)
        self.mask_loss = self.mask_loss_fn(self.mask_pred, self.masks)
        self.cam_loss = self.camera_loss(self.cam_pred, self.cams, 0)

        if opts.texture:
            self.tex_loss = self.texture_loss(self.texture_pred, self.imgs,
                                              self.mask_pred, self.masks)
            self.tex_dt_loss = self.texture_dt_loss_fn(self.texture_flow,
                                                       self.dts_barrier)

        # Priors:
        self.vert2kp_loss = self.entropy_loss(self.vert2kp)
        self.deform_reg = self.deform_reg_fn(self.delta_v)
        self.triangle_loss = self.triangle_loss_fn(self.pred_v)
        if opts.add_smr_loss:
            self.flatten_loss = self.flatten_loss_fn(self.pred_v).mean()
            self.ori_loss = self.ori_reg_fn(self.pred_v)

        # Finally sum up the loss.
        # Instance loss:
        self.total_loss = opts.kp_loss_wt * self.kp_loss
        self.total_loss += opts.mask_loss_wt * self.mask_loss
        self.total_loss += opts.cam_loss_wt * self.cam_loss
        if opts.texture:
            self.total_loss += opts.tex_loss_wt * self.tex_loss

        # Priors:
        self.total_loss += opts.vert2kp_loss_wt * self.vert2kp_loss
        self.total_loss += opts.deform_reg_wt * self.deform_reg
        self.total_loss += opts.triangle_reg_wt * self.triangle_loss

        self.total_loss += opts.tex_dt_loss_wt * self.tex_dt_loss

        if opts.add_smr_loss:
            self.total_loss += self.flatten_loss * opts.flatten_reg_wt
            if (self.curr_epoch < opts.stop_ori_epoch):
                # constrain prediction to be symmetric on the given axis
                self.total_loss += self.ori_loss * opts.ori_reg_wt

        log_metric("kp_loss", float(self.kp_loss.cpu().detach().numpy()))
        log_metric("mask_loss", float(self.mask_loss.cpu().detach().numpy()))
        log_metric("vert2kp_loss",
                   float(self.vert2kp_loss.cpu().detach().numpy()))
        log_metric("deform_reg", float(self.deform_reg.cpu().detach().numpy()))
        log_metric("triangle_loss",
                   float(self.triangle_loss.cpu().detach().numpy()))
        log_metric("cam_loss", float(self.cam_loss.cpu().detach().numpy()))
        log_metric("tex_loss", float(self.tex_loss.cpu().detach().numpy()))
        log_metric("tex_dt_loss",
                   float(self.tex_dt_loss.cpu().detach().numpy()))
        log_metric("total_loss", float(self.total_loss.cpu().detach().numpy()))
Exemplo n.º 3
0
    def forward(self):
        opts = self.opts
        if opts.texture:
            pred_codes, self.textures = self.model(self.input_imgs)
        else:
            pred_codes = self.model(self.input_imgs)
        self.delta_v, scale, trans, quat = pred_codes

        self.cam_pred = torch.cat([scale, trans, quat], 1)

        if opts.only_mean_sym:
            del_v = self.delta_v
        else:
            del_v = self.model.symmetrize(self.delta_v)

        # Deform mean shape:
        self.mean_shape = self.model.get_mean_shape()
        self.pred_v = self.mean_shape + del_v

        # Compute keypoints.
        self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp, dim=1)
        self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Decide which camera to use for projection.
        if opts.use_gtpose:
            proj_cam = self.cams
        else:
            proj_cam = self.cam_pred

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts, proj_cam)

        # Render mask.
        self.mask_pred = self.renderer(self.pred_v, self.faces, proj_cam)

        if opts.texture:
            self.texture_flow = self.textures
            self.textures = geom_utils.sample_textures(self.texture_flow,
                                                       self.imgs)
            tex_size = self.textures.size(2)
            self.textures = self.textures.unsqueeze(4).repeat(
                1, 1, 1, 1, tex_size, 1)

            self.texture_pred = self.tex_renderer(self.pred_v.detach(),
                                                  self.faces,
                                                  proj_cam.detach(),
                                                  textures=self.textures)
        else:
            self.textures = None

        # Compute losses for this instance.
        self.kp_loss = self.projection_loss(self.kp_pred, self.kps)
        self.mask_loss = self.mask_loss_fn(self.mask_pred, self.masks)
        self.cam_loss = self.camera_loss(self.cam_pred, self.cams, 0)

        if opts.texture:
            self.tex_loss = self.texture_loss(self.texture_pred, self.imgs,
                                              self.mask_pred, self.masks)
            self.tex_dt_loss = self.texture_dt_loss_fn(self.texture_flow,
                                                       self.dts_barrier)

        # Priors:
        self.vert2kp_loss = self.entropy_loss(self.vert2kp)
        self.deform_reg = self.deform_reg_fn(self.delta_v)
        self.triangle_loss = self.triangle_loss_fn(self.pred_v)

        # Finally sum up the loss.
        # Instance loss:
        self.total_loss = opts.kp_loss_wt * self.kp_loss
        self.total_loss += opts.mask_loss_wt * self.mask_loss
        self.total_loss += opts.cam_loss_wt * self.cam_loss
        if opts.texture:
            self.total_loss += opts.tex_loss_wt * self.tex_loss

        # Priors:
        self.total_loss += opts.vert2kp_loss_wt * self.vert2kp_loss
        self.total_loss += opts.deform_reg_wt * self.deform_reg
        self.total_loss += opts.triangle_reg_wt * self.triangle_loss

        self.total_loss += opts.tex_dt_loss_wt * self.tex_dt_loss
Exemplo n.º 4
0
    def forward(self):
        # print('predict forward')
        # print('self.opts.texture:',self.opts.texture)
        # print('self.opts.use_sfm_camera:',self.opts.use_sfm_camera)
        if self.opts.texture:
            pred_codes, self.textures = self.model.forward(self.input_imgs)
            # print('--tex from model:',self.textures.shape) # [B, 1280, 6, 6, 2]
        else:
            pred_codes = self.model.forward(self.input_imgs)

        self.delta_v, scale, trans, quat = pred_codes

        if self.opts.use_sfm_camera:
            self.cam_pred = self.sfm_cams
        else:
            self.cam_pred = torch.cat([scale, trans, quat], 1)

        del_v = self.model.symmetrize(self.delta_v)
        # Deform mean shape:
        self.mean_shape = self.model.get_mean_shape()

        if self.opts.use_sfm_ms:
            self.pred_v = self.sfm_mean_shape
        elif self.opts.ignore_pred_delta_v:
            self.pred_v = self.mean_shape + del_v * 0
        else:
            self.pred_v = self.mean_shape + del_v

        # Compute keypoints.
        if self.opts.use_sfm_ms:
            self.kp_verts = self.pred_v
        else:
            self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp,
                                                       dim=1)
            self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts,
                                                    self.cam_pred)
        self.mask_pred = self.renderer.forward(self.pred_v, self.faces,
                                               self.cam_pred)
        # import pdb; pdb.set_trace()
        # Render texture.
        if self.opts.texture and not self.opts.use_sfm_ms:
            if self.textures.size(-1) == 2:
                # Flow texture!
                self.texture_flow = self.textures
                # print('--tex b4 sample:',self.textures.shape)
                self.textures = geom_utils.sample_textures(
                    self.textures, self.imgs)
                # print('--tex after sample:',self.textures.shape)
            if self.textures.dim() == 5:  # B x F x T x T x 3
                tex_size = self.textures.size(2)
                # print('--tex b4 unsqueeze:',self.textures.shape)
                self.textures = self.textures.unsqueeze(4).repeat(
                    1, 1, 1, 1, tex_size, 1)
                # print('--tex after unsqueeze:',self.textures.shape)
            # Render texture:
            # texture_pred is rendered img
            self.texture_pred = self.tex_renderer.forward(
                self.pred_v, self.faces, self.cam_pred, textures=self.textures)

            # B x 2 x H x W
            uv_flows = self.model.texture_predictor.uvimage_pred
            # B x H x W x 2
            self.uv_flows = uv_flows.permute(0, 2, 3, 1)
            self.uv_images = torch.nn.functional.grid_sample(
                self.imgs, self.uv_flows, align_corners=True)
        else:
            self.textures = None
Exemplo n.º 5
0
    def forward(self):
        if self.opts.texture:
            pred_codes, self.textures = self.model.forward(
                self.input_imgs)  #這邊得到的textures就是1 1280 6 6 2
        else:
            pred_codes = self.model.forward(self.input_imgs)

        self.delta_v, scale, trans, quat = pred_codes

        if self.opts.use_sfm_camera:
            self.cam_pred = self.sfm_cams
        else:
            self.cam_pred = torch.cat([scale, trans, quat], 1)

        del_v = self.model.symmetrize(self.delta_v)
        # Deform mean shape:
        self.mean_shape = self.model.get_mean_shape()
        #-------------------------edited by parker
        #----------------------------這確實是mean shape----------------
        f = open("bird_mean_mesh.off", "w")
        f.write("OFF\n")
        line = str(len(self.mean_shape)) + " " + str(len(
            self.faces[0])) + " 0\n"
        f.write(line)
        mesh_x = np.empty(len(self.mean_shape))
        mesh_y = np.empty(len(self.mean_shape))
        mesh_z = np.empty(len(self.mean_shape))
        # print("bird_vis verts:", self.mean_shape)
        for i in range(len(self.mean_shape)):
            mesh_x_point = float(self.mean_shape[i][0])
            mesh_y_point = float(self.mean_shape[i][1])
            mesh_z_point = float(self.mean_shape[i][2])

            line = str(mesh_x_point) + " " + str(mesh_y_point) + " " + str(
                mesh_z_point) + "\n"
            f.write(line)
            for j in range(3):
                if (j == 0):
                    mesh_x[i] = self.mean_shape[i][j]
                elif (j == 1):
                    mesh_y[i] = self.mean_shape[i][j]
                else:
                    mesh_z[i] = self.mean_shape[i][j]

        tri_i = np.empty(len(self.faces[0]))
        tri_j = np.empty(len(self.faces[0]))
        tri_k = np.empty(len(self.faces[0]))

        for i in range(len(self.faces[0])):

            #-------------------------
            face_point1 = int(self.faces[0][i][0])
            face_point2 = int(self.faces[0][i][1])
            face_point3 = int(self.faces[0][i][2])
            #--------------------------------------

            line = str(3) + " " + str(face_point1) + " " + str(
                face_point2) + " " + str(face_point3) + "\n"
            f.write(line)
            for j in range(3):
                if (j == 0):
                    tri_i[i] = self.faces[0][i][j]
                elif (j == 1):
                    tri_j[i] = self.faces[0][i][j]
                else:
                    tri_k[i] = self.faces[0][i][j]
#--------------我暫時不需要顯示這些東西
#        fig = go.Figure(
#            data=[go.Mesh3d(x=mesh_x, y=mesh_y, z=mesh_z, color='lightgreen', opacity=0.5,i=tri_i, j=tri_j, k=tri_k)])
#        fig.show()
        f.close()
        #---------------------------------------------------------
        #        exit()
        if self.opts.use_sfm_ms:
            self.pred_v = self.sfm_mean_shape
        elif self.opts.ignore_pred_delta_v:
            self.pred_v = self.mean_shape + del_v * 0
        else:
            self.pred_v = self.mean_shape + del_v

        # Compute keypoints.
        if self.opts.use_sfm_ms:
            self.kp_verts = self.pred_v
        else:
            self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp,
                                                       dim=1)
            self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts,
                                                    self.cam_pred)
        self.mask_pred = self.renderer.forward(self.pred_v, self.faces,
                                               self.cam_pred)

        # Render texture.
        if self.opts.texture and not self.opts.use_sfm_ms:
            if self.textures.size(-1) == 2:
                # Flow texture!
                self.texture_flow = self.textures
                #-----------------------
                # txt_file = open("texture_flow.txt", "w")
                # txt_file.write(repr(self.textures.shape))
                # txt_file.write(repr(self.textures))
                # txt_file.close()
                #-----------------------
                self.textures = geom_utils.sample_textures(
                    self.textures, self.imgs)
#-----------------------edited by parker
# txt_file=open("texture_sample_textures.txt","w")
# txt_file.write(repr(self.textures.shape))
# txt_file.write(repr(self.textures))
# txt_file.close()

            if self.textures.dim() == 5:  # B x F x T x T x 3
                tex_size = self.textures.size(2)
                self.textures = self.textures.unsqueeze(4).repeat(
                    1, 1, 1, 1, tex_size, 1)  #這一行部知道在幹麻

            # Render texture:
            self.texture_pred = self.tex_renderer.forward(
                self.pred_v, self.faces, self.cam_pred, textures=self.textures)

            # B x 2 x H x W
            uv_flows = self.model.texture_predictor.uvimage_pred
            # B x H x W x 2
            self.uv_flows = uv_flows.permute(0, 2, 3, 1)
            self.uv_images = torch.nn.functional.grid_sample(
                self.imgs, self.uv_flows, align_corners=True)
            #edited_by parker
            # uv_flows=open("uv_flows.txt","w")
            # uv_flows.write(repr(self.uv_flows.shape))
            # uv_flows.write(repr(self.uv_flows))
            # uv_flows.close()
            # uv_images=open("uv_images.txt","w")
            # uv_images.write(repr(self.uv_images[0].shape))
            # uv_images_png=np.reshape(self.uv_images[0],(128,256,3))
            # uv_images.write(repr(uv_images_png))
            # uv_images.close()
            #---------------------
            #----------------------------------show uv image------ parker
            uv_image_array = np.zeros([128, 256, 3])

            for i in range(len(self.uv_images[0])):
                for j in range(len(self.uv_images[0][i])):
                    for k in range(len(self.uv_images[0][i][j])):
                        uv_image_array[j][k][i] = self.uv_images[0][i][j][k]
            import matplotlib.pyplot as plt
            plt.imshow(uv_image_array)
            plt.draw()
            plt.show()
            plt.savefig('uv_image_test.png')
            #----------------------------------
        else:
            self.textures = None