Пример #1
0
 def prepare_poses(self, data):
     gt_part = part_model_batch_to_part(
         cvt_torch(data['meta']['nocs2camera'], self.device),
         self.num_parts, self.device)
     init_part = add_noise_to_part_dof(gt_part, self.pose_perturb_cfg)
     if 'crop_pose' in data['meta']:
         crop_pose = part_model_batch_to_part(
             cvt_torch(data['meta']['crop_pose'], self.device),
             self.num_parts, self.device)
         for key in ['translation', 'scale']:
             init_part[key] = crop_pose[key]
     return gt_part, init_part
Пример #2
0
def eval_data(name, data, obj_info):
    poses, corners = cvt_torch(data['pred']['poses'],
                               'cpu'), cvt_torch(data['pred']['corners'],
                                                 'cpu')
    gt_poses, gt_corners = cvt_torch(data['gt']['poses'],
                                     'cpu'), cvt_torch(data['gt']['corners'],
                                                       'cpu')

    error_dict = {}
    sym = obj_info['sym']
    rigid = obj_info['num_parts'] == 1

    for i in range(len(poses)):
        if i == 0:  # the first frame's pose is given by initialization
            continue
        key = f'{name}_{i}'
        _, per_diff = eval_part_full(gt_poses[i],
                                     poses[i],
                                     per_instance=True,
                                     yaxis_only=sym)
        error_dict[key] = {
            key: float(value.numpy())
            for key, value in per_diff.items()
        }
        _, per_iou = eval_single_part_iou(
            gt_corners.unsqueeze(0),
            corners[i].unsqueeze(0),
            {key: value.unsqueeze(0)
             for key, value in gt_poses[i].items()},
            {key: value.unsqueeze(0)
             for key, value in poses[i].items()},
            separate='both',
            nocs=rigid,
            sym=sym)
        per_iou = {
            f'iou_{j}': float(per_iou['iou'][j])
            for j in range(len(per_iou['iou']))
        }
        error_dict[key].update(per_iou)

        if not rigid:
            joint_state = get_joint_state(obj_info, poses[i])
            gt_joint_state = get_joint_state(obj_info, gt_poses[i])

            joint_diff = np.abs(joint_state - gt_joint_state)
            error_dict[key].update({
                f'theta_diff_{j}': joint_diff[j]
                for j in range(len(joint_diff))
            })

    return error_dict
Пример #3
0
    def prepare_data(self, data):
        gt_part, init_part = self.prepare_poses(data)

        input = {
            'points': data['points'],
            'points_mean': data['meta']['points_mean'],
            'nocs': data['nocs'],
            'state': {
                'part': init_part
            },
            'gt_part': gt_part
        }

        input = cvt_torch(input, self.device)
        input['meta'] = data['meta']
        input['labels'] = data['labels'].long().to(self.device)

        part_pose = input['state']['part']
        canon_pose = {
            key: part_pose[key].reshape(
                (-1, ) + part_pose[key].shape[2:])  # [B, P, x] --> [B * P, x]
            for key in ['rotation', 'translation', 'scale']
        }

        input['canon_pose'] = canon_pose

        batch_size = len(input['gt_part']['scale'])
        part_delta = compute_parts_delta_pose(
            input['state']['part'], input['gt_part'], {
                key:
                value.reshape((batch_size, self.num_parts) + value.shape[1:])
                for key, value in canon_pose.items()
            })
        input['root_delta'] = part_delta
        self.feed_dict = input
Пример #4
0
    def convert_subseq_frame_data(self, data):
        gt_part = part_model_batch_to_part(
            cvt_torch(data['meta']['nocs2camera'], self.device),
            self.num_parts, self.cfg['device'])
        input = {
            'points': data['points'],
            'points_mean': data['meta']['points_mean'],
            'gt_part': gt_part
        }

        if 'nocs' in data:
            input['npcs'] = data['nocs']
        input = cvt_torch(input, self.device)
        input['meta'] = data['meta']
        if 'labels' in data:
            input['labels'] = data['labels'].long().to(self.device)
        return input
