Exemple #1
0
    def validate6Dpose(self, gt_pose, instance_bb8det=None, model_id=None):
        gt_pose = np.reshape(gt_pose, newshape=(4, 4))
        model_objx_info = self.models_info[instance_bb8det[0] + model_id]

        min_x = model_objx_info['min_x'] * self.scale_to_meters
        min_y = model_objx_info['min_y'] * self.scale_to_meters
        min_z = model_objx_info['min_z'] * self.scale_to_meters
        size_x = model_objx_info['size_x'] * self.scale_to_meters
        size_y = model_objx_info['size_y'] * self.scale_to_meters
        size_z = model_objx_info['size_z'] * self.scale_to_meters
        max_x = min_x + size_x
        max_y = min_y + size_y
        max_z = min_z + size_z

        BoundingBox = np.zeros(shape=(8, 3))
        BoundingBox[0, :] = np.array([min_x, min_y, min_z])
        BoundingBox[1, :] = np.array([min_x, min_y, max_z])
        BoundingBox[2, :] = np.array([min_x, max_y, max_z])
        BoundingBox[3, :] = np.array([min_x, max_y, min_z])

        BoundingBox[4, :] = np.array([max_x, min_y, min_z])
        BoundingBox[5, :] = np.array([max_x, min_y, max_z])
        BoundingBox[6, :] = np.array([max_x, max_y, max_z])
        BoundingBox[7, :] = np.array([max_x, max_y, min_z])
        Xworld = np.reshape(BoundingBox, newshape=(-1, 3, 1))

        Ximg_gt_pix = np.dot(gt_pose[0:3, 0:3], BoundingBox.T) + gt_pose[0:3,
                                                                         3:4]
        Ximg_gt_pix /= Ximg_gt_pix[2, :]
        Ximg_gt_pix = np.dot(self.cam_intrinsic[:, 0:3], Ximg_gt_pix)
        Ximg_gt_pix = Ximg_gt_pix[0:2, :].T
        Ximg_gt_pix = np.reshape(Ximg_gt_pix, newshape=(-1, 2, 1))

        Ximg_pix = instance_bb8det[6:22]
        Ximg_pix = np.reshape(Ximg_pix, newshape=(-1, 2, 1))
        img_shape = np.array([[[640.], [480.]]])  # 1x2x1
        Ximg_pix = Ximg_pix * img_shape

        Ximg_noised_pix = Ximg_gt_pix + np.random.normal(
            loc=0.0, scale=1.0, size=(8, 2, 1))
        epnpSolver = EPnP()
        error, Rt, Cc, Xc = epnpSolver.efficient_pnp_gauss(
            Xworld, Ximg_pix, self.cam_intrinsic)
        out = {"R": Rt[0:3, 0:3], "t": Rt[0:3, 3:4]}

        # absolute pose error
        rot_error = pose_error.re(R_est=out["R"],
                                  R_gt=gt_pose[0:3, 0:3]) / np.pi * 180.
        trans_error = pose_error.te(t_est=out["t"], t_gt=gt_pose[0:3,
                                                                 3:4]) / 0.01

        return out
Exemple #2
0
    # Average Distance of Model Points for objects with no indistinguishable views
    if 'add' in errs_active:
        errs['add'].append(pose_error.add(pose, pose_gt, model))

    # Average Distance of Model Points for objects with indistinguishable views
    if 'adi' in errs_active:
        errs['adi'].append(pose_error.adi(pose, pose_gt, model))

    # Translational Error
    if 'te' in errs_active:
        errs['te'].append(pose_error.te(pose['t'], pose_gt['t']))

    # Rotational Error
    if 're' in errs_active:
        errs['re'].append(pose_error.re(pose['R'], pose_gt['R']))

    # Complement over Union
    if 'cou' in errs_active:
        errs['cou'].append(pose_error.cou(pose, pose_gt, model, im_size, K))

# Plot the errors
for err_name in errs_active:
    plt.figure()
    plt.plot(errs[err_name], c='r', lw='3')
    plt.xlabel('Pose ID')
    plt.ylabel(err_name)
    plt.tick_params(labelsize=16)
    plt.tight_layout()
