Exemple #1
0
def _dump_vis(model, pred, gt, common,
              video_list, K_list, dist_list, M_list):
    from utils.plot_util import draw_skel
    from utils.StitchedImage import StitchedImage
    import utils.CamLib as cl
    from tqdm import tqdm

    # iterate frames
    for i, (_, fid) in tqdm(enumerate(common), desc='Dumping Samples', total=len(common)):
        # Accumulate frames
        merged_list = list()
        # inpaint pred/gt
        for K, dist, M, v in zip(K_list, dist_list, M_list, video_list):
            img = read_vid_frame(v, fid)
            uv_p = cl.project(cl.trafo_coords(pred[i], M), K, dist)
            img_p = draw_skel(img.copy(), model, uv_p, color_fixed='r', order='uv')
            uv_gt = cl.project(cl.trafo_coords(gt[i], M), K, dist)
            img_p = draw_skel(img_p, model, uv_gt, color_fixed='g', order='uv')

            merged_list.append(img_p)

        merged = StitchedImage(merged_list)
        p = os.path.join(os.path.dirname(video_list[0]), 'eval_vis_dump/%04d.png' % i)
        # cv2.imshow('img', merged.image)
        # cv2.waitKey()
        my_mkdir(p, is_file=True)
        cv2.imwrite(p, merged.image)
Exemple #2
0
def draw_trafo(img, M_trafo_w2l, M_cam_w2l, K_cam, dist_cam, linewidth=2, l=0.025):
    """ Draws a little coordinate frame into an image. """
    import utils.CamLib as cl

    M_trafo_l2w = np.linalg.inv(M_trafo_w2l)

    # points in local space we'd like to draw
    points_local = np.array([
        [0.0, 0.0, 0.0],  # origin
        [l, 0.0, 0.0],  # end x
        [0.0, l, 0.0],  #end y
        [0.0, 0.0, l]  #end z
    ])
    # trafo them to world space
    points_world = cl.trafo_coords(points_local, M_trafo_l2w)

    # transform points into image space
    p_cam = cl.trafo_coords(points_world, M_cam_w2l)
    p_uv = cl.project(p_cam, K_cam, dist_cam)
    p_uv = np.round(p_uv).astype(np.int32)

    # draw stuff
    bones, colors = [[0, 1], [0, 2], [0, 3]], [(0, 0, 255), (0, 255, 0), (255, 0, 0)]
    for b, c in zip(bones, colors):
        pid, cid = b
        img = cv2.line(img,
                       (p_uv[pid, 0], p_uv[pid, 1]),
                       (p_uv[cid, 0], p_uv[cid, 1]),
                       c, thickness=linewidth)
    return img
Exemple #3
0
    def postfunc_save_samples(self, trainer):
        """ Adds a visual sample to the summary writer. """
        from utils.plot_util import draw_hand
        import utils.CamLib as cl
        import numpy as np

        tmp = list()
        for img, cam_int, xyz_pred, xyz_gt in zip(
                trainer.fetches_v[data_t.image], trainer.fetches_v[data_t.K],
                trainer.fetches_v[data_t.pred_xyz],
                trainer.fetches_v[data_t.xyz]):
            # project
            uv_gt = cl.project(xyz_gt, cam_int)
            uv_pred = cl.project(xyz_pred, cam_int)

            img_rgb = ((img + 1.0) / 2.0 * 255).round().astype(np.uint8)
            img_p = draw_hand(img_rgb.copy(), uv_pred, order='uv')
            img_gt = draw_hand(img_rgb.copy(), uv_gt, order='uv')
            tmp.append(np.concatenate([img_p, img_gt], 1))

            # from utils.mpl_setup import plt_figure
            # plt, fig, axes = plt_figure(2)
            # axes[0].imshow(img_p)
            # axes[1].imshow(img_gt)
            # plt.show()

            if len(tmp) == trainer.config.save_sample_num:
                break

        summary_v = trainer.session.run(self.merged_vis_sum,
                                        {self.merged_vis: np.stack(tmp)})
        trainer.summary_writer.add_summary(summary_v, trainer.global_step_v)
        trainer.summary_writer.flush()
        print('Saved some samples.')
