def make_morph_image(self, src_img, src_info, erode_ks=3, dilate_ks=11):
        """

        Args:
            src_img (torch.Tensor): (bs * ns, 3, h, w)
            src_info (dict):
            erode_ks (int):
            dilate_ks (int):

        Returns:
            all_morph_imgs (torch.cuda.Tensor): (bs * ns, 3, h, w)
        """

        bs = src_img.shape[0]

        if erode_ks > 0:
            confidant_sil = morph(src_info["confidant_sil"],
                                  ks=erode_ks,
                                  mode="erode")
        else:
            confidant_sil = src_info["confidant_sil"]

        if dilate_ks > 0:
            outpad_sil = morph(src_info["outpad_sil"],
                               ks=dilate_ks,
                               mode="dilate")
        else:
            outpad_sil = src_info["outpad_sil"]

        # outpad_sil = ((confidant_sil + (1 - src_info["cond"][:, -1:])) > 0).float()
        # outpad_sil = morph(outpad_sil, ks=dilate_ks, mode="dilate")

        canny_filter = CannyFilter(device=src_img.device).to(src_img.device)

        blurred, grad_x, grad_y, grad_magnitude, grad_orientation, thin_edges = canny_filter(
            confidant_sil, 0.1, 0.9, True)

        uncertain_sil = outpad_sil * (1 - confidant_sil)

        # visualizer.vis_named_img("silhouette", 1 - src_info["masks"])
        # visualizer.vis_named_img("confidant_sil", confidant_sil)
        # visualizer.vis_named_img("outpad_sil", outpad_sil)
        # visualizer.vis_named_img("uncertain_sil", uncertain_sil)

        all_morph_imgs = []
        for i in range(bs):
            boundaries_points = thin_edges[i, 0].nonzero(as_tuple=False)
            uncertain_points = uncertain_sil[i, 0].nonzero(as_tuple=False)

            weights, nn_pts, ids = self.cal_top_k_ids(uncertain_points,
                                                      boundaries_points)

            morph_img = self.morph_image(src_img[i], weights, uncertain_points,
                                         nn_pts, confidant_sil[i])
            all_morph_imgs.append(morph_img)

        all_morph_imgs = torch.stack(all_morph_imgs, dim=0)
        return all_morph_imgs
    def forward(self, src_img, ref_img, src_smpl, ref_smpl, src_mask=None, ref_mask=None,
                links_ids=None, offsets=0, temporal=False):
        """
        Args:
            src_img (torch.tensor) : (bs, ns, 3, H, W);
            ref_img (torch.tensor) : (bs, nt, 3, H, W);
            src_smpl (torch.tensor): (bs, ns, 85);
            ref_smpl (torch.tesnor): (bs, nt, 85);
            src_mask (torch.tensor): (bs, ns, 3, H, W) or None, front is 0, background is 1;
            ref_mask (torch.tensor): (bs, nt, 3, H, W) or None, front is 0, background is 1;
            links_ids (torch.tensor): (bs, ns + nt, number of verts, 2);
            offsets (torch.tensor) : (bs, nv, 3) or 0;
            temporal (bool): if true, then it will calculate the temporal warping flow, otherwise Ttt will be None

        Returns:
            input_G_bg  (torch.tensor) :  (bs, ns, 4, H, W)
            input_G_src (torch.tensor) :  (bs, ns, 6, H, W)
            input_G_tsf (torch.tensor) :  (bs, nt, 3, H, W)
            Tst         (torch.tensor) :  (bs, nt, ns, H, W, 2)
            Ttt         (torch.tensor) :  (bs, nt - 1, H, W, 2) if temporal is True else return None

        """
        bs, ns, _, h, w = src_img.shape
        bs, nt = ref_img.shape[0:2]

        input_G_bg, input_G_src, input_G_tsf, Tst, Ttt, uv_img, src_info, ref_info = super().forward(
            src_img, ref_img, src_smpl, ref_smpl, src_mask, ref_mask,
            links_ids=links_ids, offsets=offsets, temporal=temporal
        )

        if src_mask is None:
            src_mask = src_info["cond"][:, -1:]
        else:
            src_mask = src_info["masks"]

        if ref_mask is None:
            tsf_mask = ref_info["cond"][:, -1:]
        else:
            tsf_mask = ref_info["masks"]

        src_mask = morph(src_mask, ks=self._opt.ft_ks, mode="erode")
        tsf_mask = morph(tsf_mask, ks=self._opt.ft_ks, mode="erode")

        src_mask = src_mask.view(bs, ns, 1, h, w)
        tsf_mask = tsf_mask.view(bs, nt, 1, h, w)

        head_bbox = self.cal_head_bbox_by_kps(ref_info["j2d"])
        body_bbox = self.cal_body_bbox_by_kps(ref_info["j2d"])

        return input_G_bg, input_G_src, input_G_tsf, Tst, Ttt, src_mask, tsf_mask, head_bbox, body_bbox, uv_img
    def make_morph_f2pts(self, f2pts, fim, human_sil, erode_ks=0):
        """

        Args:
            f2pts (torch.Tensor): (bs, nf, 3, 2)
            fim (torch.Tensor): (bs, h, w)
            human_sil (torch.Tensor): (bs, 1, h, w)
            erode_ks (int)

        Returns:
            morphed_f2pts (torch.Tensor): (bs, nf, 3, 2)
        """

        bs = f2pts.shape[0]

        # (bs, 1, h, w)
        if erode_ks > 0:
            human_sil = morph(human_sil, ks=erode_ks, mode="erode")

        fbc = self.render.compute_barycenter(f2pts)
        morphed_f2pts = []
        for i in range(bs):
            m_f2pts = self.morph_f2pts(f2pts[i], fbc[i], fim[i],
                                       human_sil[i][0])
            morphed_f2pts.append(m_f2pts)

        morphed_f2pts = torch.stack(morphed_f2pts, dim=0)

        return morphed_f2pts
    def make_uv_img(self, src_img, src_info):
        """
        Args:
            src_img (torch.tensor): (bs, ns, 3, h, w)
            src_info (dict):

        Returns:
            merge_uv (torch.tensor): (bs, 3, h, w)
        """

        bs, ns, _, h, w = src_img.shape
        bsxns = bs * ns

        ## previous
        # only_vis_src_f2pts = src_info["only_vis_f2pts"]
        # Ts2uv = self.render.cal_bc_transform(only_vis_src_f2pts, self.uv_fim[0:bsxns], self.uv_wim[0:bsxns])
        # src_warp_to_uv = F.grid_sample(src_img.view(bs * ns, 3, h, w), Ts2uv)
        # vis_warp_to_uv = F.grid_sample(self.one_map, Ts2uv)
        # merge_uv = torch.sum(src_warp_to_uv.view(bs, ns, -1, h, w), dim=1) / (
        #     torch.sum(vis_warp_to_uv.view(bs, ns, -1, h, w), dim=1) + 1e-5)

        ## current
        uv_fim = self.uv_fim[0:bsxns]
        uv_wim = self.uv_wim[0:bsxns]
        one_map = self.one_map[0:bsxns]
        only_vis_src_f2pts = src_info["only_vis_obj_f2pts"]
        src_f2pts = src_info["obj_f2pts"]
        only_vis_Ts2uv = self.render.cal_bc_transform(only_vis_src_f2pts,
                                                      uv_fim, uv_wim)
        Ts2uv = self.render.cal_bc_transform(src_f2pts, uv_fim, uv_wim)

        src_warp_to_uv = F.grid_sample(src_img.view(bs * ns, 3, h, w),
                                       Ts2uv).view(bs, ns, -1, h, w)
        vis_warp_to_uv = F.grid_sample(one_map, only_vis_Ts2uv)

        # TODO, here ks=13 is hyper-parameter, might need to set it to the configuration.
        vis_warp_to_uv = morph(vis_warp_to_uv, ks=13,
                               mode="dilate").view(bs, ns, -1, h, w)

        vis_sum = torch.sum(vis_warp_to_uv[:, 1:], dim=1)
        temp = torch.sum(src_warp_to_uv[:, 1:] * vis_warp_to_uv[:, 1:],
                         dim=1) / (vis_sum + 1e-5)

        vis_front = vis_warp_to_uv[:, 0]
        vis_other = (vis_sum >= 1).float()

        front_invisible = (1 - vis_front) * vis_other
        merge_uv = src_warp_to_uv[:, 0] * (
            1 - front_invisible) + temp * front_invisible

        # merge_uv = src_warp_to_uv[:, 0]
        # noisy = torch.randn((bs, 3, h, w), dtype=torch.float32).to(src_img.device)
        # merge_uv = 0.5 * merge_uv + 0.5 * noisy
        # merge_uv = torch.clamp(merge_uv, min=-1.0, max=1.0)

        return merge_uv
    def make_bg_inputs(self, src_img, src_info):
        # bg input
        src_cond = src_info["cond"]
        if "masks" in src_info:
            bg_mask = src_info["masks"]
        else:
            bg_mask = src_cond[:, -1:, :, :]

        src_bg_mask = morph(bg_mask, ks=self._opt.bg_ks, mode="erode")
        input_G_bg = torch.cat([src_img * src_bg_mask, src_bg_mask], dim=1)

        return input_G_bg