plt.show()
Exemple #3
0
    def update(self, labels, preds):
        """
        :param preds: [cls_prob, loc_loss, cls_label, bb8_loss, loc_pred, bb8_pred,
                           anchors, loc_label, loc_pred_masked, loc_mae, bb8_label, bb8_pred_masked, bb8_mae]
        Implementation of updating metrics
        """
        def iou(x, ys):
            """
            Calculate intersection-over-union overlap
            Params:
            ----------
            x : numpy.array
                single box [xmin, ymin ,xmax, ymax]
            ys : numpy.array
                multiple box [[xmin, ymin, xmax, ymax], [...], ]
            Returns:
            -----------
            numpy.array
                [iou1, iou2, ...], size == ys.shape[0]
            """
            ixmin = np.maximum(ys[:, 0], x[0])
            iymin = np.maximum(ys[:, 1], x[1])
            ixmax = np.minimum(ys[:, 2], x[2])
            iymax = np.minimum(ys[:, 3], x[3])
            iw = np.maximum(ixmax - ixmin, 0.)
            ih = np.maximum(iymax - iymin, 0.)
            inters = iw * ih
            uni = (x[2] - x[0]) * (x[3] - x[1]) + (ys[:, 2] - ys[:, 0]) * \
                (ys[:, 3] - ys[:, 1]) - inters
            ious = inters / uni
            ious[uni < 1e-12] = 0  # in case bad boxes
            return ious

        labels = labels[0].asnumpy()
        # get generated multi label from network
        cls_prob = preds[0]
        loc_loss = preds[1].asnumpy()  # smoothL1 loss
        loc_loss_in_use = loc_loss[loc_loss.nonzero()]
        cls_label = preds[2].asnumpy()
        bb8_loss = preds[3].asnumpy()
        loc_pred = preds[4]
        bb8_pred = preds[5]
        anchors = preds[6]
        # anchor_in_use = anchors[anchors.nonzero()]
        bb8dets = BB8MultiBoxDetection(cls_prob,
                                       loc_pred,
                                       bb8_pred,
                                       anchors,
                                       nms_threshold=0.5,
                                       force_suppress=False,
                                       variances=(0.1, 0.1, 0.2, 0.2),
                                       nms_topk=400)
        bb8dets = bb8dets.asnumpy()

        # for i in range(1,16):
        #     if self.classes[0] == 'obj_{:02d}'.format(i):
        #         model_id = i
        #         break
        model_id = int(self.classes[0].strip('obj_'))

        for nbatch in range(bb8dets.shape[0]):

            self.num_inst[7] += 1
            self.num_inst[8] += 1
            self.num_inst[9] += 1
            self.num_inst[10] += 1
            self.num_inst[11] += 1
            self.num_inst[12] += 1
            self.num_inst[13] += 1
            self.num_inst[14] += 1
            self.num_inst[15] += 1

            if bb8dets[nbatch, 0, 0] == -1:
                continue
            else:
                # for LINEMOD dataset, for each image only select the first det
                # self.validate6Dpose(gt_pose=labels[nbatch, 0, 24:40], instance_bb8det=bb8dets[nbatch, 0, :], model_id=model_id)
                pose_est = self.calculate6Dpose(instance_bb8det=bb8dets[nbatch,
                                                                        0, :],
                                                model_id=model_id)
                model_path = '/data/ZHANGXIN/DATASETS/SIXD_CHALLENGE/LINEMOD/models/' + self.classes[
                    int(bb8dets[nbatch, 0, 0])] + '.ply'
                model_ply = inout.load_ply(model_path)
                model_ply['pts'] = model_ply['pts'] * self.scale_to_meters
                pose_gt_transform = np.reshape(labels[nbatch, 0, 24:40],
                                               newshape=(4, 4))
                pose_gt = {
                    "R": pose_gt_transform[0:3, 0:3],
                    "t": pose_gt_transform[0:3, 3:4]
                }

                # absolute pose error
                rot_error = pose_error.re(R_est=pose_est["R"],
                                          R_gt=pose_gt["R"]) / np.pi * 180.
                trans_error = pose_error.te(t_est=pose_est["t"],
                                            t_gt=pose_gt["t"]) / 0.01

                # other pose metrics
                if model_id in [10, 11]:
                    add_metric = pose_error.adi(
                        pose_est=pose_est, pose_gt=pose_gt, model=model_ply
                    )  # use adi when object is eggbox or glue
                else:
                    add_metric = pose_error.add(
                        pose_est=pose_est, pose_gt=pose_gt, model=model_ply
                    )  # use adi when object is eggbox or glue
                reproj_metric = pose_error.reprojectionError(
                    pose_est=pose_est,
                    pose_gt=pose_gt,
                    model=model_ply,
                    K=self.cam_intrinsic[:, 0:3])
                cou_metric = pose_error.cou(pose_est=pose_est,
                                            pose_gt=pose_gt,
                                            model=model_ply,
                                            im_size=(640, 480),
                                            K=self.cam_intrinsic[:, 0:3])

                # metric update
                if reproj_metric <= 5:  # reprojection error less than 5 pixels
                    self.sum_metric[7] += 1

                if add_metric <= self.models_info[
                        bb8dets[nbatch, 0, 0] + model_id][
                            'diameter'] * self.scale_to_meters * 0.1:  # ADD metric less than 0.1 * diameter
                    self.sum_metric[8] += 1

                if add_metric <= self.models_info[
                        bb8dets[nbatch, 0, 0] + model_id][
                            'diameter'] * self.scale_to_meters * 0.3:  # ADD metric less than 0.1 * diameter
                    self.sum_metric[9] += 1

                if add_metric <= self.models_info[
                        bb8dets[nbatch, 0, 0] + model_id][
                            'diameter'] * self.scale_to_meters * 0.5:  # ADD metric less than 0.1 * diameter
                    self.sum_metric[10] += 1

                if rot_error < 5:  # 5 degrees
                    self.sum_metric[11] += 1

                if trans_error < 5:  # 5 cm
                    self.sum_metric[12] += 1

                if (rot_error < 5) and (trans_error < 5):  # 5 degrees and 5 cm
                    self.sum_metric[13] += 1

                if cou_metric < 0.5:  # 2D IoU greater than 0.5
                    self.sum_metric[14] += 1

                if cou_metric < 0.1:  # 2D IoU larger than 0.9
                    self.sum_metric[15] += 1

        loc_label = preds[7].asnumpy()
        loc_label_in_use = loc_label[loc_label.nonzero()]
        loc_pred_masked = preds[8].asnumpy()
        loc_pred_in_use = loc_pred_masked[loc_pred_masked.nonzero()]
        loc_mae = preds[9].asnumpy()
        loc_mae_in_use = loc_mae[loc_mae.nonzero()]
        loc_mae_pixel = np.abs((bb8dets[:, 0, 2:6] - labels[:, 0, 1:5]) *
                               300)  # need to be refined

        bb8_label = preds[10].asnumpy()
        bb8_label_in_use = bb8_label[bb8_label.nonzero()]
        bb8_pred = preds[11].asnumpy()
        bb8_pred_in_use = bb8_pred[bb8_pred.nonzero()]
        bb8_mae = preds[12].asnumpy()
        bb8_mae_in_use = bb8_mae[bb8_mae.nonzero()]
        bb8_mae_pixel = np.abs((labels[:, 0, 8:24] - bb8dets[:, 0, 6:22]) *
                               300)  # need to be refined
        bb8_mae_pixel_x = bb8_mae_pixel[:, [0, 2, 4, 6, 8, 10, 12, 14]]
        bb8_mae_pixel_y = bb8_mae_pixel[:, [1, 3, 5, 7, 9, 11, 13, 15]]
        bb8_mae_pixel = np.sqrt(
            np.square(bb8_mae_pixel_x) + np.square(bb8_mae_pixel_y))
        # multi objects in one image (to be done)
        # loc_mae_pixel = []
        # bb8_mae_pixel = []
        # # independant execution for each image
        # for i in range(labels.shape[0]):
        #     # get as numpy arrays
        #     label = labels[i]
        #     pred = bb8dets[i]
        #     loc_mae_pixel_per_image = []
        #     bb8_mae_pixel_per_image = []
        #     # calculate for each class
        #     while (pred.shape[0] > 0):
        #         cid = int(pred[0, 0])
        #         indices = np.where(pred[:, 0].astype(int) == cid)[0]
        #         if cid < 0:
        #             pred = np.delete(pred, indices, axis=0)
        #             continue
        #         dets = pred[indices]
        #         pred = np.delete(pred, indices, axis=0)
        #
        #         # ground-truths
        #         label_indices = np.where(label[:, 0].astype(int) == cid)[0]
        #         gts = label[label_indices, :]
        #         label = np.delete(label, label_indices, axis=0)
        #         if gts.size > 0:
        #             found = [False] * gts.shape[0]
        #             for j in range(dets.shape[0]):
        #                 # compute overlaps
        #                 ious = iou(dets[j, 2:6], gts[:, 1:5])
        #                 ovargmax = np.argmax(ious)
        #                 ovmax = ious[ovargmax]
        #                 if ovmax > self.ovp_thresh:
        #                     if not found[ovargmax]:
        #                         loc_mae_pixel_per_image.append(np.abs((dets[j, 2:6] - gts[ovargmax, 1:5]) * 300))   # tp
        #                         bb8_mae_pixel_per_image.append(np.abs((dets[j, 6:22] - gts[ovargmax, 8:24]) * 300))
        #                         found[ovargmax] = True
        #                     else:
        #                         # duplicate
        #                         pass  # fp
        #
        #     loc_mae_pixel.append(np.mean(loc_mae_pixel_per_image, axis=1))
        #     bb8_mae_pixel.append(np.mean(bb8_mae_pixel_per_image, axis=1))

        valid_count = np.sum(cls_label >= 0)
        box_count = np.sum(cls_label > 0)
        # overall accuracy & object accuracy
        label = cls_label.flatten()
        # in case you have a 'other' class
        label[np.where(label >= cls_prob.shape[1])] = 0
        mask = np.where(label >= 0)[0]
        indices = np.int64(label[mask])
        prob = cls_prob.transpose((0, 2, 1)).reshape(
            (-1, cls_prob.shape[1])).asnumpy()
        prob = prob[mask, indices]
        self.sum_metric[0] += (-np.log(prob + self.eps)).sum()
        self.num_inst[0] += valid_count
        # loc_smoothl1loss
        self.sum_metric[1] += np.sum(loc_loss)
        self.num_inst[1] += box_count * 4
        # loc_mae
        self.sum_metric[2] += np.sum(loc_mae)
        self.num_inst[2] += box_count * 4
        # loc_mae_pixel
        self.sum_metric[3] += np.sum(loc_mae_pixel)
        self.num_inst[3] += loc_mae_pixel.size
        # bb8_smoothl1loss
        self.sum_metric[4] += np.sum(bb8_loss)
        self.num_inst[4] += box_count * 16
        # bb8_mae
        self.sum_metric[5] += np.sum(bb8_mae)
        self.num_inst[5] += box_count * 16
        # bb8_mae_pixel
        self.sum_metric[6] += np.sum(bb8_mae_pixel)
        self.num_inst[6] += bb8_mae_pixel.size