Exemple #4
0
def _calc_board_stats(points2d_obs,
                      model_point3d_coord_obs,
                      K,
                      dist,
                      img=None):
    success, r_rel, t_rel = cv2.solvePnP(np.expand_dims(
        model_point3d_coord_obs, 1),
                                         np.expand_dims(points2d_obs, 1),
                                         K,
                                         distCoeffs=dist,
                                         flags=cv2.SOLVEPNP_ITERATIVE)

    # get normal vector from tag rotation
    R, _ = cv2.Rodrigues(r_rel)
    n = np.matmul(R, np.array([0.0, 0.0, 1.0]))  # normal wrt camera
    n = np.clip(np.sum(n), -1.0, 1.0)
    angle = np.arccos(n) * 180.0 / np.pi

    # calculate points in camera frame
    M_w2c = np.concatenate([R, t_rel], 1)
    M_w2c = np.concatenate([M_w2c, np.array([[0.0, 0.0, 0.0, 1.0]])], 0)
    p3d = cl.trafo_coords(model_point3d_coord_obs, M_w2c)

    # reprojection error
    p2d_p = cl.project(p3d, K, dist)
    err = np.linalg.norm(p2d_p - points2d_obs, 2, -1)

    if img is not None:
        import matplotlib.pyplot as plt
        plt.imshow(img[:, :, ::-1])
        plt.plot(points2d_obs[:, 0], points2d_obs[:, 1], 'go')
        plt.plot(p2d_p[:, 0], p2d_p[:, 1], 'rx')
        plt.show()
    return angle, p3d[:, -1], err
Exemple #5
0
    def update_frame_view_points(self, use_labels=True):
        self.frame_view.clear()  # clear frame view
        k = self.keys[self.key_id]

        if k in self.keys_valid:
            # draw when there is a valid annotation

            this_pred = self.predictions[k]
            xyz = np.array(this_pred['kp_xyz'])[0]

            if k in self.label_tasks.keys() and use_labels:
                # use label task results if there are any
                this_pred = self.label_tasks[k]
                xyz = np.array(this_pred['kp_xyz'])

            for i, cid in enumerate(self.cam_range):
                # project into frame
                kp_uv = cl.project(cl.trafo_coords(xyz, self.M_list[i]),
                                   self.K_list[i], self.dist_list[i])

                for kp_id, uv in enumerate(kp_uv):
                    kp_name = self.config['keypoints'][kp_id]
                    self.frame_view.update_frame_keypoint(i,
                                                          kp_name,
                                                          uv,
                                                          is_scene_pos=False)

            # make point all not movable
            for item in self.frame_view.frame_keypoints.values():
                item.setFlag(QGraphicsItem.ItemIsMovable, False)
Exemple #6
0
def augment_warp(meta):
    tmp_img, tmp_K = list(), list()
    for i in range(len(meta.K_list)):
        assert meta.img_list[
            i].dtype == np.float32, 'Assumed datatype mismatch.'

        ## ROTATE AROUND PP
        ang_range = 15.0
        angle = np.random.rand() * ang_range - ang_range / 2.0
        pp = meta.K_list[i][:2, 2]
        M = cv2.getRotationMatrix2D((pp[0], pp[1]), angle=angle, scale=1.0)
        meta.img_list[i] = cv2.warpAffine(meta.img_list[i],
                                          M,
                                          meta.img_list[i].shape[:2][::-1],
                                          borderValue=(128, 128, 128))

        # compensate 3D
        angle *= np.pi / 180.0
        c, s = np.cos(angle), np.sin(angle)
        M_rot = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]])
        M_rot4 = np.eye(4)
        M_rot4[:3, :3] = M_rot
        meta.M_list[i] = np.matmul(np.linalg.inv(M_rot4), meta.M_list[i])

        # compensate 2D
        meta.uv[i] = _compensate2D(meta.uv[i], M_rot, pp)
        meta.uv_merged[i] = _compensate2D(meta.uv_merged[i], M_rot, pp)

        ## SCALE AND CROP
        # sample warp params
        t_f = 5  # displacement in pix plus/minus
        s_f = 0.2
        s = np.random.rand() * s_f + 1.0  # scale, only 'zoom in'
        t = np.random.rand(2) * 2 * t_f - t_f  # translation

        # how much the image center is translated by scaling (compensate so the central pixel stays the same)
        off = 0.5 * (s - 1.0) * np.array(meta.img_list[i].shape[:2][::-1])

        M = np.array([[s, 0.0, t[0] - off[0]], [0.0, s, t[1] - off[1]],
                      [0.0, 0.0, 1.0]])
        meta.img_list[i] = cv2.warpAffine(meta.img_list[i],
                                          M[:2, :],
                                          meta.img_list[i].shape[:2][::-1],
                                          flags=cv2.INTER_LINEAR,
                                          borderValue=(128, 128, 128))

        meta.K_list[i] = np.matmul(M, meta.K_list[i])
        meta.uv[i] = cl.trafo_coords(meta.uv[i], M)
        meta.uv_merged[i] = cl.trafo_coords(meta.uv_merged[i], M)
    return meta
