def make_morph_image(self, src_img, src_info, erode_ks=3, dilate_ks=11): """ Args: src_img (torch.Tensor): (bs * ns, 3, h, w) src_info (dict): erode_ks (int): dilate_ks (int): Returns: all_morph_imgs (torch.cuda.Tensor): (bs * ns, 3, h, w) """ bs = src_img.shape[0] if erode_ks > 0: confidant_sil = morph(src_info["confidant_sil"], ks=erode_ks, mode="erode") else: confidant_sil = src_info["confidant_sil"] if dilate_ks > 0: outpad_sil = morph(src_info["outpad_sil"], ks=dilate_ks, mode="dilate") else: outpad_sil = src_info["outpad_sil"] # outpad_sil = ((confidant_sil + (1 - src_info["cond"][:, -1:])) > 0).float() # outpad_sil = morph(outpad_sil, ks=dilate_ks, mode="dilate") canny_filter = CannyFilter(device=src_img.device).to(src_img.device) blurred, grad_x, grad_y, grad_magnitude, grad_orientation, thin_edges = canny_filter( confidant_sil, 0.1, 0.9, True) uncertain_sil = outpad_sil * (1 - confidant_sil) # visualizer.vis_named_img("silhouette", 1 - src_info["masks"]) # visualizer.vis_named_img("confidant_sil", confidant_sil) # visualizer.vis_named_img("outpad_sil", outpad_sil) # visualizer.vis_named_img("uncertain_sil", uncertain_sil) all_morph_imgs = [] for i in range(bs): boundaries_points = thin_edges[i, 0].nonzero(as_tuple=False) uncertain_points = uncertain_sil[i, 0].nonzero(as_tuple=False) weights, nn_pts, ids = self.cal_top_k_ids(uncertain_points, boundaries_points) morph_img = self.morph_image(src_img[i], weights, uncertain_points, nn_pts, confidant_sil[i]) all_morph_imgs.append(morph_img) all_morph_imgs = torch.stack(all_morph_imgs, dim=0) return all_morph_imgs
def forward(self, src_img, ref_img, src_smpl, ref_smpl, src_mask=None, ref_mask=None, links_ids=None, offsets=0, temporal=False): """ Args: src_img (torch.tensor) : (bs, ns, 3, H, W); ref_img (torch.tensor) : (bs, nt, 3, H, W); src_smpl (torch.tensor): (bs, ns, 85); ref_smpl (torch.tesnor): (bs, nt, 85); src_mask (torch.tensor): (bs, ns, 3, H, W) or None, front is 0, background is 1; ref_mask (torch.tensor): (bs, nt, 3, H, W) or None, front is 0, background is 1; links_ids (torch.tensor): (bs, ns + nt, number of verts, 2); offsets (torch.tensor) : (bs, nv, 3) or 0; temporal (bool): if true, then it will calculate the temporal warping flow, otherwise Ttt will be None Returns: input_G_bg (torch.tensor) : (bs, ns, 4, H, W) input_G_src (torch.tensor) : (bs, ns, 6, H, W) input_G_tsf (torch.tensor) : (bs, nt, 3, H, W) Tst (torch.tensor) : (bs, nt, ns, H, W, 2) Ttt (torch.tensor) : (bs, nt - 1, H, W, 2) if temporal is True else return None """ bs, ns, _, h, w = src_img.shape bs, nt = ref_img.shape[0:2] input_G_bg, input_G_src, input_G_tsf, Tst, Ttt, uv_img, src_info, ref_info = super().forward( src_img, ref_img, src_smpl, ref_smpl, src_mask, ref_mask, links_ids=links_ids, offsets=offsets, temporal=temporal ) if src_mask is None: src_mask = src_info["cond"][:, -1:] else: src_mask = src_info["masks"] if ref_mask is None: tsf_mask = ref_info["cond"][:, -1:] else: tsf_mask = ref_info["masks"] src_mask = morph(src_mask, ks=self._opt.ft_ks, mode="erode") tsf_mask = morph(tsf_mask, ks=self._opt.ft_ks, mode="erode") src_mask = src_mask.view(bs, ns, 1, h, w) tsf_mask = tsf_mask.view(bs, nt, 1, h, w) head_bbox = self.cal_head_bbox_by_kps(ref_info["j2d"]) body_bbox = self.cal_body_bbox_by_kps(ref_info["j2d"]) return input_G_bg, input_G_src, input_G_tsf, Tst, Ttt, src_mask, tsf_mask, head_bbox, body_bbox, uv_img
def make_morph_f2pts(self, f2pts, fim, human_sil, erode_ks=0): """ Args: f2pts (torch.Tensor): (bs, nf, 3, 2) fim (torch.Tensor): (bs, h, w) human_sil (torch.Tensor): (bs, 1, h, w) erode_ks (int) Returns: morphed_f2pts (torch.Tensor): (bs, nf, 3, 2) """ bs = f2pts.shape[0] # (bs, 1, h, w) if erode_ks > 0: human_sil = morph(human_sil, ks=erode_ks, mode="erode") fbc = self.render.compute_barycenter(f2pts) morphed_f2pts = [] for i in range(bs): m_f2pts = self.morph_f2pts(f2pts[i], fbc[i], fim[i], human_sil[i][0]) morphed_f2pts.append(m_f2pts) morphed_f2pts = torch.stack(morphed_f2pts, dim=0) return morphed_f2pts
def make_uv_img(self, src_img, src_info): """ Args: src_img (torch.tensor): (bs, ns, 3, h, w) src_info (dict): Returns: merge_uv (torch.tensor): (bs, 3, h, w) """ bs, ns, _, h, w = src_img.shape bsxns = bs * ns ## previous # only_vis_src_f2pts = src_info["only_vis_f2pts"] # Ts2uv = self.render.cal_bc_transform(only_vis_src_f2pts, self.uv_fim[0:bsxns], self.uv_wim[0:bsxns]) # src_warp_to_uv = F.grid_sample(src_img.view(bs * ns, 3, h, w), Ts2uv) # vis_warp_to_uv = F.grid_sample(self.one_map, Ts2uv) # merge_uv = torch.sum(src_warp_to_uv.view(bs, ns, -1, h, w), dim=1) / ( # torch.sum(vis_warp_to_uv.view(bs, ns, -1, h, w), dim=1) + 1e-5) ## current uv_fim = self.uv_fim[0:bsxns] uv_wim = self.uv_wim[0:bsxns] one_map = self.one_map[0:bsxns] only_vis_src_f2pts = src_info["only_vis_obj_f2pts"] src_f2pts = src_info["obj_f2pts"] only_vis_Ts2uv = self.render.cal_bc_transform(only_vis_src_f2pts, uv_fim, uv_wim) Ts2uv = self.render.cal_bc_transform(src_f2pts, uv_fim, uv_wim) src_warp_to_uv = F.grid_sample(src_img.view(bs * ns, 3, h, w), Ts2uv).view(bs, ns, -1, h, w) vis_warp_to_uv = F.grid_sample(one_map, only_vis_Ts2uv) # TODO, here ks=13 is hyper-parameter, might need to set it to the configuration. vis_warp_to_uv = morph(vis_warp_to_uv, ks=13, mode="dilate").view(bs, ns, -1, h, w) vis_sum = torch.sum(vis_warp_to_uv[:, 1:], dim=1) temp = torch.sum(src_warp_to_uv[:, 1:] * vis_warp_to_uv[:, 1:], dim=1) / (vis_sum + 1e-5) vis_front = vis_warp_to_uv[:, 0] vis_other = (vis_sum >= 1).float() front_invisible = (1 - vis_front) * vis_other merge_uv = src_warp_to_uv[:, 0] * ( 1 - front_invisible) + temp * front_invisible # merge_uv = src_warp_to_uv[:, 0] # noisy = torch.randn((bs, 3, h, w), dtype=torch.float32).to(src_img.device) # merge_uv = 0.5 * merge_uv + 0.5 * noisy # merge_uv = torch.clamp(merge_uv, min=-1.0, max=1.0) return merge_uv
def make_bg_inputs(self, src_img, src_info): # bg input src_cond = src_info["cond"] if "masks" in src_info: bg_mask = src_info["masks"] else: bg_mask = src_cond[:, -1:, :, :] src_bg_mask = morph(bg_mask, ks=self._opt.bg_ks, mode="erode") input_G_bg = torch.cat([src_img * src_bg_mask, src_bg_mask], dim=1) return input_G_bg
def solve(self, obs, visualizer=None, visual_poses=None): """ Args: obs (dict): observations contains: --sil: --cam: --pose: --shape: visualizer: visual_poses: Returns: """ print("{} use the parse observations to tune the offsets...".format( self.__class__.__name__)) with torch.no_grad(): obs_sil = torch.tensor(obs["sil"]).float().to(self.device) obs_cam = torch.tensor(obs["cam"]).float().to(self.device) obs_pose = torch.tensor(obs["pose"]).float().to(self.device) obs_shape = torch.tensor(obs["shape"]).float().to(self.device) obs_sil = morph(obs_sil, ks=3, mode="dilate") obs_sil = morph(obs_sil, ks=5, mode="erode") obs_sil.squeeze_(dim=1) bs = obs_cam.shape[0] init_verts, _, _ = self.smpl(obs_shape, obs_pose, offsets=0, get_skin=True) faces = self.render.smpl_faces.repeat(bs, 1, 1) nv = init_verts.shape[1] offsets = nn.Parameter(torch.zeros((nv, 3)).to(self.device)) # total_steps = 500 # init_lr = 0.0002 # alpha_reg = 1000 total_steps = 500 init_lr = 0.0001 alpha_reg = 10000 optimizer = torch.optim.Adam([offsets], lr=init_lr) crt_sil = nn.MSELoss() if visualizer is not None: textures = self.render.color_textures().repeat(bs, 1, 1, 1, 1, 1) textures = textures.to(self.device) visual_poses = torch.tensor(visual_poses).float().to(self.device) num_visuals = visual_poses.shape[0] for i in tqdm(range(total_steps)): verts, joints, Rs = self.smpl(obs_shape.detach(), obs_pose.detach(), offsets=offsets, get_skin=True) rd_sil = self.render.render_silhouettes(obs_cam.detach(), verts, faces=faces.detach()) loss = crt_sil(rd_sil, obs_sil) + alpha_reg * torch.mean(offsets**2) optimizer.zero_grad() loss.backward() optimizer.step() if visualizer is not None and i % 10 == 0: with torch.no_grad(): ids = np.random.choice(num_visuals, bs) rand_pose = visual_poses[ids] verts, joints, Rs = self.smpl(obs_shape, rand_pose, offsets=offsets, get_skin=True) rd, _ = self.visual_render.render(obs_cam, verts, textures, faces=faces, get_fim=False) visualizer.vis_named_img("rd_sil", rd_sil) visualizer.vis_named_img("obs_sil", obs_sil) visualizer.vis_named_img("render", rd) print("step = {}, loss = {:.6f}".format(i, loss.item())) return offsets
def add_rendered_f2verts_fim_wim(self, smpl_info, use_morph=False, get_uv_info=True): """ Args: smpl_info (dict): the smpl information contains: --cam (torch.Tensor): --verts (torch.Tensor): use_morph (bool): use morphing strategy to adjust the f2pts to segmentation observation, it might be used to process the source information. get_uv_info (bool): get the information for UV, it might be used to process the source information; Returns: smpl_info (dict): --cam (torch.Tensor): (bs * nt, 3), --verts (torch.Tensor): (bs * nt, 6890, 3), --j2d (torch.Tensor): (bs * nt, 19, 2), --cond (torch.Tensor): (bs * nt, 3, h, w), --fim (torch.Tensor): (bs * nt, h, w), --wim (torch.Tensor): (bs * nt, h, w, 3), --f2pts (torch.Tensor): (bs * nt, 13776, 3, 2) --only_vis_f2pts (torch.tensor): (bs * nt, 13776, 3, 2) --obj_fim (torch.tensor): (bs * nt, h, w), --obj_wim (torch.tensor): (bs * nt, h, w, 3), --obj_f2pts (torch.tensor): (bs * nt, 13776, 3, 2) --only_vis_obj_f2pts (torch.tensor): (bs * nt, 13776, 3, 2) """ f2pts, fim, wim = self.render.render_fim_wim( cam=smpl_info["cam"], vertices=smpl_info["verts"], smpl_faces=True) cond, _ = self.render.encode_fim(smpl_info["cam"], smpl_info["verts"], fim=fim, transpose=True) if use_morph: if "masks" in smpl_info: human_sil = 1 - smpl_info["masks"] else: human_sil = 1 - cond[:, -1:] # TODO, here ks=3 is hyper-parameter, might need to set it to the configuration. smpl_info["confidant_sil"] = morph(human_sil, ks=3, mode="erode") # TODO, here ks=51 is hyper-parameter, might need to set it to the configuration. smpl_info["outpad_sil"] = morph( ((human_sil + 1 - cond[:, -1:]) > 0).float(), ks=51, mode="dilate") # f2pts = self.make_morph_f2pts(f2pts, fim, smpl_info["human_sil"], erode_ks=0) only_vis_f2pts = self.render.get_vis_f2pts(f2pts, fim) smpl_info["f2pts"] = f2pts smpl_info["only_vis_f2pts"] = only_vis_f2pts smpl_info["cond"] = cond smpl_info["fim"] = fim smpl_info["wim"] = wim if get_uv_info: obj_f2pts, obj_fim, obj_wim = self.render.render_fim_wim( cam=smpl_info["cam"], vertices=smpl_info["verts"], smpl_faces=False) # if use_morph: # obj_f2pts = self.make_morph_f2pts(obj_f2pts, obj_fim, smpl_info["human_sil"], erode_ks=0) only_vis_obj_f2pts = self.render.get_vis_f2pts(obj_f2pts, obj_fim) smpl_info["obj_f2pts"] = obj_f2pts smpl_info["only_vis_obj_f2pts"] = only_vis_obj_f2pts smpl_info["obj_fim"] = obj_fim smpl_info["obj_wim"] = obj_wim return smpl_info