def prepare_poses(self, data): gt_part = part_model_batch_to_part( cvt_torch(data['meta']['nocs2camera'], self.device), self.num_parts, self.device) init_part = add_noise_to_part_dof(gt_part, self.pose_perturb_cfg) if 'crop_pose' in data['meta']: crop_pose = part_model_batch_to_part( cvt_torch(data['meta']['crop_pose'], self.device), self.num_parts, self.device) for key in ['translation', 'scale']: init_part[key] = crop_pose[key] return gt_part, init_part
def eval_data(name, data, obj_info): poses, corners = cvt_torch(data['pred']['poses'], 'cpu'), cvt_torch(data['pred']['corners'], 'cpu') gt_poses, gt_corners = cvt_torch(data['gt']['poses'], 'cpu'), cvt_torch(data['gt']['corners'], 'cpu') error_dict = {} sym = obj_info['sym'] rigid = obj_info['num_parts'] == 1 for i in range(len(poses)): if i == 0: # the first frame's pose is given by initialization continue key = f'{name}_{i}' _, per_diff = eval_part_full(gt_poses[i], poses[i], per_instance=True, yaxis_only=sym) error_dict[key] = { key: float(value.numpy()) for key, value in per_diff.items() } _, per_iou = eval_single_part_iou( gt_corners.unsqueeze(0), corners[i].unsqueeze(0), {key: value.unsqueeze(0) for key, value in gt_poses[i].items()}, {key: value.unsqueeze(0) for key, value in poses[i].items()}, separate='both', nocs=rigid, sym=sym) per_iou = { f'iou_{j}': float(per_iou['iou'][j]) for j in range(len(per_iou['iou'])) } error_dict[key].update(per_iou) if not rigid: joint_state = get_joint_state(obj_info, poses[i]) gt_joint_state = get_joint_state(obj_info, gt_poses[i]) joint_diff = np.abs(joint_state - gt_joint_state) error_dict[key].update({ f'theta_diff_{j}': joint_diff[j] for j in range(len(joint_diff)) }) return error_dict
def prepare_data(self, data): gt_part, init_part = self.prepare_poses(data) input = { 'points': data['points'], 'points_mean': data['meta']['points_mean'], 'nocs': data['nocs'], 'state': { 'part': init_part }, 'gt_part': gt_part } input = cvt_torch(input, self.device) input['meta'] = data['meta'] input['labels'] = data['labels'].long().to(self.device) part_pose = input['state']['part'] canon_pose = { key: part_pose[key].reshape( (-1, ) + part_pose[key].shape[2:]) # [B, P, x] --> [B * P, x] for key in ['rotation', 'translation', 'scale'] } input['canon_pose'] = canon_pose batch_size = len(input['gt_part']['scale']) part_delta = compute_parts_delta_pose( input['state']['part'], input['gt_part'], { key: value.reshape((batch_size, self.num_parts) + value.shape[1:]) for key, value in canon_pose.items() }) input['root_delta'] = part_delta self.feed_dict = input
def convert_subseq_frame_data(self, data): gt_part = part_model_batch_to_part( cvt_torch(data['meta']['nocs2camera'], self.device), self.num_parts, self.cfg['device']) input = { 'points': data['points'], 'points_mean': data['meta']['points_mean'], 'gt_part': gt_part } if 'nocs' in data: input['npcs'] = data['nocs'] input = cvt_torch(input, self.device) input['meta'] = data['meta'] if 'labels' in data: input['labels'] = data['labels'].long().to(self.device) return input
def convert_init_frame_data(self, frame): feed_frame = {} for key, item in frame.items(): if key not in ['meta', 'labels', 'points', 'nocs']: continue if key in ['meta']: pass elif key in ['labels']: item = item.long().to(self.device) else: item = item.float().to(self.device) feed_frame[key] = item gt_part = part_model_batch_to_part( cvt_torch(frame['meta']['nocs2camera'], self.device), self.num_parts, self.cfg['device']) feed_frame.update({'gt_part': gt_part}) return feed_frame
def forward(self, save=False): self.timer.tick() pred_poses = [] gt_part = self.feed_dict[0]['gt_part'] if self.gt_init: pred_poses.append(gt_part) else: part = add_noise_to_part_dof(gt_part, self.pose_perturb_cfg) if 'crop_pose' in self.feed_dict[0]['meta']: crop_pose = part_model_batch_to_part( cvt_torch(self.feed_dict[0]['meta']['crop_pose'], self.device), self.num_parts, self.device) for key in ['translation', 'scale']: part[key] = crop_pose[key] pred_poses.append(part) self.timer.tick() time_dict = {'crop': 0.0, 'npcs_net': 0.0, 'rot_all': 0.0} frame_nums = [] npcs_pred = [] with torch.no_grad(): for i, input in enumerate(self.feed_dict): frame_nums.append([ path.split('.')[-2].split('/')[-1] for path in input['meta']['path'] ]) if i == 0: npcs_pred.append(None) continue perturbed_part = add_noise_to_part_dof( self.feed_dict[i - 1]['gt_part'], self.pose_perturb_cfg) if 'crop_pose' in self.feed_dict[i]['meta']: crop_pose = part_model_batch_to_part( cvt_torch(self.feed_dict[i]['meta']['crop_pose'], self.device), self.num_parts, self.device) for key in ['translation', 'scale']: perturbed_part[key] = crop_pose[key] last_pose = { key: value.clone() for key, value in pred_poses[-1].items() } self.timer.tick() if self.nocs_otf: center = last_pose['translation'].reshape( 3).detach().cpu().numpy() # [3] scale = last_pose['scale'].reshape(1).detach().cpu().item() depth_path = input['meta']['ori_path'][0] category, instance = input['meta']['path'][0].split( '/')[-4:-2] pre_fetched = input['meta']['pre_fetched'] pre_fetched = { key: value.reshape(value.shape[1:]) for key, value in pre_fetched.items() } pose = { key: value.squeeze(0).squeeze(0).detach().cpu().numpy() for key, value in input['gt_part'].items() } full_data = full_data_from_depth_image( depth_path, category, instance, center, self.radius * scale, pose, num_points=input['points'].shape[-1], device=self.device, mask_from_nocs2d=self.track_cfg['nocs2d_label'], nocs2d_path=self.track_cfg['nocs2d_path'], pre_fetched=pre_fetched) points, nocs, labels = full_data['points'], full_data[ 'nocs'], full_data['labels'] points = cvt_torch(points, self.device) points -= self.npcs_feed_dict[i]['points_mean'].reshape( 1, 3) input['points'] = points.transpose(-1, -2).reshape(1, 3, -1) input['labels'] = torch.tensor(full_data['labels']).to( self.device).long().reshape(1, -1) nocs = cvt_torch(nocs, self.device) self.npcs_feed_dict[i]['points'] = input['points'] self.npcs_feed_dict[i]['labels'] = input['labels'] self.npcs_feed_dict[i]['nocs'] = nocs.transpose( -1, -2).reshape(1, 3, -1) time_dict['crop'] += self.timer.tick() state = {'part': last_pose} input['state'] = state npcs_canon_pose = { key: last_pose[key][:, self.root].clone() for key in ['rotation', 'translation', 'scale'] } npcs_input = self.npcs_feed_dict[i] npcs_input['canon_pose'] = npcs_canon_pose npcs_input['init_part'] = last_pose cur_npcs_pred = self.npcs_net( npcs_input) # seg: [B, P, N], npcs: [B, P * 3, N] npcs_pred.append(cur_npcs_pred) pred_npcs, pred_seg = cur_npcs_pred['nocs'], cur_npcs_pred[ 'seg'] pred_npcs = pred_npcs.reshape(len(pred_npcs), self.num_parts, 3, -1) # [B, P, 3, N] pred_labels = torch.max(pred_seg, dim=-2)[1] # [B, P, N] -> [B, N] time_dict['npcs_net'] += self.timer.tick() input['pred_labels'], input[ 'pred_nocs'] = pred_labels, pred_npcs input['pred_label_conf'] = pred_seg[:, 0] # [B, P, N] if self.track_cfg['gt_label'] or self.track_cfg['nocs2d_label']: input['pred_labels'] = npcs_input['labels'] pred_dict = self.net(input, test_mode=True) pred_poses.append(pred_dict['part']) time_dict['rot_all'] += self.timer.tick() self.pred_dict = {'poses': pred_poses, 'npcs_pred': npcs_pred} if save: gt_corners = self.feed_dict[0]['meta']['nocs_corners'].cpu().numpy( ) corner_list = [] for i, pred_pose in enumerate(self.pred_dict['poses']): if i == 0: corner_list.append(None) continue pred_labels = torch.max(self.pred_dict['npcs_pred'][i]['seg'], dim=-2)[1] # [B, P, N] -> [B, N] pred_nocs = choose_coord_by_label( self.pred_dict['npcs_pred'][i]['nocs'].transpose(-1, -2), pred_labels) pred_corners = get_pred_nocs_corners(pred_labels, pred_nocs, self.num_parts) corner_list.append(pred_corners) gt_poses = [{ key: value.detach().cpu().numpy() for key, value in frame[f'gt_part'].items() } for frame in self.feed_dict] save_dict = { 'pred': { 'poses': [{ key: value.detach().cpu().numpy() for key, value in pred_pose.items() } for pred_pose in pred_poses], 'corners': corner_list }, 'gt': { 'poses': gt_poses, 'corners': gt_corners }, 'frame_nums': frame_nums } save_path = pjoin(self.cfg['experiment_dir'], 'results', 'data') ensure_dirs([save_path]) for i, path in enumerate(self.feed_dict[0]['meta']['path']): instance, track_num = path.split('.')[-2].split('/')[-3:-1] with open(pjoin(save_path, f'{instance}_{track_num}.pkl'), 'wb') as f: cur_dict = get_ith_from_batch(save_dict, i, to_single=False) pickle.dump(cur_dict, f)