Пример #1
0
    def load_image_sample(self, idx):

        frame = json.load(open(self.label_name[idx], 'r'))
        pd_name = self.label_name[idx].replace('data', 'output')
        pd_name = pd_name.replace('label', 'pred')
        if os.path.isfile(pd_name):
            frame_pd = json.load(open(pd_name, 'r'))
        else:
            # No prediction json file found
            frame_pd = {'prediction': []}

        n_box = len(frame['labels'])
        if n_box > self.n_box_limit:
            # print("n_box ({}) exceed the limit {}, clip up to
            # limit.".format(n_box, self.n_box_limit))
            n_box = self.n_box_limit

        # Frame level annotations
        im_path = os.path.join(self.IM_PATH, frame['name'])
        endvid = int(idx + 1 in self.seq_accum)
        cam_rot = np.array(frame['extrinsics']['rotation'])
        cam_loc = np.array(frame['extrinsics']['location'])
        cam_calib = np.array(frame['intrinsics']['cali'])
        #cam_focal = np.array(frame['intrinsics']['focal'])
        #cam_near_clip = np.array(frame['intrinsics']['nearClip'])
        #cam_fov_h = np.array(frame['intrinsics']['fov'])
        pose = tu.Pose(cam_loc, cam_rot, not self.use_kitti)

        # Object level annotations
        if self.phase in ['train', 'val']:
            labels = frame['labels'][:n_box]
            predictions = frame_pd['prediction'][:n_box]

            # Random shuffle data
            np.random.seed(777)
            np.random.shuffle(labels)
        else:
            labels = frame['labels']
            predictions = frame_pd['prediction']

        rois_pd = bh.get_box2d_array(predictions).astype(float)
        rois_gt = bh.get_box2d_array(labels).astype(float)
        tid = bh.get_label_array(labels, ['id'], (0)).astype(int)
        # Dim: H, W, L
        dim = bh.get_label_array(labels, ['box3d', 'dimension'],
                                 (0, 3)).astype(float)
        # Alpha: -pi ~ pi
        alpha = bh.get_label_array(labels, ['box3d', 'alpha'],
                                   (0)).astype(float)
        # Location in cam coord: x-right, y-down, z-front
        location = bh.get_label_array(labels, ['box3d', 'location'],
                                      (0, 3)).astype(float)

        # Center
        # f_x,   s, cen_x, ext_x
        #   0, f_y, cen_y, ext_y
        #   0,   0,     1, ext_z
        ext_loc = np.hstack([location, np.ones([len(location), 1])])  # (B, 4)
        proj_loc = ext_loc.dot(cam_calib.T)  # (B, 4) dot (3, 4).T => (B, 3)
        center_gt = proj_loc[:, :2] / proj_loc[:, 2:3]  # normalize

        if self.phase in ['train', 'val']:
            # For depth training
            #center_pd = center_gt.copy()
            # For center training
            cenx = (rois_gt[:, 0:1] + rois_gt[:, 2:3]) / 2
            ceny = (rois_gt[:, 1:2] + rois_gt[:, 3:4]) / 2
            center_pd = np.concatenate([cenx, ceny], axis=1)
        else:
            center_pd = bh.get_cen_array(predictions)

        # Depth
        depth = np.maximum(0, location[:, 2])

        ignore = bh.get_label_array(labels, ['attributes', 'ignore'],
                                    (0)).astype(int)
        # Get n_box_limit batch
        rois_gt = np.vstack([rois_gt, np.zeros([self.n_box_limit,
                                                5])])[:self.n_box_limit]
        if self.phase in ['train', 'val']:
            rois_pd = rois_gt.copy()
            rois_pd[:, :4] += np.random.rand(rois_gt.shape[0], 4) * 3
        else:
            rois_pd = np.vstack([rois_pd,
                                 np.zeros([self.n_box_limit,
                                           5])])[:self.n_box_limit]
        tid = np.hstack([tid, np.zeros(self.n_box_limit)])[:self.n_box_limit]
        alpha = np.hstack([alpha,
                           np.zeros(self.n_box_limit)])[:self.n_box_limit]
        depth = np.hstack([depth,
                           np.zeros(self.n_box_limit)])[:self.n_box_limit]
        center_pd = np.vstack([center_pd,
                               np.zeros([self.n_box_limit,
                                         2])])[:self.n_box_limit]
        center_gt = np.vstack([center_gt,
                               np.zeros([self.n_box_limit,
                                         2])])[:self.n_box_limit]
        dim = np.vstack([dim, np.zeros([self.n_box_limit,
                                        3])])[:self.n_box_limit]
        ignore = np.hstack([ignore,
                            np.zeros(self.n_box_limit)])[:self.n_box_limit]

        # objects center in the world coordinates
        loc_gt = tu.point3dcoord(center_gt, depth, cam_calib, pose)

        # Load images
        img = cv2.imread(im_path)
        assert img is not None, "Cannot read {}".format(im_path)

        h, w, _ = img.shape
        p_h = self.H - h
        p_w = self.W - w
        assert p_h >= 0, "target hight - image hight = {}".format(p_h)
        assert p_w >= 0, "target width - image width = {}".format(p_w)
        img = copy_border_reflect(img, p_h, p_w)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_patch = np.rollaxis(img, 2, 0)
        img_patch = img_patch.astype(float) / 255.0

        # Normalize
        if self.is_normalizing:
            img_patch = (img_patch - self.mean) / self.std

        bin_cls = np.zeros((self.n_box_limit, 2))
        bin_res = np.zeros((self.n_box_limit, 2))

        for i in range(n_box):
            if alpha[i] < np.pi / 6. or alpha[i] > 5 * np.pi / 6.:
                bin_cls[i, 0] = 1
                bin_res[i, 0] = alpha[i] - (-0.5 * np.pi)

            if alpha[i] > -np.pi / 6. or alpha[i] < -5 * np.pi / 6.:
                bin_cls[i, 1] = 1
                bin_res[i, 1] = alpha[i] - (0.5 * np.pi)

        box_info = {
            'im_path': im_path,
            'rois_pd': torch.from_numpy(rois_pd).float(),
            'rois_gt': torch.from_numpy(rois_gt).float(),
            'dim_gt': torch.from_numpy(dim).float(),
            'bin_cls_gt': torch.from_numpy(bin_cls).long(),
            'bin_res_gt': torch.from_numpy(bin_res).float(),
            'alpha_gt': torch.from_numpy(alpha).float(),
            'depth_gt': torch.from_numpy(depth).float(),
            'cen_pd': torch.from_numpy(center_pd).float(),
            'cen_gt': torch.from_numpy(center_gt).float(),
            'loc_gt': torch.from_numpy(loc_gt).float(),
            'tid_gt': torch.from_numpy(tid).int(),
            'ignore': torch.from_numpy(ignore).int(),
            'n_box': n_box,
            'endvid': endvid,
            'cam_calib': torch.from_numpy(cam_calib).float(),
            'cam_rot': torch.from_numpy(pose.rotation).float(),
            'cam_loc': torch.from_numpy(pose.position).float(),
        }

        return torch.from_numpy(img_patch).float(), box_info