Exemple #4
0
    def update(self, labels, preds):
        """
        :param preds: [cls_prob, loc_loss, cls_label, bb8_loss, loc_pred, bb8_pred,
                           anchors, loc_label, loc_pred_masked, loc_mae, bb8_label, bb8_pred_masked, bb8_mae]
        Implementation of updating metrics
        """
        labels = labels[0].asnumpy()  # batchsize x 8 x 40
        # get generated multi label from network
        cls_prob = preds[0]  # batchsize x num_cls x num_anchors
        loc_loss = preds[1].asnumpy()  # smoothL1 loss
        # loc_loss_in_use = loc_loss[loc_loss.nonzero()]
        cls_target = preds[2].asnumpy()  # batchsize x num_anchors
        bb8_loss = preds[3].asnumpy()
        loc_pred = preds[4]
        bb8_pred = preds[5]
        anchors = preds[6]
        # anchor_in_use = anchors[anchors.nonzero()]

        # monitor results
        # loc_target = preds[7].asnumpy()
        # loc_target_in_use = loc_target[loc_target.nonzero()]
        # loc_pred_masked = preds[8].asnumpy()
        # loc_pred_in_use = loc_pred_masked[loc_pred_masked.nonzero()]
        loc_mae = preds[9].asnumpy()
        # loc_mae_in_use = loc_mae[loc_mae.nonzero()]
        # bb8_target = preds[10].asnumpy()
        # bb8_target_in_use = bb8_target[bb8_target.nonzero()]
        # bb8_pred_masked = preds[11].asnumpy()
        # bb8_pred_in_use = bb8_pred_masked[bb8_pred_masked.nonzero()]
        bb8_mae = preds[12].asnumpy()
        # bb8_mae_in_use = bb8_mae[bb8_mae.nonzero()]

        # basic evaluation
        valid_count = np.sum(cls_target >= 0)
        box_count = np.sum(cls_target > 0)
        # overall accuracy & object accuracy
        label = cls_target.flatten()
        # in case you have a 'other' class
        label[np.where(label >= cls_prob.shape[1])] = 0
        mask = np.where(label >= 0)[0]
        indices = np.int64(label[mask])
        prob = cls_prob.transpose((0, 2, 1)).reshape(
            (-1, cls_prob.shape[1])).asnumpy()
        prob = prob[mask, indices]
        self.sum_metric[0] += (-np.log(prob + self.eps)).sum()
        self.num_inst[0] += valid_count
        # loc_smoothl1loss
        self.sum_metric[1] += np.sum(loc_loss)
        self.num_inst[1] += box_count * 4
        # loc_mae
        self.sum_metric[2] += np.sum(loc_mae)
        self.num_inst[2] += box_count * 4
        # bb8_smoothl1loss
        self.sum_metric[4] += np.sum(bb8_loss)
        self.num_inst[4] += box_count * 16
        # bb8_mae
        self.sum_metric[5] += np.sum(bb8_mae)
        self.num_inst[5] += box_count * 16

        bb8dets = IndirectBB8MultiBoxDetection(cls_prob,
                                               loc_pred,
                                               bb8_pred,
                                               anchors,
                                               nms_threshold=0.5,
                                               force_suppress=False,
                                               variances=(0.1, 0.1, 0.2, 0.2),
                                               nms_topk=400)
        bb8dets = bb8dets.asnumpy()

        loc_mae_pixel = []
        bb8_mae_pixel = []

        # pose metrics, adapt to multi-class case
        for sampleDet, sampleLabel in zip(bb8dets, labels):
            # calculate for each class
            for instanceLabel in sampleLabel:
                if instanceLabel[0] < 0:
                    continue
                else:
                    cid = instanceLabel[0].astype(np.int16)
                    model_id = int(self.classes[cid].strip("obj_"))
                    indices = np.where(sampleDet[:, 0] == cid)[0]

                    if cid in self.counts:
                        self.counts[cid] += 1
                    else:
                        self.counts[cid] = 1

                    if indices.size > 0:
                        instanceDet = sampleDet[indices[
                            0]]  # only consider the most confident instance

                        loc_mae_pixel.append(
                            np.abs((instanceDet[2:6] - instanceLabel[1:5]) *
                                   300.))
                        bb8_mae_pixel.append(
                            np.abs((instanceDet[6:22] - instanceLabel[8:24]) *
                                   300.))

                        pose_est = self.calculate6Dpose(
                            instance_bb8det=instanceDet, model_id=model_id)
                        model_path = os.path.join(
                            self.LINEMOD_path, 'models',
                            '{}.ply'.format(self.classes[cid]))
                        model_ply = inout.load_ply(model_path)
                        model_ply[
                            'pts'] = model_ply['pts'] * self.scale_to_meters
                        pose_gt_transform = np.reshape(instanceLabel[24:40],
                                                       newshape=(4, 4))
                        pose_gt = {
                            "R": pose_gt_transform[0:3, 0:3],
                            "t": pose_gt_transform[0:3, 3:4]
                        }

                        # absolute pose error
                        rot_error = pose_error.re(
                            R_est=pose_est["R"],
                            R_gt=pose_gt["R"]) / np.pi * 180.
                        trans_error = pose_error.te(t_est=pose_est["t"],
                                                    t_gt=pose_gt["t"]) / 0.01

                        # other pose metrics
                        if model_id in [10, 11]:
                            add_metric = pose_error.adi(
                                pose_est=pose_est,
                                pose_gt=pose_gt,
                                model=model_ply
                            )  # use adi when object is eggbox or glue
                        else:
                            add_metric = pose_error.add(
                                pose_est=pose_est,
                                pose_gt=pose_gt,
                                model=model_ply)  # use add otherwise

                        reproj_metric = pose_error.reprojectionError(
                            pose_est=pose_est,
                            pose_gt=pose_gt,
                            model=model_ply,
                            K=self.cam_intrinsic[:, 0:3])
                        cou_metric = pose_error.cou(pose_est=pose_est,
                                                    pose_gt=pose_gt,
                                                    model=model_ply,
                                                    im_size=(640, 480),
                                                    K=self.cam_intrinsic[:,
                                                                         0:3])

                        # record all the Reproj. error to plot curve
                        if cid not in self.Reproj:
                            self.Reproj[cid] = [reproj_metric]
                        else:
                            assert cid in self.counts
                            self.Reproj[cid] += [reproj_metric]

                        # metric update
                        if reproj_metric <= 5:  # reprojection error less than 5 pixels
                            if cid not in self.Reproj5px:
                                self.Reproj5px[cid] = 1
                            else:
                                assert cid in self.counts
                                self.Reproj5px[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.1:  # ADD metric less than 0.1 * diameter
                            if cid not in self.ADD0_1:
                                self.ADD0_1[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_1[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.3:  # ADD metric less than 0.3 * diameter
                            if cid not in self.ADD0_3:
                                self.ADD0_3[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_3[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.5:  # ADD metric less than 0.5 * diameter
                            if cid not in self.ADD0_5:
                                self.ADD0_5[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_5[cid] += 1

                        if rot_error < 5:  # 5 degrees
                            if cid not in self.re:
                                self.re[cid] = 1
                            else:
                                assert cid in self.counts
                                self.re[cid] += 1

                        if trans_error < 5:  # 5 cm
                            if cid not in self.te:
                                self.te[cid] = 1
                            else:
                                assert cid in self.counts
                                self.te[cid] += 1

                        if (rot_error < 5) and (trans_error <
                                                5):  # 5 degrees and 5 cm
                            if cid not in self.re_te:
                                self.re_te[cid] = 1
                            else:
                                assert cid in self.counts
                                self.re_te[cid] += 1

                        if cou_metric < 0.5:  # 2D IoU greater than 0.5
                            if cid not in self.IoU2D0_5:
                                self.IoU2D0_5[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_5[cid] += 1

                        if cou_metric < 0.4:  # 2D IoU greater than 0.6
                            if cid not in self.IoU2D0_6:
                                self.IoU2D0_6[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_6[cid] += 1

                        if cou_metric < 0.3:  # 2D IoU greater than 0.7
                            if cid not in self.IoU2D0_7:
                                self.IoU2D0_7[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_7[cid] += 1

                        if cou_metric < 0.2:  # 2D IoU greater than 0.8
                            if cid not in self.IoU2D0_8:
                                self.IoU2D0_8[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_8[cid] += 1

                        if cou_metric < 0.1:  # 2D IoU larger than 0.9
                            if cid not in self.IoU2D0_9:
                                self.IoU2D0_9[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_9[cid] += 1
                    # else:
                    #     loc_mae_pixel.append(np.ones((4, )) * 300.)
                    #     bb8_mae_pixel.append(np.ones((16, )) * 300.)

        loc_mae_pixel = np.array(loc_mae_pixel)
        loc_mae_pixel_x = loc_mae_pixel[:, [0, 2]]
        loc_mae_pixel_y = loc_mae_pixel[:, [1, 3]]
        loc_mae_pixel = np.sqrt(
            np.square(loc_mae_pixel_x) + np.square(loc_mae_pixel_y))
        bb8_mae_pixel = np.array(bb8_mae_pixel)
        bb8_mae_pixel_x = bb8_mae_pixel[:, [0, 2, 4, 6, 8, 10, 12, 14]]
        bb8_mae_pixel_y = bb8_mae_pixel[:, [1, 3, 5, 7, 9, 11, 13, 15]]
        bb8_mae_pixel = np.sqrt(
            np.square(bb8_mae_pixel_x) + np.square(bb8_mae_pixel_y))

        # loc_mae_pixel
        self.sum_metric[3] += np.sum(loc_mae_pixel)
        self.num_inst[3] += loc_mae_pixel.size
        # bb8_mae_pixel
        self.sum_metric[6] += np.sum(bb8_mae_pixel)
        self.num_inst[6] += bb8_mae_pixel.size
    def update(self, labels, preds):
        """
        :param preds: [cls_prob, loc_loss, cls_label, bb8_loss, loc_pred, bb8_pred,
                           anchors, loc_label, loc_pred_masked, loc_mae, bb8_label, bb8_pred_masked, bb8_mae]
        Implementation of updating metrics
        """
        def iou(x, ys):
            """
            Calculate intersection-over-union overlap
            Params:
            ----------
            x : numpy.array
                single box [xmin, ymin ,xmax, ymax]
            ys : numpy.array
                multiple box [[xmin, ymin, xmax, ymax], [...], ]
            Returns:
            -----------
            numpy.array
                [iou1, iou2, ...], size == ys.shape[0]
            """
            ixmin = np.maximum(ys[:, 0], x[0])
            iymin = np.maximum(ys[:, 1], x[1])
            ixmax = np.minimum(ys[:, 2], x[2])
            iymax = np.minimum(ys[:, 3], x[3])
            iw = np.maximum(ixmax - ixmin, 0.)
            ih = np.maximum(iymax - iymin, 0.)
            inters = iw * ih
            uni = (x[2] - x[0]) * (x[3] - x[1]) + (ys[:, 2] - ys[:, 0]) * \
                (ys[:, 3] - ys[:, 1]) - inters
            ious = inters / uni
            ious[uni < 1e-12] = 0  # in case bad boxes
            return ious

        labels = labels[0].asnumpy()  # batchsize x 8 x 40
        # get generated multi label from network
        cls_prob = preds[0]  # batchsize x num_cls x num_anchors
        loc_loss = preds[1].asnumpy()  # smoothL1 loss
        loc_loss_in_use = loc_loss[loc_loss.nonzero()]
        cls_label = preds[2].asnumpy()  # batchsize x num_anchors
        bb8_loss = preds[3].asnumpy()
        loc_pred = preds[4]
        bb8_pred = preds[5]
        anchors = preds[6]
        # anchor_in_use = anchors[anchors.nonzero()]

        # basic evaluation, adapt to multi-class
        loc_label = preds[7].asnumpy()
        loc_label_in_use = loc_label[loc_label.nonzero()]
        loc_pred_masked = preds[8].asnumpy()
        loc_pred_in_use = loc_pred_masked[loc_pred_masked.nonzero()]
        loc_mae = preds[9].asnumpy()
        loc_mae_in_use = loc_mae[loc_mae.nonzero()]
        bb8_label = preds[10].asnumpy()
        bb8_label_in_use = bb8_label[bb8_label.nonzero()]
        bb8_pred_masked = preds[11].asnumpy()
        bb8_pred_in_use = bb8_pred_masked[bb8_pred_masked.nonzero()]
        bb8_mae = preds[12].asnumpy()
        bb8_mae_in_use = bb8_mae[bb8_mae.nonzero()]

        valid_count = np.sum(cls_label >= 0)
        box_count = np.sum(cls_label > 0)
        # overall accuracy & object accuracy
        label = cls_label.flatten()
        # in case you have a 'other' class
        label[np.where(label >= cls_prob.shape[1])] = 0
        mask = np.where(label >= 0)[0]
        indices = np.int64(label[mask])
        prob = cls_prob.transpose((0, 2, 1)).reshape(
            (-1, cls_prob.shape[1])).asnumpy()
        prob = prob[mask, indices]
        self.sum_metric[0] += (-np.log(prob + self.eps)).sum()
        self.num_inst[0] += valid_count
        # loc_smoothl1loss
        self.sum_metric[1] += np.sum(loc_loss)
        self.num_inst[1] += box_count * 4
        # loc_mae
        self.sum_metric[2] += np.sum(loc_mae)
        self.num_inst[2] += box_count * 4
        # bb8_smoothl1loss
        self.sum_metric[4] += np.sum(bb8_loss)
        self.num_inst[4] += box_count * 16
        # bb8_mae
        self.sum_metric[5] += np.sum(bb8_mae)
        self.num_inst[5] += box_count * 16

        bb8dets = BB8MultiBoxDetection(cls_prob,
                                       loc_pred,
                                       bb8_pred,
                                       anchors,
                                       nms_threshold=0.5,
                                       force_suppress=False,
                                       variances=(0.1, 0.1, 0.2, 0.2),
                                       nms_topk=400)
        bb8dets = bb8dets.asnumpy()

        # model_id = int(self.classes[0].strip("obj_"))
        #
        # for nbatch in range(bb8dets.shape[0]):
        #
        #     self.num_inst[7] += 1
        #     self.num_inst[8] += 1
        #     self.num_inst[9] += 1
        #     self.num_inst[10] += 1
        #     self.num_inst[11] += 1
        #     self.num_inst[12] += 1
        #     self.num_inst[13] += 1
        #     self.num_inst[14] += 1
        #     self.num_inst[15] += 1
        #
        #     if bb8dets[nbatch, 0, 0] == -1:
        #         continue
        #     else:
        #         # for LINEMOD dataset, for each image only select the first det
        #         # self.validate6Dpose(gt_pose=labels[nbatch, 0, 24:40], instance_bb8det=bb8dets[nbatch, 0, :], model_id=model_id)
        #         pose_est = self.calculate6Dpose(instance_bb8det=bb8dets[nbatch, 0, :], model_id=model_id)
        #         model_path = '/data/ZHANGXIN/DATASETS/SIXD_CHALLENGE/LINEMOD/models/' + self.classes[int(bb8dets[nbatch, 0, 0])] + '.ply'
        #         model_ply = inout.load_ply(model_path)
        #         model_ply['pts'] = model_ply['pts'] * self.scale_to_meters
        #         pose_gt_transform = np.reshape(labels[nbatch, 0, 24:40], newshape=(4, 4))
        #         pose_gt = {"R": pose_gt_transform[0:3, 0:3],
        #         "t": pose_gt_transform[0:3, 3:4]}
        #
        #         # absolute pose error
        #         rot_error = pose_error.re(R_est=pose_est["R"], R_gt=pose_gt["R"]) / np.pi * 180.
        #         trans_error = pose_error.te(t_est=pose_est["t"], t_gt=pose_gt["t"]) / 0.01
        #
        #         # other pose metrics
        #         add_metric = pose_error.add(pose_est=pose_est, pose_gt=pose_gt, model=model_ply)    # use adi when object is eggbox or glue
        #         reproj_metric = pose_error.reprojectionError(pose_est=pose_est, pose_gt=pose_gt,
        #                                                      model=model_ply, K=self.cam_intrinsic[:, 0:3])
        #         cou_metric = pose_error.cou(pose_est=pose_est, pose_gt=pose_gt,
        #                                     model=model_ply, im_size=(640, 480), K=self.cam_intrinsic[:, 0:3])
        #
        #         # metric update
        #         if reproj_metric <= 5:  # reprojection error less than 5 pixels
        #             self.sum_metric[7] += 1
        #
        #         if add_metric <= self.models_info[bb8dets[nbatch, 0, 0] + model_id]['diameter'] * self.scale_to_meters * 0.1:   # ADD metric less than 0.1 * diameter
        #             self.sum_metric[8] += 1
        #
        #         if add_metric <= self.models_info[bb8dets[nbatch, 0, 0] + model_id]['diameter'] * self.scale_to_meters * 0.3:   # ADD metric less than 0.1 * diameter
        #             self.sum_metric[9] += 1
        #
        #         if add_metric <= self.models_info[bb8dets[nbatch, 0, 0] + model_id]['diameter'] * self.scale_to_meters * 0.5:   # ADD metric less than 0.1 * diameter
        #             self.sum_metric[10] += 1
        #
        #         if rot_error < 5:   # 5 degrees
        #             self.sum_metric[11] += 1
        #
        #         if trans_error < 5: # 5 cm
        #             self.sum_metric[12] += 1
        #
        #         if (rot_error < 5) and (trans_error < 5):   # 5 degrees and 5 cm
        #             self.sum_metric[13] += 1
        #
        #         if cou_metric < 0.5:    # 2D IoU greater than 0.5
        #             self.sum_metric[14] += 1
        #
        #         if cou_metric < 0.1:    # 2D IoU larger than 0.9
        #             self.sum_metric[15] += 1
        #

        # for each class, only consider the most confident instance
        loc_mae_pixel = []
        bb8_mae_pixel = []

        # pose metrics, adapt to multi-class case
        for sampleDet, sampleLabel in zip(bb8dets, labels):
            # calculate for each class
            for instanceLabel in sampleLabel:
                if instanceLabel[0] < 0:
                    continue
                else:
                    cid = instanceLabel[0].astype(np.int16)
                    model_id = int(self.classes[cid].strip("obj_"))
                    indices = np.where(sampleDet[:, 0] == cid)[0]

                    if cid in self.counts:
                        self.counts[cid] += 1
                    else:
                        self.counts[cid] = 1

                    if indices.size > 0:
                        instanceDet = sampleDet[indices[
                            0]]  # only consider the most confident instance

                        loc_mae_pixel.append(
                            np.abs(
                                (instanceDet[2:6] - instanceLabel[1:5]) * 300))
                        bb8_mae_pixel.append(
                            np.abs((instanceDet[6:22] - instanceLabel[8:24]) *
                                   300))

                        pose_est = self.calculate6Dpose(
                            instance_bb8det=instanceDet, model_id=model_id)
                        model_path = os.path.join(
                            self.LINEMOD_path, 'models',
                            '{}.ply'.format(self.classes[cid]))
                        model_ply = inout.load_ply(model_path)
                        model_ply[
                            'pts'] = model_ply['pts'] * self.scale_to_meters
                        pose_gt_transform = np.reshape(instanceLabel[24:40],
                                                       newshape=(4, 4))
                        pose_gt = {
                            "R": pose_gt_transform[0:3, 0:3],
                            "t": pose_gt_transform[0:3, 3:4]
                        }

                        # absolute pose error
                        rot_error = pose_error.re(
                            R_est=pose_est["R"],
                            R_gt=pose_gt["R"]) / np.pi * 180.
                        trans_error = pose_error.te(t_est=pose_est["t"],
                                                    t_gt=pose_gt["t"]) / 0.01

                        # other pose metrics
                        if model_id in [10, 11]:
                            add_metric = pose_error.adi(
                                pose_est=pose_est,
                                pose_gt=pose_gt,
                                model=model_ply
                            )  # use adi when object is eggbox or glue
                        else:
                            add_metric = pose_error.add(
                                pose_est=pose_est,
                                pose_gt=pose_gt,
                                model=model_ply
                            )  # use adi when object is eggbox or glue

                        reproj_metric = pose_error.reprojectionError(
                            pose_est=pose_est,
                            pose_gt=pose_gt,
                            model=model_ply,
                            K=self.cam_intrinsic[:, 0:3])
                        cou_metric = pose_error.cou(pose_est=pose_est,
                                                    pose_gt=pose_gt,
                                                    model=model_ply,
                                                    im_size=(640, 480),
                                                    K=self.cam_intrinsic[:,
                                                                         0:3])

                        if cid not in self.Reproj:
                            self.Reproj[cid] = [reproj_metric]
                        else:
                            assert cid in self.counts
                            self.Reproj[cid] += [reproj_metric]

                        # metric update
                        if reproj_metric <= 5:  # reprojection error less than 5 pixels
                            if cid not in self.Reproj5px:
                                self.Reproj5px[cid] = 1
                            else:
                                assert cid in self.counts
                                self.Reproj5px[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.1:  # ADD metric less than 0.1 * diameter
                            if cid not in self.ADD0_1:
                                self.ADD0_1[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_1[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.3:  # ADD metric less than 0.3 * diameter
                            if cid not in self.ADD0_3:
                                self.ADD0_3[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_3[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.5:  # ADD metric less than 0.5 * diameter
                            if cid not in self.ADD0_5:
                                self.ADD0_5[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_5[cid] += 1

                        if rot_error < 5:  # 5 degrees
                            if cid not in self.re:
                                self.re[cid] = 1
                            else:
                                assert cid in self.counts
                                self.re[cid] += 1

                        if trans_error < 5:  # 5 cm
                            if cid not in self.te:
                                self.te[cid] = 1
                            else:
                                assert cid in self.counts
                                self.te[cid] += 1

                        if (rot_error < 5) and (trans_error <
                                                5):  # 5 degrees and 5 cm
                            if cid not in self.re_te:
                                self.re_te[cid] = 1
                            else:
                                assert cid in self.counts
                                self.re_te[cid] += 1

                        if cou_metric < 0.5:  # 2D IoU greater than 0.5
                            if cid not in self.IoU2D0_5:
                                self.IoU2D0_5[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_5[cid] += 1

                        if cou_metric < 0.1:  # 2D IoU larger than 0.9
                            if cid not in self.IoU2D0_9:
                                self.IoU2D0_9[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_9[cid] += 1

        loc_mae_pixel = np.array(loc_mae_pixel)
        bb8_mae_pixel = np.array(bb8_mae_pixel)
        bb8_mae_pixel_x = bb8_mae_pixel[:, [0, 2, 4, 6, 8, 10, 12, 14]]
        bb8_mae_pixel_y = bb8_mae_pixel[:, [1, 3, 5, 7, 9, 11, 13, 15]]
        bb8_mae_pixel = np.sqrt(
            np.square(bb8_mae_pixel_x) + np.square(bb8_mae_pixel_y))

        # loc_mae_pixel
        self.sum_metric[3] += np.sum(loc_mae_pixel)
        self.num_inst[3] += loc_mae_pixel.size
        # bb8_mae_pixel
        self.sum_metric[6] += np.sum(bb8_mae_pixel)
        self.num_inst[6] += bb8_mae_pixel.size