Exemple #7
0
def _trafo2local(kp_xyz):
    """ Transforms global keypoints into a rat local coordinate frame.

        The rat local system is spanned by:
            - x: Animal right  (perpendicular to ground plane normal and body axis)
            - y: The body axis (defined by the vector from tail to a point between the ears)
            - z: Animal up (perpendicular to x and y)
        And located in the point midway between the two ear keypoints.
    """
    mid_pt = 0.5 * (kp_xyz[5] + kp_xyz[0])  # point between ears
    body_axis = mid_pt - kp_xyz[
        11]  # vector from tail to mid ears, 'animal forward'
    body_axis /= np.linalg.norm(body_axis, 2, -1, keepdims=True)

    ground_up = np.array([0.0, -1.0, 0.0])  # vector pointing up
    ground_up /= np.linalg.norm(ground_up, 2)

    animal_right = np.cross(
        body_axis, ground_up)  # pointing into the animals' right direction
    animal_up = np.cross(animal_right, body_axis)

    R = np.stack([animal_right, body_axis, animal_up], 0)  # rotation matrix
    M = np.eye(4)
    M[:3, :3] = R
    M[:3, -1:] = -np.matmul(R, np.reshape(mid_pt, [3, 1]))  # trans
    kp_xyz_local = cl.trafo_coords(kp_xyz, M)
    return kp_xyz_local
Exemple #8
0
def trafo_to_coord_frame(coord_def, xyz):
    """ Calculate the transformation matrix. """
    ori = coord_def['orientation']
    has_x = 'x' in ori.keys()
    has_y = 'y' in ori.keys()
    has_z = 'z' in ori.keys()
    assert sum([has_x, has_y, has_z
                ]) == 2, 'You get to chose exactly two axes, no more, no less.'

    if has_x and has_y:
        x_vec = _norm(_get_vec(ori['x'], xyz))
        y_vec = _norm(_get_vec(ori['y'], xyz))
        z_vec = _my_cross(x=x_vec, y=y_vec)

        if ori['prio'] == 'x':
            y_vec = _my_cross(x=x_vec, z=z_vec)
        elif ori['prio'] == 'y':
            x_vec = _my_cross(y=y_vec, z=z_vec)
        else:
            raise NotImplementedError('This should never happen.')

    elif has_x and has_z:
        x_vec = _norm(_get_vec(ori['x'], xyz))
        z_vec = _norm(_get_vec(ori['z'], xyz))
        y_vec = _my_cross(x=x_vec, z=z_vec)

        if ori['prio'] == 'x':
            z_vec = _my_cross(x=x_vec, y=y_vec)
        elif ori['prio'] == 'z':
            x_vec = _my_cross(y=y_vec, z=z_vec)
        else:
            raise NotImplementedError('This should never happen.')

    elif has_y and has_z:
        y_vec = _norm(_get_vec(ori['y'], xyz))
        z_vec = _norm(_get_vec(ori['z'], xyz))
        x_vec = _my_cross(y=y_vec, z=z_vec)

        if ori['prio'] == 'y':
            z_vec = _my_cross(x=x_vec, y=y_vec)
        elif ori['prio'] == 'z':
            y_vec = _my_cross(x=x_vec, z=z_vec)
        else:
            raise NotImplementedError('This should never happen.')

    else:
        raise NotImplementedError('This should never happen.')

    origin = _get_origin(coord_def['origin'], xyz)
    M = _to_mat(x_vec, y_vec, z_vec, origin)
    xyz_local = cl.trafo_coords(xyz, M)
    return xyz_local