Пример #5
0
    def convert_init_frame_data(self, frame):
        feed_frame = {}
        for key, item in frame.items():
            if key not in ['meta', 'labels', 'points', 'nocs']:
                continue
            if key in ['meta']:
                pass
            elif key in ['labels']:
                item = item.long().to(self.device)
            else:
                item = item.float().to(self.device)
            feed_frame[key] = item
        gt_part = part_model_batch_to_part(
            cvt_torch(frame['meta']['nocs2camera'], self.device),
            self.num_parts, self.cfg['device'])
        feed_frame.update({'gt_part': gt_part})

        return feed_frame
Пример #6
0
    def forward(self, save=False):
        self.timer.tick()

        pred_poses = []
        gt_part = self.feed_dict[0]['gt_part']
        if self.gt_init:
            pred_poses.append(gt_part)
        else:
            part = add_noise_to_part_dof(gt_part, self.pose_perturb_cfg)
            if 'crop_pose' in self.feed_dict[0]['meta']:
                crop_pose = part_model_batch_to_part(
                    cvt_torch(self.feed_dict[0]['meta']['crop_pose'],
                              self.device), self.num_parts, self.device)
                for key in ['translation', 'scale']:
                    part[key] = crop_pose[key]
            pred_poses.append(part)

        self.timer.tick()

        time_dict = {'crop': 0.0, 'npcs_net': 0.0, 'rot_all': 0.0}

        frame_nums = []
        npcs_pred = []
        with torch.no_grad():
            for i, input in enumerate(self.feed_dict):
                frame_nums.append([
                    path.split('.')[-2].split('/')[-1]
                    for path in input['meta']['path']
                ])
                if i == 0:
                    npcs_pred.append(None)
                    continue
                perturbed_part = add_noise_to_part_dof(
                    self.feed_dict[i - 1]['gt_part'], self.pose_perturb_cfg)
                if 'crop_pose' in self.feed_dict[i]['meta']:
                    crop_pose = part_model_batch_to_part(
                        cvt_torch(self.feed_dict[i]['meta']['crop_pose'],
                                  self.device), self.num_parts, self.device)
                    for key in ['translation', 'scale']:
                        perturbed_part[key] = crop_pose[key]

                last_pose = {
                    key: value.clone()
                    for key, value in pred_poses[-1].items()
                }

                self.timer.tick()
                if self.nocs_otf:
                    center = last_pose['translation'].reshape(
                        3).detach().cpu().numpy()  # [3]
                    scale = last_pose['scale'].reshape(1).detach().cpu().item()
                    depth_path = input['meta']['ori_path'][0]
                    category, instance = input['meta']['path'][0].split(
                        '/')[-4:-2]
                    pre_fetched = input['meta']['pre_fetched']
                    pre_fetched = {
                        key: value.reshape(value.shape[1:])
                        for key, value in pre_fetched.items()
                    }

                    pose = {
                        key:
                        value.squeeze(0).squeeze(0).detach().cpu().numpy()
                        for key, value in input['gt_part'].items()
                    }
                    full_data = full_data_from_depth_image(
                        depth_path,
                        category,
                        instance,
                        center,
                        self.radius * scale,
                        pose,
                        num_points=input['points'].shape[-1],
                        device=self.device,
                        mask_from_nocs2d=self.track_cfg['nocs2d_label'],
                        nocs2d_path=self.track_cfg['nocs2d_path'],
                        pre_fetched=pre_fetched)

                    points, nocs, labels = full_data['points'], full_data[
                        'nocs'], full_data['labels']

                    points = cvt_torch(points, self.device)
                    points -= self.npcs_feed_dict[i]['points_mean'].reshape(
                        1, 3)
                    input['points'] = points.transpose(-1,
                                                       -2).reshape(1, 3, -1)
                    input['labels'] = torch.tensor(full_data['labels']).to(
                        self.device).long().reshape(1, -1)
                    nocs = cvt_torch(nocs, self.device)
                    self.npcs_feed_dict[i]['points'] = input['points']
                    self.npcs_feed_dict[i]['labels'] = input['labels']
                    self.npcs_feed_dict[i]['nocs'] = nocs.transpose(
                        -1, -2).reshape(1, 3, -1)

                    time_dict['crop'] += self.timer.tick()

                state = {'part': last_pose}
                input['state'] = state

                npcs_canon_pose = {
                    key: last_pose[key][:, self.root].clone()
                    for key in ['rotation', 'translation', 'scale']
                }
                npcs_input = self.npcs_feed_dict[i]
                npcs_input['canon_pose'] = npcs_canon_pose
                npcs_input['init_part'] = last_pose
                cur_npcs_pred = self.npcs_net(
                    npcs_input)  # seg: [B, P, N], npcs: [B, P * 3, N]
                npcs_pred.append(cur_npcs_pred)
                pred_npcs, pred_seg = cur_npcs_pred['nocs'], cur_npcs_pred[
                    'seg']
                pred_npcs = pred_npcs.reshape(len(pred_npcs), self.num_parts,
                                              3, -1)  # [B, P, 3, N]
                pred_labels = torch.max(pred_seg,
                                        dim=-2)[1]  # [B, P, N] -> [B, N]

                time_dict['npcs_net'] += self.timer.tick()

                input['pred_labels'], input[
                    'pred_nocs'] = pred_labels, pred_npcs
                input['pred_label_conf'] = pred_seg[:, 0]  # [B, P, N]
                if self.track_cfg['gt_label'] or self.track_cfg['nocs2d_label']:
                    input['pred_labels'] = npcs_input['labels']

                pred_dict = self.net(input, test_mode=True)
                pred_poses.append(pred_dict['part'])

                time_dict['rot_all'] += self.timer.tick()

        self.pred_dict = {'poses': pred_poses, 'npcs_pred': npcs_pred}

        if save:
            gt_corners = self.feed_dict[0]['meta']['nocs_corners'].cpu().numpy(
            )
            corner_list = []
            for i, pred_pose in enumerate(self.pred_dict['poses']):
                if i == 0:
                    corner_list.append(None)
                    continue
                pred_labels = torch.max(self.pred_dict['npcs_pred'][i]['seg'],
                                        dim=-2)[1]  # [B, P, N] -> [B, N]
                pred_nocs = choose_coord_by_label(
                    self.pred_dict['npcs_pred'][i]['nocs'].transpose(-1, -2),
                    pred_labels)
                pred_corners = get_pred_nocs_corners(pred_labels, pred_nocs,
                                                     self.num_parts)
                corner_list.append(pred_corners)

            gt_poses = [{
                key: value.detach().cpu().numpy()
                for key, value in frame[f'gt_part'].items()
            } for frame in self.feed_dict]
            save_dict = {
                'pred': {
                    'poses': [{
                        key: value.detach().cpu().numpy()
                        for key, value in pred_pose.items()
                    } for pred_pose in pred_poses],
                    'corners':
                    corner_list
                },
                'gt': {
                    'poses': gt_poses,
                    'corners': gt_corners
                },
                'frame_nums': frame_nums
            }

            save_path = pjoin(self.cfg['experiment_dir'], 'results', 'data')
            ensure_dirs([save_path])
            for i, path in enumerate(self.feed_dict[0]['meta']['path']):
                instance, track_num = path.split('.')[-2].split('/')[-3:-1]
                with open(pjoin(save_path, f'{instance}_{track_num}.pkl'),
                          'wb') as f:
                    cur_dict = get_ith_from_batch(save_dict,
                                                  i,
                                                  to_single=False)
                    pickle.dump(cur_dict, f)