Example #6
0
    def solve(self, obs, visualizer=None, visual_poses=None):
        """
        Args:
            obs (dict): observations contains:
                --sil:
                --cam:
                --pose:
                --shape:
            visualizer:
            visual_poses:

        Returns:

        """

        print("{} use the parse observations to tune the offsets...".format(
            self.__class__.__name__))

        with torch.no_grad():
            obs_sil = torch.tensor(obs["sil"]).float().to(self.device)
            obs_cam = torch.tensor(obs["cam"]).float().to(self.device)
            obs_pose = torch.tensor(obs["pose"]).float().to(self.device)
            obs_shape = torch.tensor(obs["shape"]).float().to(self.device)
            obs_sil = morph(obs_sil, ks=3, mode="dilate")
            obs_sil = morph(obs_sil, ks=5, mode="erode")
            obs_sil.squeeze_(dim=1)

            bs = obs_cam.shape[0]
            init_verts, _, _ = self.smpl(obs_shape,
                                         obs_pose,
                                         offsets=0,
                                         get_skin=True)
            faces = self.render.smpl_faces.repeat(bs, 1, 1)
            nv = init_verts.shape[1]

        offsets = nn.Parameter(torch.zeros((nv, 3)).to(self.device))

        # total_steps = 500
        # init_lr = 0.0002
        # alpha_reg = 1000

        total_steps = 500
        init_lr = 0.0001
        alpha_reg = 10000

        optimizer = torch.optim.Adam([offsets], lr=init_lr)
        crt_sil = nn.MSELoss()

        if visualizer is not None:
            textures = self.render.color_textures().repeat(bs, 1, 1, 1, 1, 1)
            textures = textures.to(self.device)

            visual_poses = torch.tensor(visual_poses).float().to(self.device)
            num_visuals = visual_poses.shape[0]

        for i in tqdm(range(total_steps)):
            verts, joints, Rs = self.smpl(obs_shape.detach(),
                                          obs_pose.detach(),
                                          offsets=offsets,
                                          get_skin=True)
            rd_sil = self.render.render_silhouettes(obs_cam.detach(),
                                                    verts,
                                                    faces=faces.detach())
            loss = crt_sil(rd_sil,
                           obs_sil) + alpha_reg * torch.mean(offsets**2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if visualizer is not None and i % 10 == 0:
                with torch.no_grad():
                    ids = np.random.choice(num_visuals, bs)
                    rand_pose = visual_poses[ids]
                    verts, joints, Rs = self.smpl(obs_shape,
                                                  rand_pose,
                                                  offsets=offsets,
                                                  get_skin=True)
                    rd, _ = self.visual_render.render(obs_cam,
                                                      verts,
                                                      textures,
                                                      faces=faces,
                                                      get_fim=False)
                    visualizer.vis_named_img("rd_sil", rd_sil)
                    visualizer.vis_named_img("obs_sil", obs_sil)
                    visualizer.vis_named_img("render", rd)

                print("step = {}, loss = {:.6f}".format(i, loss.item()))

        return offsets
    def add_rendered_f2verts_fim_wim(self,
                                     smpl_info,
                                     use_morph=False,
                                     get_uv_info=True):
        """
        Args:
            smpl_info (dict): the smpl information contains:
                --cam (torch.Tensor):
                --verts (torch.Tensor):

            use_morph (bool): use morphing strategy to adjust the f2pts to segmentation observation,
                it might be used to process the source information.

            get_uv_info (bool): get the information for UV, it might be used to process the source information;

        Returns:
            smpl_info (dict):
                --cam   (torch.Tensor): (bs * nt, 3),
                --verts (torch.Tensor): (bs * nt, 6890, 3),
                --j2d   (torch.Tensor): (bs * nt, 19, 2),
                --cond  (torch.Tensor): (bs * nt, 3, h, w),
                --fim   (torch.Tensor): (bs * nt, h, w),
                --wim   (torch.Tensor): (bs * nt, h, w, 3),
                --f2pts (torch.Tensor): (bs * nt, 13776, 3, 2)
                --only_vis_f2pts (torch.tensor): (bs * nt, 13776, 3, 2)
                --obj_fim              (torch.tensor): (bs * nt, h, w),
                --obj_wim              (torch.tensor): (bs * nt, h, w, 3),
                --obj_f2pts            (torch.tensor): (bs * nt, 13776, 3, 2)
                --only_vis_obj_f2pts   (torch.tensor): (bs * nt, 13776, 3, 2)
        """

        f2pts, fim, wim = self.render.render_fim_wim(
            cam=smpl_info["cam"], vertices=smpl_info["verts"], smpl_faces=True)
        cond, _ = self.render.encode_fim(smpl_info["cam"],
                                         smpl_info["verts"],
                                         fim=fim,
                                         transpose=True)

        if use_morph:
            if "masks" in smpl_info:
                human_sil = 1 - smpl_info["masks"]
            else:
                human_sil = 1 - cond[:, -1:]

            # TODO, here ks=3 is hyper-parameter, might need to set it to the configuration.
            smpl_info["confidant_sil"] = morph(human_sil, ks=3, mode="erode")

            # TODO, here ks=51 is hyper-parameter, might need to set it to the configuration.
            smpl_info["outpad_sil"] = morph(
                ((human_sil + 1 - cond[:, -1:]) > 0).float(),
                ks=51,
                mode="dilate")
            # f2pts = self.make_morph_f2pts(f2pts, fim, smpl_info["human_sil"], erode_ks=0)

        only_vis_f2pts = self.render.get_vis_f2pts(f2pts, fim)
        smpl_info["f2pts"] = f2pts
        smpl_info["only_vis_f2pts"] = only_vis_f2pts
        smpl_info["cond"] = cond
        smpl_info["fim"] = fim
        smpl_info["wim"] = wim

        if get_uv_info:
            obj_f2pts, obj_fim, obj_wim = self.render.render_fim_wim(
                cam=smpl_info["cam"],
                vertices=smpl_info["verts"],
                smpl_faces=False)

            # if use_morph:
            #     obj_f2pts = self.make_morph_f2pts(obj_f2pts, obj_fim, smpl_info["human_sil"], erode_ks=0)

            only_vis_obj_f2pts = self.render.get_vis_f2pts(obj_f2pts, obj_fim)
            smpl_info["obj_f2pts"] = obj_f2pts
            smpl_info["only_vis_obj_f2pts"] = only_vis_obj_f2pts
            smpl_info["obj_fim"] = obj_fim
            smpl_info["obj_wim"] = obj_wim

        return smpl_info