Exemple #9
0
def read_img(metas):
    for meta in metas:
        if not all([os.path.exists(x) for x in meta.img_paths]):
            print('Path not found: ', meta.img_paths[0])

        # read all images
        meta.img_list = [
            cv2.imread(x).astype(np.float32) for x in meta.img_paths
        ]

        # read calibration
        calib = _get_calib(meta.calib_id, meta.calib_path)
        # dist = [np.array(calib[cam]['dist']) for cam in meta.cam_range]
        meta.K_list = [np.array(calib[cam]['K']) for cam in meta.cam_range]
        meta.M_list = [np.array(calib[cam]['M']) for cam in meta.cam_range]
        # meta.M_list = [np.linalg.inv(M) for M in meta.M_list]

        # compensate for crop
        meta.K_list = [
            compensate_crop_K(K, s, o)
            for K, s, o in zip(meta.K_list, meta.scales, meta.offsets)
        ]

        # compensate 2D coordinates for crop
        meta.uv_merged = [
            compensate_crop_coords(pts, s, o)
            for pts, s, o in zip(meta.uv_merged, meta.scales, meta.offsets)
        ]

        # # undistort them
        # meta.img_list = [cv2.undistort(I, K, dist) for I, K, dist in zip(meta.img_list, meta.K_list, dist)]

        meta.uv = [
            cl.project(cl.trafo_coords(meta.xyz, M), K)
            for K, M in zip(meta.K_list, meta.M_list)
        ]
    return metas
Exemple #10
0
def _calc_reprojection_error(cam_intrinsic, cam_dist, cam_extrinsic,
                             coord2d_obs, coord3d):
    """ Calculates the reprojection error for a single 2D / 3D point correspondence and its camera calib. """
    if len(coord3d.shape) == 1:
        coord3d = np.expand_dims(coord3d, 0)
    if len(coord2d_obs.shape) == 1:
        coord2d_obs = np.expand_dims(coord2d_obs, 0)

    # transform into this cams frame
    coord3d_h = np.concatenate([coord3d, np.ones((1, 1))], -1)
    coord3d_cam = np.matmul(coord3d_h,
                            np.transpose(np.linalg.inv(cam_extrinsic)))
    coord3d_cam = coord3d_cam[:, :3] / coord3d_cam[:, -1:]

    # calculate projection of 3D point
    coord2d = np.matmul(coord3d_cam, np.transpose(cam_intrinsic))
    coord2d = coord2d[:, :2] / coord2d[:, -1:]

    # apply distortion to the projected point
    coord2d = cl.distort_points(coord2d, cam_intrinsic, cam_dist)

    # find corresponding observation of this cam
    delta_error = np.sqrt(np.sum(np.square(coord2d - coord2d_obs)))
    return delta_error, coord2d
