def preprocess_df(self, fit_infos): fit_infos_df_keys = ["frame_idx", "video_id", "img_path"] df_dicts = [] for fit_info in fit_infos: df_dict = {key: fit_info[key] for key in fit_infos_df_keys} df_dict["obj_paths"] = [fit_info["obj_path"]] df_dict["boxes"] = npt.numpify(fit_info["boxes"]) df_dicts.append(df_dict) return pd.DataFrame(df_dicts)
def add_lines(ax, lines, over_lines=None, labels=None, overlay_alpha=0.6): colors = iter(cycle(cycle_colors)) for line_idx, line in enumerate(lines): if labels is None: label = f"{line_idx}" else: label = labels[line_idx] color = next(colors) ax.plot(npt.numpify(line), label=label, c=color) if over_lines is not None: over_line = over_lines[line_idx] ax.plot( npt.numpify(over_line), "-", label=label, c=color, alpha=overlay_alpha, )
def add_pointsrow( axes, tensors, overlay_list=None, overlay_colors=["c", "k", "b"], row_idx=0, row_nb=1, point_s=1, point_c="r", axis_equal=True, show_axis=True, alpha=1, over_alpha=1, ): point_nb = len(tensors) points = [conv2img(tens) for tens in tensors] for point_idx, point in enumerate(points): ax = vizmp.get_axis( axes, row_idx=row_idx, col_idx=point_idx, row_nb=row_nb, col_nb=point_nb, ) pts = npt.numpify(points[point_idx]) if point_c == "rainbow": pt_nb = pts.shape[0] point_c = cm.rainbow(np.linspace(0, 1, pt_nb)) ax.scatter(pts[:, 0], pts[:, 1], s=point_s, c=point_c, alpha=alpha) if overlay_list is not None: for overlay, over_color in zip(overlay_list, overlay_colors): over_pts = npt.numpify(overlay[point_idx]) ax.scatter( over_pts[:, 0], over_pts[:, 1], s=point_s, c=over_color, alpha=over_alpha, ) if axis_equal: ax.axis("equal") if not show_axis: ax.axis("off")
def save_scene_clip(self, locations=["tmp.webm"], fps=5, imgs=None): if isinstance(locations, str): locations = [locations] with torch.no_grad(): viz = self.forward()["scene_viz_rend"] # Stack views horizontally img_seqs = torch.cat(viz, 2) img_seqs = npt.numpify(img_seqs)[:, :, :, :3] if imgs is not None: # BGR2RGB imgs = ( np.stack([npt.numpify(img)[:, :, ::-1] for img in imgs])[:, :, :, :3] / 255) img_seqs = np.concatenate([imgs, img_seqs], 2) # Renderings have alpha channel and are scaled in [0, 1] clip = mpy.ImageSequenceClip([frm * 255 for frm in img_seqs], fps=fps) for location in locations: clip.write_videofile(location)
def make_alpha(img, bg_val=1): img = npt.numpify(img) # Extend single-channel image to 3 channels if img.ndim == 2: img = img[:, :, None] if img.shape[2] == 1: img = img.repeat(3, 2) # Alpha as locations where image is != from background color mask = (img[:, :, :3].sum(-1) != (bg_val * 3)).astype(img.dtype) img = np.concatenate([img, mask[:, :, None]], 2) return img
def imagify(tensor, normalize_colors=True): tensor = npt.numpify(tensor) # Scale to [0, 1] if normalize_colors: tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) # Put channels last if tensor.ndim == 3 and tensor.shape[0] <= 4: tensor = tensor.transpose(1, 2, 0) if tensor.ndim == 3 and tensor.shape[2] == 1: tensor = tensor[:, :, 0] elif tensor.ndim == 3 and tensor.shape[2] < 3: tensor = np.concatenate( [tensor, 0.5 * np.ones_like(tensor)[:, :, : 3 - tensor.shape[2]]], 2, ) if tensor.ndim == 3 and (tensor.shape[2] != 4): tensor = tensor[:, :, :3] return tensor
def add_imgrow( axes, tensors, row_idx=0, row_nb=1, overlays=None, points=None, over_alpha=0.5, show_axis=False, point_s=1, point_c="r", overval=1, ): img_nb = len(tensors) imgs = [conv2img(tens) for tens in tensors] for img_idx, img in enumerate(imgs): ax = vizmp.get_axis( axes, row_idx=row_idx, col_idx=img_idx, row_nb=row_nb, col_nb=img_nb, ) if points is not None: pts = npt.numpify(points[img_idx]) if pts is not None: ax.scatter(pts[:, 0], pts[:, 1], s=point_s, c=point_c) if overlays is not None: overlay_img = conv2img(overlays[img_idx]) if overlay_img.ndim == 2: overlay_mask = (overlay_img != overval)[:, :] else: overlay_mask = (overlay_img.sum(2) != overval * 3)[ :, :, np.newaxis ] over_img = ( overlay_mask * img + (1 - overlay_mask) * img * over_alpha ) ax.imshow(over_img) else: ax.imshow(img) if not show_axis: ax.axis("off")
def rtsmooth(measurements, dt=0.02, order=2): """ Args: measurements (np.array): (time, measurements_dim) Returns: data (np.array): (time, measurements_dim) """ measure_dim = measurements.shape[1] kf = kinematic_kf(dim=measure_dim, order=order, dt=dt) # print(kf.F[:3, :3]) # State transition # kf.P is ordered with [2, 2] or [3, 3] blocks for each dimension # (2 if 1st order - constant velocity, 3 if 2nd order - constant acceleration) kf.P[::order + 1, ::order + 1] *= 1 kf.P *= 10 kf.Q[::order + 1, ::order + 1] *= 1 mu, cov, _, _ = kf.batch_filter(npt.numpify(measurements)) smoothed, _, _, _ = kf.rts_smoother(mu, cov) print(smoothed.shape) return smoothed[:, ::order + 1, 0]
def fitobj2mask( masks, bboxes, obj_paths, z_off=0.5, radius=0.1, faces_per_pixel=1, lr=0.01, loss_type="l2", iters=100, viz_step=1, save_folder="tmp/", viz_rows=12, crop_box=True, crop_size=(200, 200), rot_nb=1, ): # Initialize logging info opts = { "z_off": z_off, "loss_type": loss_type, "iters": iters, "radius": radius, "lr": lr, "obj_paths": obj_paths, "faces_per_pix": faces_per_pixel, } results = {"opts": opts} save_folder = Path(save_folder) print(f"Saving to {save_folder}") metrics = defaultdict(list) batch_size = len(obj_paths) # Load normalized object batch_faces = [] batch_verts = [] for obj_path in obj_paths: verts_loc, faces_idx, _ = py3dload_obj(obj_path) faces = faces_idx.verts_idx batch_faces.append(faces.cuda()) verts = normalize.normalize_verts(verts_loc, radius).cuda() batch_verts.append(verts) batch_verts = torch.stack(batch_verts) batch_faces = torch.stack(batch_faces) # Dummy intrinsic camera height, width = masks[0].shape focal = min(masks[0].shape) camintr = ( torch.Tensor( [[focal, 0, width // 2], [0, focal, height // 2], [0, 0, 1]] ) .cuda() .unsqueeze(0) .repeat(batch_size, 1, 1) ) if crop_box: adaptive_loss = AdaptiveLossFunction( num_dims=crop_size[0] * crop_size[1], float_dtype=np.float32, device="cuda:0", ) else: adaptive_loss = AdaptiveLossFunction( num_dims=height * width, float_dtype=np.float32, device="cuda:0" ) # Prepare rigid parameters if rot_nb > 1: rot_mats = [special_ortho_group.rvs(3) for _ in range(rot_nb)] rot_vecs = torch.Tensor( [np.linalg.svd(rot_mat)[0][:2].reshape(-1) for rot_mat in rot_mats] ) rot_vec = rot_vecs.repeat(batch_size, 1).cuda() # Ordering b1 rot1, b1 rot2, ..., b2 rot1, ... else: rot_vec = torch.Tensor( [[1, 0, 0, 0, 1, 0] for _ in range(batch_size)] ).cuda() bboxes_tight = torch.stack(bboxes) # trans = ops3d.trans_init_from_boxes(bboxes, camintr, (z_off, z_off)).cuda() trans = ops3d.trans_init_from_boxes_autodepth( bboxes_tight, camintr, batch_verts, z_guess=z_off ).cuda() # Repeat to match rots trans = repeatdim(trans, rot_nb, 1) bboxes = boxutils.preprocess_boxes(bboxes_tight, padding=10, squarify=True) if crop_box: camintr_crop = camutils.get_K_crop_resize(camintr, bboxes, crop_size) camintr_crop = repeatdim(camintr_crop, rot_nb, 1) trans.requires_grad = True rot_vec.requires_grad = True optim_params = [rot_vec, trans] if "adapt" in loss_type: optim_params = optim_params + list(adaptive_loss.parameters()) optimizer = torch.optim.Adam([rot_vec, trans], lr=lr) ref_masks = torch.stack(masks).cuda() if crop_box: ref_masks = cropping.crops(ref_masks.float(), bboxes, crop_size)[:, 0] # Prepare reference mask if "dtf" in loss_type: target_masks = torch.stack( [torch.Tensor(dtf.distance_transform(mask)) for mask in ref_masks] ).cuda() else: target_masks = ref_masks ref_masks = repeatdim(ref_masks, rot_nb, 1) target_masks = repeatdim(target_masks, rot_nb, 1) batch_verts = repeatdim(batch_verts, rot_nb, 1) batch_faces = repeatdim(batch_faces, rot_nb, 1) col_nb = 5 fig_res = 1.5 # Aggregate images clip_data = [] for iter_idx in tqdm(range(iters)): rot_mat = rotations.compute_rotation_matrix_from_ortho6d(rot_vec) optim_verts = batch_verts.bmm(rot_mat) + trans.unsqueeze(1) if crop_box: rendres = batch_render( optim_verts, batch_faces, K=camintr_crop, image_sizes=[(crop_size[1], crop_size[0])], mode="silh", faces_per_pixel=faces_per_pixel, ) else: rendres = batch_render( optim_verts, batch_faces, K=camintr, image_sizes=[(width, height)], mode="silh", faces_per_pixel=faces_per_pixel, ) optim_masks = rendres[:, :, :, -1] mask_diff = ref_masks - optim_masks mask_l2 = (mask_diff ** 2).mean() mask_l1 = mask_diff.abs().mean() mask_iou = lyiou.batch_mask_iou( (optim_masks > 0), (ref_masks > 0) ).mean() metrics["l1"].append(mask_l1.item()) metrics["l2"].append(mask_l2.item()) metrics["mask"].append(mask_iou.item()) optim_mask_diff = target_masks - optim_masks if "l2" in loss_type: loss = (optim_mask_diff ** 2).mean() elif "l1" in loss_type: loss = optim_mask_diff.abs().mean() elif "adapt" in loss_type: loss = adaptive_loss.lossfun( optim_mask_diff.view(rot_nb * batch_size, -1) ).mean() optimizer.zero_grad() loss.backward() optimizer.step() if iter_idx % viz_step == 0: row_idxs = np.linspace( 0, batch_size * rot_nb - 1, viz_rows ).astype(np.int) row_nb = viz_rows fig, axes = plt.subplots( row_nb, col_nb, figsize=(int(col_nb * fig_res), int(row_nb * fig_res)), ) for row_idx in range(row_nb): show_idx = row_idxs[row_idx] ax = vizmp.get_axis( axes, row_idx, 0, row_nb=row_nb, col_nb=col_nb ) ax.imshow(npt.numpify(optim_masks[show_idx])) ax.set_title("optim mask") ax = vizmp.get_axis( axes, row_idx, 1, row_nb=row_nb, col_nb=col_nb ) ax.imshow(npt.numpify(ref_masks[show_idx])) ax.set_title("ref mask") ax = vizmp.get_axis( axes, row_idx, 2, row_nb=row_nb, col_nb=col_nb ) ax.imshow( npt.numpify(ref_masks[show_idx] - optim_masks[show_idx]), vmin=-1, vmax=1, ) ax.set_title("ref masks diff") ax = vizmp.get_axis( axes, row_idx, 3, row_nb=row_nb, col_nb=col_nb ) ax.imshow(npt.numpify(target_masks[show_idx]), vmin=-1, vmax=1) ax.set_title("target mask") ax = vizmp.get_axis( axes, row_idx, 4, row_nb=row_nb, col_nb=col_nb ) ax.imshow( npt.numpify( target_masks[show_idx] - optim_masks[show_idx] ), vmin=-1, vmax=1, ) ax.set_title("masks diff") viz_folder = save_folder / "viz" viz_folder.mkdir(parents=True, exist_ok=True) data = vizmp.fig2np(fig) clip_data.append(data) fig.savefig(viz_folder / f"{iter_idx:04d}.png") clip = mpy.ImageSequenceClip(clip_data, fps=4) clip.write_videofile(str(viz_folder / "out.mp4")) clip.write_videofile(str(viz_folder / "out.webm")) results["metrics"] = metrics return results
def conv2img(tensor): if len(tensor.shape) == 3: if tensor.shape[0] in [3, 4]: tensor = tensor.permute(1, 2, 0) tensor = npt.numpify(tensor) return tensor
def regress(self, img_original, hand_bbox_list, add_margin=False, debug=True, viz_path="tmp.png", K=None): """ args: img_original: original raw image (BGR order by using cv2.imread) hand_bbox_list: [ dict( left_hand = [x0, y0, w, h] or None right_hand = [x0, y0, w, h] or None ) ... ] add_margin: whether to do add_margin given the hand bbox outputs: To be filled Note: Output element can be None. This is to keep the same output size with input bbox """ pred_output_list = list() hand_bbox_list_processed = list() for hand_bboxes in hand_bbox_list: if hand_bboxes is None: # Should keep the same size with bbox size pred_output_list.append(None) hand_bbox_list_processed.append(None) continue pred_output = dict(left_hand=None, right_hand=None) hand_bboxes_processed = dict(left_hand=None, right_hand=None) for hand_type in hand_bboxes: bbox = hand_bboxes[hand_type] if bbox is None: continue else: img_cropped, norm_img, bbox_scale_ratio, bbox_processed = \ self.__process_hand_bbox(img_original, hand_bboxes[hand_type], hand_type, add_margin) hand_bboxes_processed[hand_type] = bbox_processed with torch.no_grad(): # pred_rotmat, pred_betas, pred_camera = self.model_regressor(norm_img.to(self.device)) self.model_regressor.set_input_imgonly( {'img': norm_img.unsqueeze(0)}) self.model_regressor.test() pred_res = self.model_regressor.get_pred_result() ##Output cam = pred_res['cams'][0, :] #scale, tranX, tranY pred_verts_origin = pred_res['pred_verts'][0] faces = self.model_regressor.right_hand_faces_local pred_pose = pred_res['pred_pose_params'].copy() pred_joints = pred_res['pred_joints_3d'].copy()[0] if hand_type == 'left_hand': cam[1] *= -1 pred_verts_origin[:, 0] *= -1 faces = faces[:, ::-1] pred_pose[:, 1::3] *= -1 pred_pose[:, 2::3] *= -1 pred_joints[:, 0] *= -1 pred_output[hand_type] = dict() pred_output[hand_type][ 'pred_vertices_smpl'] = pred_verts_origin # SMPL-X hand vertex in bbox space pred_output[hand_type][ 'pred_joints_smpl'] = pred_joints pred_output[hand_type]['faces'] = faces pred_output[hand_type][ 'bbox_scale_ratio'] = bbox_scale_ratio pred_output[hand_type]['bbox_top_left'] = np.array( bbox_processed[:2]) pred_output[hand_type]['pred_camera'] = cam pred_output[hand_type]['img_cropped'] = img_cropped # Get global camera global_cams = local_to_global_cam( bbox_wh_to_xy( np.array([list(bbox_processed) ]).astype(np.float)), cam[None, :], max(img_original.shape)) pred_output[hand_type]["global_cams"] = global_cams # pred hand pose & shape params & hand joints 3d pred_output[hand_type][ 'pred_hand_pose'] = pred_pose # (1, 48): (1, 3) for hand rotation, (1, 45) for finger pose. # Recover PCA components if hand_type == "right_hand": comps = torch.Tensor(self.mano_model.rh_mano_pca. full_hand_components) elif hand_type == "left_hand": comps = torch.Tensor(self.mano_model.lh_mano_pca. full_hand_components) pca_comps = torch.einsum('bi,ij->bj', [ torch.Tensor(pred_pose[:, 3:]), torch.inverse(comps) ]) pred_output[hand_type]['pred_hand_betas'] = pred_res[ 'pred_shape_params'] # (1, 10) pred_output[hand_type][ 'pred_pca_pose'] = pca_comps.numpy() hand_side = hand_type.split("_")[0] pred_output[hand_type]["hand_side"] = hand_side pred_output[hand_type][ 'mano_trans'] = self.mano_model.get_mano_trans( mano_pose=pred_pose[:, 3:][0], rot=pred_pose[:, :3][0], betas=pred_res["pred_shape_params"][0], ref_verts=pred_verts_origin, side=hand_side) #Convert vertices into bbox & image space cam_scale = cam[0] cam_trans = cam[1:] vert_smplcoord = pred_verts_origin.copy() joints_smplcoord = pred_joints.copy() vert_bboxcoord = convert_smpl_to_bbox( vert_smplcoord, cam_scale, cam_trans, bAppTransFirst=True) # SMPL space -> bbox space joints_bboxcoord = convert_smpl_to_bbox( joints_smplcoord, cam_scale, cam_trans, bAppTransFirst=True) # SMPL space -> bbox space hand_boxScale_o2n = pred_output[hand_type][ 'bbox_scale_ratio'] hand_bboxTopLeft = pred_output[hand_type][ 'bbox_top_left'] vert_imgcoord = convert_bbox_to_oriIm( vert_bboxcoord, hand_boxScale_o2n, hand_bboxTopLeft, img_original.shape[1], img_original.shape[0]) pred_output[hand_type][ 'pred_vertices_img'] = vert_imgcoord joints_imgcoord = convert_bbox_to_oriIm( joints_bboxcoord, hand_boxScale_o2n, hand_bboxTopLeft, img_original.shape[1], img_original.shape[0]) pred_output[hand_type][ 'pred_joints_img'] = joints_imgcoord if K is not None: K = npt.numpify(K) K_viz = K.copy() # r_hand, _ = cv2.Rodrigues(r_vec) K_viz[:2] = K_viz[:2] / max(img_original.shape) ortho_scale_pixels = global_cams[0, 0] / 2 * max( img_original.shape ) # scale of verts_pixel / pred_verts_origin ortho_trans_pixels = ( global_cams[0, 1:] + 1 / global_cams[0, 0]) * ortho_scale_pixels t_vec = camconvs.weakcam2persptrans( np.concatenate([ np.array([ortho_scale_pixels]), ortho_trans_pixels ], 0) / max(img_original.shape), K_viz)[:, None] r_hand = np.eye(3) # Sanity check nptype = np.float32 proj2d, camverts = project.proj2d( pred_verts_origin.astype(nptype), K.astype(nptype), rot=r_hand.astype(nptype), trans=t_vec[:, 0].astype(nptype)) pred_output[hand_type]['camverts'] = camverts pred_output[hand_type][ 'perspective_trans'] = t_vec.transpose() pred_output[hand_type]['perspective_rot'] = r_hand pred_output_list.append(pred_output) hand_bbox_list_processed.append(hand_bboxes_processed) assert len(hand_bbox_list_processed) == len(hand_bbox_list) return pred_output_list
def preprocess_supervision(self, fit_infos, grab_objects=False): # Initialize tar reader tareader = TarReader() # sample_masks = [] sample_verts = [] sample_confs = [] sample_imgs = [] ref_hand_rends = [] # Regions of interest containing hands and objects roi_bboxes = [] roi_valid_masks = [] # Crops of hand and object masks sample_hand_masks = [] sample_objs_masks = [] # Create dummy intrinsic camera for supervision rendering focal = 200 camintr = np.array([[focal, 0, 456 // 2], [0, focal, 256 // 2], [0, 0, 1]]) camintr_th = torch.Tensor(camintr).unsqueeze(0) # Modelling hand color print("Preprocessing sequence") for fit_info in tqdm(fit_infos): img = tareader.read_tar_frame(fit_info["img_path"]) img_size = img.shape[:2] # height, width # img = cv2.imread(fit_info["img_path"]) hand_infos = fit_info["hands"] human_verts = np.zeros((self.smplx_vertex_nb, 3)) verts_confs = np.zeros((self.smplx_vertex_nb, )) # Get hand vertex refernces poses img_hand_verts = [] img_hand_faces = [] for side in hand_infos: hand_info = hand_infos[side] hand_verts = hand_info["verts"] # Aggregate hand vertices and faces for rendering img_hand_verts.append( lift_verts( torch.Tensor(hand_verts).unsqueeze(0), camintr_th)) img_hand_faces.append( torch.Tensor(hand_info["faces"]).unsqueeze(0)) corresp = self.mano_corresp[f"{side}_hand"] human_verts[corresp] = hand_verts verts_confs[corresp] = 1 has_hands = len(img_hand_verts) > 0 # render reference hands if has_hands: img_hand_verts, img_hand_faces, _ = catmesh.batch_cat_meshes( img_hand_verts, img_hand_faces) with torch.no_grad(): res = py3drendutils.batch_render( img_hand_verts.cuda(), img_hand_faces.cuda(), faces_per_pixel=2, color=(1, 0.4, 0.6), K=camintr_th, image_sizes=[(img_size[1], img_size[0])], mode="rgb", shading="soft", ) ref_hand_rends.append(npt.numpify(res[0, :, :, :3])) hand_mask = npt.numpify(res[0, :, :, 3]) else: ref_hand_rends.append(np.zeros(img.shape) + 1) hand_mask = np.zeros((img.shape[:2])) obj_masks = fit_info["masks"] # GrabCut objects has_objs = len(obj_masks) > 0 if has_objs: obj_masks_aggreg = (npt.numpify(torch.stack(obj_masks)).sum(0) > 0) else: obj_masks_aggreg = np.zeros_like(hand_mask) # Detect if some pseudo ground truth masks exist has_both_masks = (hand_mask.max() > 0) and (obj_masks_aggreg.max() > 0) if has_both_masks: xs, ys = np.where((hand_mask + obj_masks_aggreg) > 0) # Compute region of interest which contains hands and objects roi_bbox = boxutils.squarify_box( [xs.min(), ys.min(), xs.max(), ys.max()], scale_factor=1.5) else: rad = min(img.shape[:2]) roi_bbox = [0, 0, rad, rad] roi_bbox = [int(val) for val in roi_bbox] roi_bboxes.append(roi_bbox) img_crop = cropping.crop_cv2(img, roi_bbox, resize=self.crop_size) # Compute region of crop which belongs to original image (vs paddding) roi_valid_mask = cropping.crop_cv2(np.ones(img.shape[:2]), roi_bbox, resize=self.crop_size) roi_valid_masks.append(roi_valid_mask) # Crop hand and object image hand_mask_crop = (cropping.crop_cv2( hand_mask, roi_bbox, resize=self.crop_size) > 0).astype(np.int) objs_masks_crop = cropping.crop_cv2( obj_masks_aggreg.astype(np.int), roi_bbox, resize=self.crop_size, ).astype(np.int) # Remove object region from hand mask hand_mask_crop[objs_masks_crop > 0] = 0 # Extract skeletons skel_objs_masks_crop = skeletonize(objs_masks_crop.astype( np.uint8)) skel_hand_mask_crop = skeletonize(hand_mask_crop.astype(np.uint8)) # Removing object region from hand can cancel out whole hand ! if has_both_masks and hand_mask_crop.max(): grabinfo = grabcut.grab_cut( img_crop.astype(np.uint8), mask=hand_mask_crop, bbox=roi_bbox, bgd_mask=skel_objs_masks_crop, fgd_mask=skel_hand_mask_crop, debug=self.debug, ) hand_mask = grabinfo["grab_mask"] hand_mask[objs_masks_crop > 0] = 0 else: hand_mask = hand_mask_crop sample_hand_masks.append(hand_mask) # Get crops of object masks obj_mask_crops = [] for obj_mask in obj_masks: obj_mask_crop = cropping.crop_cv2( npt.numpify(obj_mask).astype(np.int), roi_bbox, resize=self.crop_size, ) skel_obj_mask_crop = skeletonize(obj_mask_crop.astype( np.uint8)) if grab_objects: raise NotImplementedError( "Maybe needs also the skeleton of other objects" "to be labelled as background ?") grabinfo = grabcut.grab_cut( img_crop, mask=obj_mask_crop, bbox=roi_bbox, bgd_mask=skel_hand_mask_crop, fgd_mask=skel_obj_mask_crop, debug=self.debug, ) obj_mask_crop = grabinfo["grab_mask"] obj_mask_crops.append(obj_mask_crop) if len(obj_mask_crops): sample_objs_masks.append(np.stack(obj_mask_crops)) else: sample_objs_masks.append(np.zeros((1, rad, rad))) # Remove object region from hand mask # sample_masks.append(torch.stack(fit_info["masks"])) sample_verts.append(human_verts) sample_confs.append(verts_confs) sample_imgs.append(img) verts = torch.Tensor(np.stack(sample_verts)) links = [preprocess_links(info["links"]) for info in fit_infos] fit_data = { # "masks": torch.stack(sample_masks), "roi_bboxes": torch.Tensor(np.stack(roi_bboxes)), "roi_valid_masks": torch.Tensor(np.stack(roi_valid_masks)), "objs_masks_crops": torch.Tensor(np.stack(sample_objs_masks)), "hand_masks_crops": torch.Tensor(np.stack(sample_hand_masks)), "verts": verts, "verts_confs": torch.Tensor(np.stack(sample_confs)), "imgs": sample_imgs, "ref_hand_rends": ref_hand_rends, "links": links, "mano_corresp": self.mano_corresp, } return fit_data
row_nb = args.batch_size col_nb = 3 diffs = (rendres - img_th[:, :, :, :])[:, :, :, :3] if args.loss_type == "l1": loss = diffs.abs().mean() if args.loss_type == "l2": loss = (diffs ** 2).sum(-1).mean() # loss = (rendres - img_th).abs().mean() optimizer.zero_grad() loss.backward() print(loss) optimizer.step() if iter_idx % args.viz_step == 0: fig, axes = plt.subplots(row_nb, col_nb) for row_idx in range(row_nb): ax = vizmp.get_axis( axes, row_idx=row_idx, col_idx=0, row_nb=row_nb, col_nb=col_nb ) ax.imshow(egoviz.imagify(rendres[row_idx], normalize_colors=False)) ax = vizmp.get_axis( axes, row_idx=row_idx, col_idx=1, row_nb=row_nb, col_nb=col_nb ) ax.imshow( egoviz.imagify(img_th[row_idx][:, :], normalize_colors=False) ) ax = vizmp.get_axis( axes, row_idx=row_idx, col_idx=2, row_nb=row_nb, col_nb=col_nb ) ax.imshow(npt.numpify(diffs[row_idx])) fig.savefig(f"tmp_{iter_idx:04d}.png", bbox_inches="tight")
def compute_obj_mask_loss(self, scene_outputs, supervision): rend = scene_outputs["segm_rend"] device = rend.device gt_obj_masks = (supervision["objs_masks_crops"].permute(0, 2, 3, 1).to(device)) gt_hand_masks = supervision["hand_masks_crops"].to(device) if self.mask_mode == "segm": pred_masks = rend[:, :, :, :-1] # Remove alpha channel gt_masks = torch.cat([gt_hand_masks.unsqueeze(-1), gt_obj_masks], -1) sup_masks = (((gt_hand_masks.unsqueeze(-1).sum([1, 2, 3]) > 0) & (gt_obj_masks.sum([1, 2, 3]) > 0)).float().view( -1, 1, 1, 1)) optim_mask_diff = gt_masks - pred_masks optim_mask_diff = optim_mask_diff * sup_masks if self.loss_obj_mask == "l1": loss = optim_mask_diff.abs().mean() elif self.loss_obj_mask == "l2": loss = (optim_mask_diff**2).mean() elif self.loss_obj_mask == "adapt": loss = self.mask_adaptive_loss.lossfun( optim_mask_diff.view(gt_masks.shape[0], -1)).mean() masked_diffs = npt.numpify(gt_masks) - npt.numpify(pred_masks) elif self.mask_mode == "segmask": pred_masks = rend[:, :, :, :-1] # Remove alpha channel gt_masks = torch.cat([gt_hand_masks.unsqueeze(-1), gt_obj_masks], -1) # Get region to penalize by computing complementary from gt masks comp_obj_idxs = [[ idx for idx in range(self.obj_nb) if idx != obj_idx ] for obj_idx in range(self.obj_nb)] sup_mask = torch.cat( [ 1 - gt_masks[:, :, :, comp_idxs].sum( -1, keepdim=True).clamp(0, 1) for comp_idxs in comp_obj_idxs ], -1, ) sup_mask = ((gt_hand_masks.unsqueeze(-1).sum([1, 2, 3]) > 0) & (gt_obj_masks.sum([1, 2, 3]) > 0)).float().view( -1, 1, 1, 1) * sup_mask masked_diffs = sup_mask * (gt_masks - pred_masks) if self.loss_obj_mask == "l1": loss = masked_diffs.abs().sum() / sup_mask.sum() elif self.loss_obj_mask == "l2": loss = (masked_diffs**2).sum() / sup_mask.sum() elif self.loss_obj_mask == "adapt": loss = self.mask_adaptive_loss.lossfun( masked_diffs.view(sup_mask.shape[0], -1)).mean() elif self.mask_mode == "mask": pred_obj_masks = rend[:, :, :, 1:-1] obj_mask_diffs = gt_obj_masks[:, :, :, :] - pred_obj_masks if obj_mask_diffs.shape[-1] != 1: raise NotImplementedError("No handling of multiple objects") sup_mask = (1 - gt_hand_masks).unsqueeze(-1) # Zero supervision on frames which do not have both hand and masks sup_mask = ((gt_hand_masks.unsqueeze(-1).sum([1, 2, 3]) > 0) & (gt_obj_masks.sum([1, 2, 3]) > 0)).float().view( -1, 1, 1, 1) * sup_mask masked_diffs = sup_mask * obj_mask_diffs if self.loss_obj_mask == "l2": loss = (masked_diffs**2).sum() / sup_mask.sum() elif self.loss_obj_mask == "l1": loss = (masked_diffs.abs()).sum() / sup_mask.sum() elif self.loss_obj_mask == "adapt": loss = self.mask_adaptive_loss.lossfun( masked_diffs.view(sup_mask.shape[0], -1)).mean() return (loss, {"mask_diffs": masked_diffs})
def ego_viz( data, supervision, scene_outputs, loss_metas=None, fig_res=2, step_idx=0, save_folder="tmp", sample_nb=4, ): # segm_rend = npt.numpify(scene_outputs["segm_rend"]) viz_rends = [npt.numpify(rend) for rend in scene_outputs["scene_viz_rend"]] ref_hand_rends = supervision["ref_hand_rends"] col_nb = 3 + len(viz_rends) fig, axes = plt.subplots( sample_nb, col_nb, figsize=(int(3 / 2 * col_nb * fig_res), int(sample_nb * fig_res)), ) scene_size = len(supervision["imgs"]) sample_idxs = np.linspace(0, scene_size - 1, sample_nb).astype(np.int) mask_diffs = npt.numpify(loss_metas["mask_diffs"]) for row_idx, sample_idx in enumerate(sample_idxs): img = supervision["imgs"][sample_idx][:, :, ::-1] # obj_mask = supervision["masks"][sample_idx] sample_data = data[sample_idx] # Column 1: image and supervision ax = vizmp.get_axis(axes, row_idx=row_idx, col_idx=0, row_nb=sample_nb, col_nb=col_nb) ax.imshow(img) ax.axis("off") add_hands(ax, sample_data) # Column 2: Rendered prediction on top of GT hands ax = vizmp.get_axis(axes, row_idx=row_idx, col_idx=1, row_nb=sample_nb, col_nb=col_nb) ax.imshow(ref_hand_rends[sample_idx]) ax.imshow(make_alpha(viz_rends[0][sample_idx][:, :, :3]), alpha=0.8) ax.axis("off") # Column 3: Rendered Mask ax = vizmp.get_axis(axes, row_idx=row_idx, col_idx=2, row_nb=sample_nb, col_nb=col_nb) mask_diff = mask_diffs[sample_idx] mask_diff_img = imagify(mask_diff) ax.imshow(mask_diff_img) # Column 4+: Rendered prediction ax.axis("off") for view_idx, viz_rend in enumerate(viz_rends): ax = vizmp.get_axis( axes, row_idx=row_idx, col_idx=3 + view_idx, row_nb=sample_nb, col_nb=col_nb, ) if view_idx == 0: ax.imshow(img) alpha = 0.8 # Special treatment of first view, which is camera view else: alpha = 1 ax.imshow(viz_rend[sample_idx, :, :, :3], alpha=alpha) ax.axis("off") os.makedirs(save_folder, exist_ok=True) save_path = os.path.join(save_folder, f"tmp_{step_idx:04d}.png") fig.suptitle(f"optim iter : {step_idx}") fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0.1, wspace=0) fig.gca().xaxis.set_major_locator(ticker.NullLocator()) fig.gca().yaxis.set_major_locator(ticker.NullLocator()) fig.savefig(save_path, bbox_inches="tight") print(f"Saved to {save_path}") return save_path