Пример #2
0
def plot_3D_box(info_gt, info_pd, args, session_name):
    if args.draw_bev: print("BEV: {}".format(session_name))
    if args.draw_2d or args.draw_3d: print("3D: {}".format(session_name))

    # Variables
    fig, ax = plt.subplots(figsize=(5, 5), dpi=100)

    # set output video
    if args.is_save:
        vid_trk = cv2.VideoWriter(
            '{}_{}_{}_{}.mp4'.format(
                session_name, args.box_key, 'tracking',
                '_'.join([str(n) for n in args.select_seq])), FOURCC, args.fps,
            (resW, resH))
        vid_bev = cv2.VideoWriter(
            '{}_{}_{}_{}.mp4'.format(
                session_name, args.box_key, 'birdsview',
                '_'.join([str(n) for n in args.select_seq])), FOURCC, args.fps,
            (resH, resH))
    else:
        vid_trk = None
        vid_bev = None

    # Iterate through all objects
    for n_seq, (pd_seq, gt_seq) in enumerate(zip(info_pd, info_gt)):

        id_to_color = {}
        cmap = pu.RandomColor(len(gt_seq['frames']))
        np.random.seed(777)

        if n_seq not in args.select_seq:
            continue

        for n_frame, (pd_boxes, gt_boxes) in enumerate(
                zip(pd_seq['frames'], gt_seq['frames'])):
            if n_frame % 100 == 0:
                print(n_frame)

            # Get objects
            if args.draw_3d or args.draw_2d:
                rawimg = cv2.imread(gt_boxes['im_path'][0])
                cv2.putText(rawimg, '{}'.format(n_frame), (0, 30), FONT, 1,
                            (0, 0, 0), 2, cv2.LINE_AA)

            if len(gt_boxes['annotations']) > 0:
                cam_coords = np.array(gt_boxes['annotations'][0]['cam_loc'])
                cam_rotation = np.array(gt_boxes['annotations'][0]['cam_rot'])
                cam_calib = np.array(gt_boxes['annotations'][0]['cam_calib'])

                cam_pose = tu.Pose(cam_coords, cam_rotation)
                boxes_pd = [
                    hypo[args.box_key] for hypo in pd_boxes['hypotheses']
                ]

            for i, anno in enumerate(gt_boxes['annotations']):

                tid = anno['id']
                box_gt = np.array(anno['box']).astype(int)
                h_gt, w_gt, l_gt = anno['dim']
                depth_gt = anno['depth']
                alpha_gt = anno['alpha']
                xc_gt, yc_gt = anno['xc'], anno['yc']

                center_gt = np.hstack([
                    (xc_gt - W // 2) * depth_gt / FOCAL_LENGTH, depth_gt
                ])

                rot_y_gt = tu.alpha2rot_y(alpha_gt, xc_gt - W // 2,
                                          FOCAL_LENGTH)

                vol_box_gt = get_3d_box_from_2d(np.array([[xc_gt, yc_gt]]),
                                                np.array([depth_gt]), rot_y_gt,
                                                (h_gt, w_gt, l_gt), cam_calib,
                                                cam_pose)

                # Match gt and pd
                has_match = len(boxes_pd) != 0
                if has_match:
                    _, idx, valid = tu.matching(
                        np.array(anno['box']).reshape(-1, 4),
                        np.array(boxes_pd).reshape(-1, 5)[:, :4], 0.8)
                    has_match = has_match and valid.item()

                    hypo = pd_boxes['hypotheses'][idx[0]]

                    # Get information of gt and pd
                    box_pd = np.array(hypo[args.box_key]).astype(int)

                    h_pd, w_pd, l_pd = hypo['dim']
                    depth_pd = hypo['depth']
                    alpha_pd = hypo['alpha']
                    xc_pd, yc_pd = hypo['xc'], hypo['yc']

                    center_pd = np.hstack([
                        (xc_pd - W // 2) * depth_pd / FOCAL_LENGTH, depth_pd
                    ])

                    rot_y_pd = tu.alpha2rot_y(alpha_pd, xc_pd - W // 2,
                                              FOCAL_LENGTH)

                    vol_box_pd = get_3d_box_from_2d(np.array([[xc_pd, yc_pd]]),
                                                    np.array([depth_pd]),
                                                    roy_y_pd,
                                                    (h_pd, w_pd, l_pd),
                                                    cam_calib, cam_pose)

                # Get box color
                # color is in BGR format (for cv2), color[:-1] in RGB format
                # (for plt)
                if tid not in list(id_to_color):
                    id_to_color[tid] = [cmap.get_random_color(scale=255), 10]
                else:
                    id_to_color[tid][1] = 10
                color, life = id_to_color[tid]

                # Make rectangle
                if args.draw_3d:
                    # Make rectangle
                    rawimg = tu.draw_3d_cube(rawimg,
                                             vol_box_gt,
                                             tid,
                                             cam_calib,
                                             cam_pose,
                                             line_color=(color[0],
                                                         color[1] * 0.7,
                                                         color[2] * 0.7),
                                             line_width=2)
                    if has_match:
                        rawimg = tu.draw_3d_cube(rawimg,
                                                 vol_box_pd,
                                                 tid,
                                                 cam_calib,
                                                 cam_pose,
                                                 line_color=color)
                if args.draw_2d:
                    #text_gt = 'GT:{}° {}m'.format(
                    #    int(alpha_gt / np.pi * 180),
                    #    int(depth_gt))
                    cv2.rectangle(rawimg, (box_gt[0], box_gt[1]),
                                  (box_gt[2], box_gt[3]),
                                  (color[0], color[1] * 0.7, color[2] * 0.7),
                                  8)
                    if has_match:
                        #text_pd = 'PD:{}° {}m'.format(
                        #    int(alpha_pd / np.pi * 180),
                        #    int(depth_pd))
                        cv2.rectangle(rawimg, (box_pd[0], box_pd[1]),
                                      (box_pd[2], box_pd[3]), color, 10)
                if args.draw_bev:
                    # Change BGR to RGB
                    color_bev = [c / 255.0 for c in color[::-1]]
                    if has_match:
                        plot_bev_obj(center_pd, 'PD', rot_y_pd, l_pd, w_pd,
                                     plt, color_bev)

                    plot_bev_obj(center_gt, 'GT', rot_y_gt, l_gt, w_gt, plt,
                                 color_bev)

            if args.draw_bev:
                # Make plot
                ax.set_aspect('equal', adjustable='box')
                plt.axis([-60, 60, -10, 100])
                # plt.axis([-80, 80, -10, 150])
                plt.plot([0, 0], [0, 3], 'k-')
                plt.plot([-1, 0], [2, 3], 'k-')
                plt.plot([1, 0], [2, 3], 'k-')

            for tid in list(id_to_color):
                id_to_color[tid][1] -= 1
                if id_to_color[tid][1] < 0:
                    del id_to_color[tid]

            # Plot
            if vid_trk:
                vid_trk.write(cv2.resize(rawimg, (resW, resH)))
            elif args.draw_3d or args.draw_2d:
                #draw_img = rawimg[:, :, ::-1]

                #fig = plt.figure(figsize=(18, 9), dpi=50)
                #plt.imshow(draw_img)
                key = 0
                while (key not in [ord('q'), ord(' '), 27]):
                    cv2.imshow('preview', cv2.resize(rawimg, (resW, resH)))
                    key = cv2.waitKey(1)

                if key == 27:
                    cv2.destroyAllWindows()
                    return

            # Plot
            if vid_bev:
                fig_data = pu.fig2data(plt.gcf())
                vid_bev.write(cv2.resize(fig_data, (resH, resH)))
                plt.clf()
            elif args.draw_bev:
                plt.show()
                plt.clf()

    if args.is_save:
        vid_trk.release()
        vid_bev.release()
    def update(self, data):
        # frame information here
        ret = []
        self.frame_count += 1
        self.cam_rot = data['cam_rot'].squeeze()  # In rad
        self.cam_coord = data['cam_loc'].squeeze()
        if self.frame_count == 1:
            self.init_coord = self.cam_coord.copy()
        self.cam_coord -= self.init_coord
        self.cam_calib = data['cam_calib'].squeeze()
        self.cam_pose = tu.Pose(self.cam_coord, self.cam_rot)

        # process information
        dets, feats, dims, alphas, single_depths, cens, roty, world_coords = \
            self.process_info(
                data,
                det_thresh=self.det_thresh,
                max_depth=self.max_depth,
                _nms=self.dataset=='gta',
                _valid=self.dataset=='gta',
                )

        # save to current frame
        self.current_frame = {
            'bbox': dets,
            'feat': feats,
            'dim': dims,
            'alpha': alphas,  # in deg
            'roty': roty,  # in rad
            'depth': single_depths,
            'center': cens,
            'location': world_coords,
            'n_obj': len(alphas),
        }

        # get predicted locations from existing trackers.
        trk_dim = self.feat_dim + self.det_dim
        trks = np.zeros((len(self.trackers), trk_dim))

        # Prediction
        for t, trk in enumerate(trks):
            if self.kf2d:
                pos = self.trackers[t].predict()[0]
            else:
                pos = self.trackers[t].predict_no_effect()[0]
            trk[:self.det_dim] = [pos[0], pos[1], pos[2], pos[3], 1.0]
            trk[self.det_dim:] = self.trackers[t].feat

        # Association
        self.affinity = np.zeros((len(dets), len(trks)), dtype=np.float32)

        for d, det in enumerate(dets):
            for t, trk in enumerate(trks):
                self.affinity[d,
                              t] += self.iou_affinity_weight * tu.compute_iou(
                                  det[:4], trk[:4])
        if self.deep_sort and len(dets) * len(trks) > 0:
            self.affinity += self.feat_affinity_weight * tu.compute_cos_dis(
                feats, trks[:, 5:])

        matched, unmatched_dets, unmatched_trks = \
            tu.associate_detections_to_trackers(
                dets, trks, self.affinity,
                self.affinity_threshold)

        # update matched trackers with assigned detections
        for t, trk in enumerate(self.trackers):
            if t in unmatched_trks:
                trk.lost = True
                continue

            d = matched[np.where(matched[:, 1] == t)[0], 0]
            trk.update(dets[d, :][0])
            trk.det = dets[d, :][0]
            trk.feat = self.current_frame['feat'][d[0]]
            trk.dim = self.current_frame['dim'][d[0]]
            trk.alpha = self.current_frame['alpha'][d[0]]
            trk.depth = self.current_frame['depth'][d[0]]
            trk.cen = self.current_frame['center'][d[0]]
            trk.rot = self.current_frame['roty'][d[0]]

        # create and initialise new trackers for unmatched detections
        for i in unmatched_dets:
            trk = KalmanBoxTracker(dets[i, :])
            trk.det = dets[i, :]
            trk.feat = self.current_frame['feat'][i]
            trk.dim = self.current_frame['dim'][i]
            trk.alpha = self.current_frame['alpha'][i]
            trk.depth = self.current_frame['depth'][i]
            trk.cen = self.current_frame['center'][i]
            trk.rot = self.current_frame['roty'][i]
            self.trackers.append(trk)

        # Check if boxes are correct
        if self.visualize:
            img = cv2.imread(data['im_path'][0])
            h, w, _ = img.shape
            img = cv2.putText(img, str(self.frame_count), (20, 30),
                              cv2.FONT_HERSHEY_COMPLEX, 1, (200, 200, 200), 2)

            for idx, trk in enumerate(self.trackers):
                box_color = (0, 150, 150) if trk.lost else (0, 255, 0)
                box_bold = 2 if trk.lost else 4
                box = trk.get_state()[0].astype('int')
                print(
                    trk.id + 1,
                    'Lost' if trk.lost else 'Tracked',
                    #trk.aff_value,
                    trk.depth,
                )
                if trk.depth < 0 or trk.depth > self.max_depth:
                    continue
                img = cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]),
                                    box_color, box_bold - 1)
                img = cv2.putText(img, str(int(trk.id + 1)),
                                  (int(box[0]), int(box[1]) + 20),
                                  cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255), 2)
                img = cv2.putText(img, str(int(trk.depth)),
                                  (int(box[2]) - 14, int(box[3])),
                                  cv2.FONT_HERSHEY_COMPLEX, 0.8, (0, 0, 255),
                                  2)
                xc = int(trk.cen[0])
                yc = int(trk.cen[1])
                img = cv2.rectangle(img, (xc - 1, yc - 1), (xc + 1, yc + 1),
                                    (0, 0, 255), box_bold)

            for idx, (det, cen) in enumerate(zip(dets, cens)):
                det = det.astype('int')
                cen = cen.astype('int')
                img = cv2.rectangle(img, (det[0], det[1]), (det[2], det[3]),
                                    (255, 0, 0), 2)
                img = cv2.rectangle(img, (cen[0] - 1, cen[1] - 1),
                                    (cen[0] + 1, cen[1] + 1), (255, 0, 0), 4)

            key = 0
            while (key not in [ord('q'), ord(' '), 27]):
                cv2.imshow('preview', cv2.resize(img, (0, 0), fx=0.5, fy=0.5))
                key = cv2.waitKey(1)

            if key == 27:
                cv2.destroyAllWindows()
                exit()

        # Get output returns and remove dead tracklet
        i = len(self.trackers)
        for trk in reversed(self.trackers):
            trk_box = trk.get_state()[0]
            if (trk.time_since_update <
                    1) and (trk.hit_streak >= self.min_hits
                            or self.frame_count <= self.min_hits):
                # +1 as MOT benchmark requires positive
                height = float(trk.det[3] - trk.det[1])
                width = float(trk.det[2] - trk.det[0])
                hypo = {
                    'height': height,
                    'width': width,
                    'trk_box': trk_box.tolist(),
                    'det_box': trk.det.tolist(),
                    'id': trk.id + 1,
                    'x': int(trk.cen[0]),
                    'y': int(trk.cen[1]),
                    'dim': trk.dim.tolist(),
                    'alpha': trk.alpha.item(),
                    'roty': trk.rot.item(),
                    'depth': trk.depth.item()
                }
                ret.append(hypo)
            i -= 1
            # remove dead tracklet
            if trk.time_since_update > self.max_age:
                self.trackers.pop(i)
        return ret
    def update(self, data):
        # frame information here
        ret = []
        self.frame_count += 1
        self.cam_rot = data['cam_rot'].squeeze() # In rad
        self.cam_coord = data['cam_loc'].squeeze()
        if self.frame_count == 1:
            self.init_coord = self.cam_coord.copy()
        self.cam_coord -= self.init_coord
        self.cam_calib = data['cam_calib'].squeeze()
        self.cam_pose = tu.Pose(self.cam_coord, self.cam_rot)

        # process information
        dets, feats, dims, alphas, single_depths, cens, roty, world_coords = \
            self.process_info(
                data,
                det_thresh=self.det_thresh,
                max_depth=self.max_depth,
                _nms=self.dataset=='carla',
                _valid=self.dataset=='carla',
                _center=self.dataset=='kitti'
                )

        # save to current frame
        self.current_frame = {
            'bbox': dets,
            'feat': feats,
            'dim': dims,
            'alpha': alphas, # in deg
            'roty': roty, # in rad
            'depth': single_depths,
            'center': cens,
            'location': world_coords,
            'n_obj': len(alphas),
        }

        # Prediction
        # get predicted locations from existing trackers.
        trk_locs = np.zeros((len(self.trackers), 3))
        trk_dets = np.zeros((len(self.trackers), 5))
        trk_dims = np.zeros((len(self.trackers), 3))
        trk_rots = np.zeros((len(self.trackers), 1))
        trk_feats = np.zeros((len(self.trackers), self.feat_dim))
        for t in range(len(trk_locs)):
            trk_locs[t] = self.trackers[t].predict().squeeze()
            trk_dets[t] = self.trackers[t].det
            trk_dims[t] = self.trackers[t].dim
            trk_rots[t] = self.trackers[t].rot
            trk_feats[t] = self.trackers[t].feat
            trk_cen = tu.projection3d(self.cam_calib, self.cam_pose, trk_locs[t:t+1])
            self.trackers[t].cen = trk_cen.squeeze()

        # Generate 2D boxes from 3D estimated location
        trkboxes, trkdepths, trkpoints = tu.construct2dlayout(trk_locs, trk_dims, trk_rots,
                                             self.cam_calib,
                                             self.cam_pose)
        detboxes, detdepths, detpoints = tu.construct2dlayout(world_coords, dims, roty,
                                             self.cam_calib,
                                             self.cam_pose)

        # Association
        idxes_order = np.argsort(trkdepths)
        boxes_order = []
        for idx in idxes_order:
            if self.use_occ:
                # Check if trk box has occluded by others
                if boxes_order != []:
                    # Sort boxes
                    box = trkboxes[idx]
                    ious = []
                    for bo in boxes_order:
                        ious.append(tu.compute_iou(bo, box))
                    # Check if occluded
                    self.trackers[idx].occ = (max(ious) > self.occ_iou_thresh)
            boxes_order.append(trkboxes[idx])

        trk_depths_order = np.array(trkdepths)[idxes_order]
        trk_feats_order = trk_feats[idxes_order]
        trk_dim_order = trk_dims[idxes_order]

        coord_affinity = np.zeros((len(detboxes), len(boxes_order)),
                             dtype=np.float32)
        feat_affinity = np.zeros((len(detboxes), len(boxes_order)),
                             dtype=np.float32)

        if self.use_occ:
            for d, det in enumerate(detboxes):
                if len(boxes_order) != 0:
                    coord_affinity[d, :] = \
                        tu.compute_boxoverlap_with_depth(
                            dets[d],
                            [det[0], det[1], det[2], det[3], 1.0],
                            detdepths[d],
                            dims[d],
                            trk_dets[idxes_order],
                            boxes_order,
                            trk_depths_order,
                            trk_dim_order,
                            H=self.H,
                            W=self.W)
        else:
            for d, det in enumerate(detboxes):
                for t, trk in enumerate(boxes_order):
                    coord_affinity[d, t] += tu.compute_iou(trk, det[:4])

        # Filter out those are not overlaped at all
        location_mask = (coord_affinity>0)

        if self.deep_sort and len(detboxes) * len(boxes_order) > 0:
            feat_affinity += location_mask * \
                             tu.compute_cos_dis(feats, trk_feats_order)

        self.affinity = self.coord_3d_affinity_weight * coord_affinity + \
                        self.feat_affinity_weight * feat_affinity

        # Assignment
        matched, unmatched_dets, unmatched_trks = \
            tu.associate_detections_to_trackers(
                detboxes, boxes_order, self.affinity,
                self.affinity_threshold)

        # update matched trackers with assigned detections
        for t, trkidx in enumerate(idxes_order):
            if t in unmatched_trks:
                self.trackers[trkidx].lost = True
                self.trackers[trkidx].aff_value *= 0.9
                continue

            d = matched[np.where(matched[:, 1] == t)[0], 0]
            if self.kf3d:
                self.trackers[trkidx].update(world_coords[d[0]])
            elif self.lstm3d or self.lstmkf3d:
                self.trackers[trkidx].update(world_coords[d[0]])
            self.trackers[trkidx].lost = False
            self.trackers[trkidx].aff_value = self.affinity[d, t].item()
            self.trackers[trkidx].det = dets[d, :][0]
            self.trackers[trkidx].trk_box = boxes_order[t]
            feat_alpha = 1 - feat_affinity[d, t].item()
            self.trackers[trkidx].feat += feat_alpha * (self.current_frame['feat'][d[0]] - self.trackers[trkidx].feat)
            self.trackers[trkidx].dim = self.current_frame['dim'][d[0]]
            self.trackers[trkidx].alpha = self.current_frame['alpha'][d[0]]
            self.trackers[trkidx].depth = self.current_frame['depth'][d[0]]
            self.trackers[trkidx].cen = self.current_frame['center'][d[0]]
            self.trackers[trkidx].rot = self.current_frame['roty'][d[0]]

        # create and initialise new trackers for unmatched detections
        for i in unmatched_dets:
            if self.kf3d:
                trk = KalmanBox3dTracker(world_coords[i])
            elif self.lstm3d:
                trk = LSTM3dTracker(self.device,
                                    self.lstm,
                                    world_coords[i])
            elif self.lstmkf3d:
                trk = LSTMKF3dTracker(self.device,
                                    self.lstm,
                                    world_coords[i])
            trk.det = dets[i, :]
            trk.trk_box = detboxes[i]
            trk.feat = self.current_frame['feat'][i]
            trk.dim = self.current_frame['dim'][i]
            trk.alpha = self.current_frame['alpha'][i]
            trk.depth = self.current_frame['depth'][i]
            trk.cen = self.current_frame['center'][i]
            trk.rot = self.current_frame['roty'][i]
            self.trackers.append(trk)

        # Check if boxes are correct
        if self.visualize:
            img = cv2.imread(data['im_path'][0])
            _h, _w, _ = img.shape
            img = cv2.putText(img, str(self.frame_count), (20, 30),
                        cv2.FONT_HERSHEY_COMPLEX, 1,
                        (200, 200, 200), 2)

            lost_color = (0, 150, 150)
            occ_color = (150, 0, 150)
            trk_color = (0, 255, 0)
            det_color = (255, 0, 0)
            gt_color = (0, 0, 255)
            cb = 5
            for idx, (box, po, trk) in enumerate(zip(trkboxes, trkpoints, self.trackers)):
                box_color = lost_color if trk.lost else trk_color
                box_color = occ_color if trk.occ else box_color
                box_bold = 2 if trk.lost or trk.occ else 4
                box = box.astype('int')
                print(trk.id+1, 
                    'Lost' if trk.lost else 'Tracked', 
                    '{:2d}'.format(trk.time_since_update),
                    '{:.02f} {:.02f}'.format(trk.aff_value, trkdepths[idx]), 
                    trk.get_history()[-1].flatten(), 
                    trk.get_state()
                    )
                if trkdepths[idx] < 0 or trkdepths[idx] > self.max_depth:
                    continue
                '''
                for (ii,jj) in po:
                    img = cv2.line(img, (int(ii[0]), int(ii[1])),
                            (int(jj[0]), int(jj[1])), box_color, box_bold)
                #'''
                img = cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), box_color, box_bold-1)
                img = cv2.rectangle(img, (int(trk.cen[0])-cb, int(trk.cen[1])-cb), 
                                        (int(trk.cen[0])+cb, int(trk.cen[1])+cb), box_color, box_bold)

                img = cv2.putText(img, '{}'.format(trk.id+1), (int(box[0]), int(box[1])+20),
                        cv2.FONT_HERSHEY_COMPLEX, 1, box_color, box_bold)
                img = cv2.putText(img, '{:.02f}'.format(trk.aff_value), (int(box[0]-14), int(box[3])+20),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, box_color, box_bold)
                img = cv2.putText(img, str(int(trkdepths[idx])), (int(box[2])-14, int(box[3])),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, box_color, box_bold)

            if len(data['alpha_gt']) > 0:
                valid_rots = np.zeros_like(data['alpha_gt'])[:, np.newaxis]
                for idx, (alpha, det, center) in enumerate(
                                zip(data['alpha_gt'], data['rois_gt'], data['center_gt'])):
                    valid_rots[idx] = tu.alpha2rot_y(
                                            alpha, 
                                            center[0] - self.W//2,
                                            FOCAL_LENGTH=self.cam_calib[0][0])
                loc_gt = tu.point3dcoord(data['center_gt'], data['depth_gt'], self.cam_calib, self.cam_pose)

                if self.dataset == 'kitti':
                    loc_gt[:, 2] += data['dim_gt'][:, 0] / 2 
                #bbgt, depgt, ptsgt = tu.construct2dlayout(loc_gt, data['dim_gt'], valid_rots,
                #                                 self.cam_calib,
                #                                 self.cam_pose)
                for idx, (tid, boxgt, cengt) in enumerate(zip(data['tid_gt'], data['rois_gt'], data['center_gt'])):
                    detgt = boxgt.astype('int')
                    cengt = cengt.astype('int')
                    img = cv2.rectangle(img, (detgt[0], detgt[1]), (detgt[2], detgt[3]), gt_color, 2)
                    img = cv2.rectangle(img, (cengt[0]-cb, cengt[1]-cb), (cengt[0]+cb, cengt[1]+cb), gt_color, 4)

                    '''
                    for (ii,jj) in ptsgt[idx]:
                        img = cv2.line(img, (int(ii[0]), int(ii[1])),
                                (int(jj[0]), int(jj[1])), gt_color, 2)
                    #'''

            for idx, (det, detbox, detpo, cen) in enumerate(zip(dets, detboxes, detpoints, cens)):
                #det = det.astype('int')
                detbox = detbox.astype('int')
                cen = cen.astype('int')
                '''
                for (ii,jj) in detpo:
                    img = cv2.line(img, (int(ii[0]), int(ii[1])),
                            (int(jj[0]), int(jj[1])), det_color, 2)
                #'''
                #img = cv2.rectangle(img, (det[0], det[1]), (det[2], det[3]), det_color, 2)
                img = cv2.rectangle(img, (detbox[0], detbox[1]), (detbox[2], detbox[3]), det_color, 2)
                img = cv2.rectangle(img, (cen[0]-cb, cen[1]-cb), (cen[0]+cb, cen[1]+cb), det_color, 4)
                img = cv2.putText(img, str(int(detdepths[idx])), (int(detbox[2])-14, int(detbox[3])),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, det_color, 4)

            key = 0
            while(key not in [ord('q'), ord(' '), 27]):
                _f = 0.5 if _h > 600 else 1.0
                cv2.imshow('preview', cv2.resize(img, (0, 0), fx=_f, fy=_f))
                key = cv2.waitKey(1)

            if key == 27:
                cv2.destroyAllWindows()
                exit()


        # Get output returns and remove dead tracklet
        i = len(self.trackers)
        for trk in reversed(self.trackers):
            if self.kf3d:
                dep_ = tu.worldtocamera(trk.kf.x[:3].T,
                                        self.cam_pose)[0, 2]
            elif self.lstm3d or self.lstmkf3d:
                dep_ = tu.worldtocamera(trk.x[:3][np.newaxis],
                                        self.cam_pose)[0, 2]
            if (trk.time_since_update < 1) and not trk.occ and (
                    trk.hit_streak >= self.min_hits or self.frame_count <=
                    self.min_hits):
                # +1 as MOT benchmark requires positive
                height = float(trk.det[3] - trk.det[1])
                width = float(trk.det[2] - trk.det[0])
                hypo = {'height': height,
                        'width': width,
                        'trk_box': trk.trk_box.tolist(),
                        'det_box': trk.det.tolist(),
                        'id': trk.id + 1,
                        'x': int(trk.cen[0]),
                        'y': int(trk.cen[1]),
                        'dim': trk.dim.tolist(),
                        'alpha': trk.alpha.item(),
                        'roty': trk.rot.item(),
                        'depth': float(dep_) 
                        }
                ret.append(hypo)
            i -= 1
            # remove dead tracklet
            if dep_ <= self.occ_min_depth or dep_ >= self.occ_max_depth or \
                    (trk.time_since_update > self.max_age and not trk.occ):
                self.trackers.pop(i)

        return ret