def read_img(metas):
    for meta in metas:
        if type(meta) == DatasetLabeledMeta:
            if not all([os.path.exists(x) for x in meta.img_paths]):
                print('Path not found: ', meta.img_paths[0])

            meta.is_supervised = 1.0

            # read all images
            meta.img_list = [cv2.imread(x).astype(np.float32) for x in meta.img_paths]

            # read calibration
            calib = _get_calib(meta.calib_id, meta.calib_path)
            # dist = [np.array(calib[cam]['dist']) for cam in meta.cam_range]
            meta.K_list = [np.array(calib[cam]['K']) for cam in meta.cam_range]
            meta.M_list = [np.array(calib[cam]['M']) for cam in meta.cam_range]
            # meta.M_list = [np.linalg.inv(M) for M in meta.M_list]

            # compensate for crop
            meta.K_list = [compensate_crop_K(K, s, o) for K, s, o in zip(meta.K_list, meta.scales, meta.offsets)]

            # compensate 2D coordinates for crop
            meta.uv_merged = [compensate_crop_coords(pts, s, o) for pts, s, o in zip(meta.uv_merged, meta.scales, meta.offsets)]

            # # undistort them
            # meta.img_list = [cv2.undistort(I, K, dist) for I, K, dist in zip(meta.img_list, meta.K_list, dist)]

            meta.uv = [cl.project(cl.trafo_coords(meta.xyz, M), K) for K, M in zip(meta.K_list, meta.M_list)]

        elif type(meta) == DatasetUnlabeledMeta:

            meta.is_supervised = 0.0

            # read a video frames
            fid, meta.img_list = read_random_video_frames(meta.video_set)

            # read calibration
            calib = _get_calib_videos(meta.calib)
            # dist = [np.array(calib[cam]['dist']) for cam in meta.cam_range]
            meta.K_list = [np.array(calib['K'][cam]) for cam in meta.cam_range]
            meta.M_list = [np.array(calib['M'][cam]) for cam in meta.cam_range]

            # read bounding box
            boxes, meta.voxel_root = _get_pred_bb(meta.pred_bb_file, fid)

            if meta.voxel_root is None:
                print('Invalid root in unlabeled sequence, Skipping.')
                return None

            # crop images according to the boxes
            meta.img_list, meta.K_list = _crop_images(meta.img_list, boxes, meta.K_list)

            # create dummy data just so its available:
            meta.img_paths = meta.video_set
            meta.xyz = np.zeros((12, 3))
            meta.uv = np.zeros((len(meta.cam_range), 12, 2))
            meta.uv_merged = meta.uv
            meta.vis = np.zeros((12, ))
            meta.vis_merged = np.zeros((len(meta.cam_range), 12))

        else:
            raise NotImplementedError()


    return metas
Exemple #12
0
                    # if not detected use full image
                    this_box = [0.0, 1.0, 0.0, 1.0]
                box_scaled = np.array([
                    this_box[0] * w, this_box[1] * w, this_box[2] * h,
                    this_box[3] * h
                ])
                this_img = draw_bb(this_img,
                                   box_scaled,
                                   mode='lrtb',
                                   color='g',
                                   linewidth=2)

            if args.draw_root:
                # draw voxel root
                root_uv = cl.project(
                    cl.trafo_coords(np.array(predictions[idx]['xyz']),
                                    M_list[i]), K_list[i])
                this_img = cv2.circle(this_img,
                                      (int(root_uv[0, 0]), int(root_uv[0, 1])),
                                      radius=5,
                                      color=(0, 255, 255),
                                      thickness=-1)

            # draw keypoints
            if 'kp_xyz' in predictions[idx].keys():
                uv_proj = cl.project(
                    cl.trafo_coords(np.array(predictions[idx]['kp_xyz'][0]),
                                    M_list[i]), K_list[i])
                this_img = draw_skel(this_img,
                                     model,
                                     uv_proj,
Exemple #13
0
    def btn_write(self):
        self.save_label_state()  # save current annotation

        num_kp = len(self.config['keypoints'])
        empty = {
            'kp_xyz': np.zeros((num_kp, 3)),
            'vis3d': np.zeros((num_kp, ))
        }

        # assemble all info we want to write to disk
        output_data = dict()
        for k in self.file_list_sel_full_keys:
            fid = int(k)
            if k in self.label_tasks.keys():
                output_data[fid] = self.label_tasks[k]

                # project into views
                for i, cid in enumerate(self.cam_range):
                    # project into frame
                    xyz = self.label_tasks[k]['kp_xyz']
                    kp_uv = cl.project(cl.trafo_coords(xyz, self.M_list[i]),
                                       self.K_list[i], self.dist_list[i])
                    output_data[fid]['cam%d' % cid] = {
                        'kp_uv': kp_uv,
                        'vis': self.label_tasks[k]['vis3d']
                    }

            else:
                output_data[fid] = empty

        self.pb_start(len(output_data))

        # figure out base path
        i = 0
        while True:
            base_path = os.path.join(os.path.dirname(self.video_list[0]),
                                     self.output_task_dir % i)
            if not os.path.exists(base_path):
                break
            i += 1

        # dump frames
        for fid, _ in output_data.items():
            img_list, K_list, M_list, dist_list = self.precacher.get_data(fid)

            # write image frames
            for cid, img in zip(self.cam_range, img_list):
                output_path = os.path.join(base_path, 'cam%d' % cid,
                                           '%08d.png' % fid)
                my_mkdir(output_path, is_file=True)
                cv2.imwrite(output_path, img)
                # print('Dumped: ', output_path)
            self.pb_update()

        self.pb_finish()

        # dump anno
        anno_out_path = os.path.join(base_path, 'anno.json')
        my_mkdir(anno_out_path, is_file=True)
        json_dump(anno_out_path,
                  {'%08d.png' % k: v
                   for k, v in output_data.items()},
                  verbose=True)
Exemple #14
0
        fetches = [
            xyz_tf,
            uv_tf,
            pred_tensors[data_t.pred_score3d][-1]
        ]
        fetches_v = sess.run(fetches, feed_dict=feeds)

        # save predictions
        xyz_pred, uv_pred, score_pred = fetches_v
        predictions[idx]['kp_xyz'] = xyz_pred
        predictions[idx]['kp_score'] = score_pred

        if args.show:
            img_list = list()
            for i, (this_img, this_box, this_uv) in enumerate(zip(imgs, boxes, uv_pred)):
                uv_proj = cl.project(cl.trafo_coords(xyz_pred[0], M_list[i]), K_list[i])
                h, w = this_img.shape[:2]
                if np.all(np.abs(this_box) < 1e-4):
                    # if not detected use full image
                    this_box = [0.0, 1.0, 0.0, 1.0]
                box_scaled = np.array([this_box[0] * w, this_box[1] * w, this_box[2] * h, this_box[3] * h])
                this_img_box = draw_bb(this_img, box_scaled, mode='lrtb', color='g', linewidth=2)

                root_uv = cl.project(cl.trafo_coords(root, M_list[i]), K_list[i])
                this_img_box = cv2.circle(this_img_box,
                                          (int(root_uv[0, 0]), int(root_uv[0, 1])),
                                          radius=5,
                                          color=(0, 255, 255),
                                          thickness=-1)

                img_list.append(
Exemple #15
0
def triangulate_robust(kp_uv,
                       vis,
                       K_list,
                       M_list,
                       dist_list=None,
                       threshold=50.0,
                       mode=None):
    """ Given some 2D observations kp_uv and their respective validity this function finds 3D point hypothesis
        kp_uv: NxKx2 2D points
        vis: NxK visibility
        K_list Nx3x3 camera intrinsic
        dist_list Nx1x5 camera distortion
        M_list Nx4x4 camera extrinsic

        points3d: Kx3 3D point hypothesis
        points2d_proj: NxKx2 projection of points3d into the N cameras
        vis3d: K Validity of points3d
        points2d_merged: NxKx2 Merged result of input observations and points2d_proj according to
                points2d_merged = points2d_proj if ~vis else kp_uv
            So it basically keeps the 2D annotations and uses the reprojection if there was no annotation.
        vis2d_merged: NxK Validity of points2d_merged.
    """
    global G_TRIANG_TOOL, G_TRIANG_TOOL_AVAIL
    if not G_TRIANG_TOOL_AVAIL:
        print('THIS IS THE PROBLEM')
        raise ImportError('Could not load the triangulation toolbox.')

    # create tool if necessary
    if G_TRIANG_TOOL is None:
        G_TRIANG_TOOL = TriangTool()

    if mode is None:
        mode = t_triangulation.RANSAC

    # output values
    num_cams, num_kp = kp_uv.shape[:2]
    points3d = np.zeros((num_kp, 3), dtype=np.float32)
    vis3d = np.zeros((num_kp, ), dtype=np.float32)
    points2d_proj = np.zeros((num_cams, num_kp, 2), dtype=np.float32)

    points2d_merged = kp_uv.copy(
    )  # merged result of projection and 2d annotation (uses 2d anno if avail, otherwise 3d proj)
    vis2d_merged = vis.copy()  # validity of points2d_merged

    # iterate over keypoints
    for kp_id in range(num_kp):
        # resort data
        points2d = list()
        cams = list()
        for cid in range(num_cams):
            if vis[cid, kp_id] > 0.5:
                points2d.append(kp_uv[cid, kp_id])
                cams.append(cid)

        if np.unique(cams).shape[0] >= 2:
            # find consistent 3D hypothesis for the center of bounding boxed
            point3d, inlier = G_TRIANG_TOOL.triangulate(
                [K_list[i] for i in cams], [M_list[i] for i in cams],
                np.expand_dims(np.array(points2d), 1),
                dist_list=dist_list,
                mode=mode,
                threshold=threshold)
            if np.sum(inlier) >= 2:
                points3d[kp_id] = point3d
                vis3d[kp_id] = 1.0
                for cid, (K, M) in enumerate(zip(K_list, M_list)):
                    points2d_proj[cid, kp_id] = cl.project(
                        cl.trafo_coords(point3d, M), K)

                    is_outlier_label = np.linalg.norm(
                        points2d_proj[cid, kp_id] -
                        kp_uv[cid, kp_id]) >= 2 * threshold

                    # fill in projection into merged, if this point was not visible before
                    if vis2d_merged[cid, kp_id] < 0.5:
                        # if the 2D point was not filled before set projection
                        points2d_merged[cid, kp_id] = points2d_proj[cid, kp_id]
                        vis2d_merged[cid, kp_id] = 1.0

                    elif is_outlier_label:
                        # some 2D labels are just weird outliers:
                        # If the distance between a consistent 3D points proj and the label is too big we stick with the 3D pt
                        points2d_merged[cid, kp_id] = points2d_proj[cid, kp_id]
                        vis2d_merged[cid, kp_id] = 1.0

    return points3d, points2d_proj, vis3d, points2d_merged, vis2d_merged
        for fid in range(fid1, fid2):
            img, K, img_shape = reader.read()

            # inpaint pose
            img_p = img.copy()
            for c, pose in zip(['g', 'b'], [pose_ours, pose_dlc]):

                last_uv = None
                for t in range(5):
                    if fid - t < 0:
                        continue
                    if pose[fid - t] is None:
                        continue
                    p = np.reshape(np.array(pose[fid - t]), [-1, 3])
                    uv = cl.project(cl.trafo_coords(p, M_list[cid]),
                                    K)[args.kp_id]
                    # img_p = cv2.circle(img_p,
                    #                    (int(uv[0]), int(uv[1])),
                    #                    radius=2,
                    #                    color=(255, 0, 0) if c == 'b' else (0, 255, 0),
                    #                    thickness=-1)
                    col = (0, 255, 0) if c == 'g' else (0, 0, 255)

                    if last_uv is None:
                        img_p = cv2.circle(img_p, (int(uv[0]), int(uv[1])),
                                           radius=5,
                                           color=col,
                                           thickness=2)
                    else:
                        img_p = cv2.line(img_p, (int(uv[0]), int(uv[1])),
                                         (int(last_uv[0]), int(last_uv[1])),
    model = Model(args.model)
    df = build_dataflow(model, [args.set_name],
                        ['/misc/lmbraid18/datasets/RatPose/RatTrack_paper_resub_sessions/Rat506_200306/run046_cam%d.avi'],
                        ['/misc/lmbraid18/datasets/RatPose/RatTrack_paper_resub_sessions/Rat506_200306/pred_run046__00.json'],
                        is_train=False,
                        threaded=True, single_sample=False)

    start = None
    for idx, dp in enumerate(df.get_data()):
        if idx >= df.size():
            break

        data = df2dict(dp)
        img_rgb = np.round((data[data_t.image]+0.5)*255.0).astype(np.uint8)[:, :, :, ::-1]
        num_cams = img_rgb.shape[0]
        print('is_supervised', data[data_t.is_supervised])

        img_list = list()
        for i in range(num_cams):
            xyz_cam = cl.trafo_coords(data[data_t.xyz_nobatch][0], data[data_t.M][i])
            uv = cl.project(xyz_cam, data[data_t.K][i])
            I = draw_skel(img_rgb[i], model, data[data_t.uv][i], data[data_t.vis_nobatch][0], order='uv')
            img_list.append(I)
        xyz = data[data_t.xyz_nobatch][0]

        merge = StitchedImage(img_list, target_size=(int(0.8 * args.window_size), args.window_size))

        cv2.imshow('pose labeled', merge.image[:, :, ::-1])
        cv2.waitKey(0 if args.wait else 10)

def post_process_detections(boxes,
                            scores,
                            K_list,
                            M_list,
                            img_shape,
                            min_score_cand=0.2,
                            min_score_pick=0.1,
                            max_reproj_error=10.0,
                            verbose=True,
                            img=None,
                            logger=None):
    """ Some post processing to increase the quality of bounding box detections.
        We consider all bounding boxes as candidate above some min_score_cand and calculate their 2D center position.
        Using all centers we aim to find a 3D hypothesis explaining as many centers as possible. Subsequently, we pick
        the boxes with minimal distance to the
    """
    global triang_tool
    if triang_tool is None:
        triang_tool = TriangTool()

    output = {'boxes': None, 'xyz': None}

    # calculate bounding box centers
    box_centers = np.stack([
        0.5 * (boxes[:, :, 1] + boxes[:, :, 3]) * img_shape[1], 0.5 *
        (boxes[:, :, 0] + boxes[:, :, 2]) * img_shape[0]
    ], -1)

    if img is not None:
        # If an image was give we assume that it should be showed
        img_list = list()
        # show centers
        for ind, I in enumerate(img):
            I = I.copy()
            for s, b in zip(scores[ind], box_centers[ind]):
                if s < min_score_cand:
                    continue
                c = (int(b[0]), int(b[1]))
                I = cv2.circle(I, c, radius=5, color=(0, 0, 255), thickness=-1)
            img_list.append(I)

        from utils.StitchedImage import StitchedImage
        merge = StitchedImage(img_list)
        cv2.imshow('centers all', merge.image)
        cv2.waitKey(10)

    # resort data
    points2d = list()
    cams = list()
    for cid in range(boxes.shape[0]):
        for i in range(boxes.shape[1]):
            if scores[cid, i] > min_score_cand:
                points2d.append(box_centers[cid, i])
                cams.append(cid)

        if verbose and logger is not None:
            logger.log('Cam %d contributes %d points for triangulation' %
                       (cid, np.sum(scores[cid] > min_score_cand)))
    points2d = np.array(points2d)

    if np.unique(cams).shape[0] >= 3:
        # find consistent 3D hypothesis for the center of bounding boxed
        point3d, inlier = triang_tool.triangulate([K_list[i] for i in cams],
                                                  [M_list[i] for i in cams],
                                                  np.expand_dims(points2d, 1),
                                                  mode=t_triangulation.RANSAC,
                                                  threshold=max_reproj_error)
        if verbose and logger is not None:
            logger.log('Found 3D point with %d inliers' % np.sum(inlier))

            if img is not None:
                img_list = list()
                for ind, (I, K, M) in enumerate(zip(img, K_list, M_list)):
                    p2d = cl.project(cl.trafo_coords(point3d, M), K)
                    c = (int(p2d[0, 0]), int(p2d[0, 1]))
                    print(ind, c)
                    I = cv2.circle(I.copy(),
                                   c,
                                   radius=5,
                                   color=(0, 0, 255),
                                   thickness=-1)
                    img_list.append(I)

                from utils.StitchedImage import StitchedImage
                merge = StitchedImage(img_list)
                cv2.imshow('center consistent', merge.image)
                cv2.waitKey()

        if np.sum(inlier) > 0:
            output['xyz'] = point3d

            # select optimal box wrt the center
            order = [1, 3, 0, 2]
            boxes_opti = list()
            for cid, (K, M) in enumerate(zip(K_list, M_list)):
                uv = cl.project(cl.trafo_coords(point3d, M), K)

                # find bbox with minimal distance to found center
                diff_l2 = np.sqrt(np.sum(np.square(box_centers[cid] - uv), -1))
                diff_combined = diff_l2 / np.sqrt(
                    scores[cid] + 0.001
                )  # we want to pick something with low distance and high score
                ind = np.argmin(diff_combined)
                boxes_opti.append(boxes[cid, ind, order])
            output['boxes'] = np.stack(boxes_opti)

            return output

    # If we get here its time to use the fall back solution:
    # Use top scoring bbox in each frame independently
    boxes_opti = list()
    order = [1, 3, 0, 2]
    for box, score in zip(boxes, scores):
        ind = np.argmax(score)
        boxes_opti.append(box[ind, order])
    output['boxes'] = np.stack(boxes_opti)

    if verbose and logger is not None:
        logger.log(
            'Using fallback solution: Best scoring box from each view, because of small amount of inliers.'
        )

    return output