Ejemplo n.º 1
0
    if 'acpd' in errs_active:
        errs['acpd'].append(
            pose_error.acpd(pose_indis_set, pose_gt_indis_set, model))

    # Maximum Corresponding Point Distance
    if 'mcpd' in errs_active:
        errs['mcpd'].append(
            pose_error.mcpd(pose_indis_set, pose_gt_indis_set, model))

    # Average Distance of Model Points for objects with no indistinguishable views
    if 'add' in errs_active:
        errs['add'].append(pose_error.add(pose, pose_gt, model))

    # Average Distance of Model Points for objects with indistinguishable views
    if 'adi' in errs_active:
        errs['adi'].append(pose_error.adi(pose, pose_gt, model))

    # Translational Error
    if 'te' in errs_active:
        errs['te'].append(pose_error.te(pose['t'], pose_gt['t']))

    # Rotational Error
    if 're' in errs_active:
        errs['re'].append(pose_error.re(pose['R'], pose_gt['R']))

    # Complement over Union
    if 'cou' in errs_active:
        errs['cou'].append(pose_error.cou(pose, pose_gt, model, im_size, K))

# Plot the errors
for err_name in errs_active:
Ejemplo n.º 2
0
    def update(self, labels, preds):
        """
        :param preds: [cls_prob, loc_loss, cls_label, bb8_loss, loc_pred, bb8_pred,
                           anchors, loc_label, loc_pred_masked, loc_mae, bb8_label, bb8_pred_masked, bb8_mae]
        Implementation of updating metrics
        """
        def iou(x, ys):
            """
            Calculate intersection-over-union overlap
            Params:
            ----------
            x : numpy.array
                single box [xmin, ymin ,xmax, ymax]
            ys : numpy.array
                multiple box [[xmin, ymin, xmax, ymax], [...], ]
            Returns:
            -----------
            numpy.array
                [iou1, iou2, ...], size == ys.shape[0]
            """
            ixmin = np.maximum(ys[:, 0], x[0])
            iymin = np.maximum(ys[:, 1], x[1])
            ixmax = np.minimum(ys[:, 2], x[2])
            iymax = np.minimum(ys[:, 3], x[3])
            iw = np.maximum(ixmax - ixmin, 0.)
            ih = np.maximum(iymax - iymin, 0.)
            inters = iw * ih
            uni = (x[2] - x[0]) * (x[3] - x[1]) + (ys[:, 2] - ys[:, 0]) * \
                (ys[:, 3] - ys[:, 1]) - inters
            ious = inters / uni
            ious[uni < 1e-12] = 0  # in case bad boxes
            return ious

        labels = labels[0].asnumpy()
        # get generated multi label from network
        cls_prob = preds[0]
        loc_loss = preds[1].asnumpy()  # smoothL1 loss
        loc_loss_in_use = loc_loss[loc_loss.nonzero()]
        cls_label = preds[2].asnumpy()
        bb8_loss = preds[3].asnumpy()
        loc_pred = preds[4]
        bb8_pred = preds[5]
        anchors = preds[6]
        # anchor_in_use = anchors[anchors.nonzero()]
        bb8dets = BB8MultiBoxDetection(cls_prob,
                                       loc_pred,
                                       bb8_pred,
                                       anchors,
                                       nms_threshold=0.5,
                                       force_suppress=False,
                                       variances=(0.1, 0.1, 0.2, 0.2),
                                       nms_topk=400)
        bb8dets = bb8dets.asnumpy()

        # for i in range(1,16):
        #     if self.classes[0] == 'obj_{:02d}'.format(i):
        #         model_id = i
        #         break
        model_id = int(self.classes[0].strip('obj_'))

        for nbatch in range(bb8dets.shape[0]):

            self.num_inst[7] += 1
            self.num_inst[8] += 1
            self.num_inst[9] += 1
            self.num_inst[10] += 1
            self.num_inst[11] += 1
            self.num_inst[12] += 1
            self.num_inst[13] += 1
            self.num_inst[14] += 1
            self.num_inst[15] += 1

            if bb8dets[nbatch, 0, 0] == -1:
                continue
            else:
                # for LINEMOD dataset, for each image only select the first det
                # self.validate6Dpose(gt_pose=labels[nbatch, 0, 24:40], instance_bb8det=bb8dets[nbatch, 0, :], model_id=model_id)
                pose_est = self.calculate6Dpose(instance_bb8det=bb8dets[nbatch,
                                                                        0, :],
                                                model_id=model_id)
                model_path = '/data/ZHANGXIN/DATASETS/SIXD_CHALLENGE/LINEMOD/models/' + self.classes[
                    int(bb8dets[nbatch, 0, 0])] + '.ply'
                model_ply = inout.load_ply(model_path)
                model_ply['pts'] = model_ply['pts'] * self.scale_to_meters
                pose_gt_transform = np.reshape(labels[nbatch, 0, 24:40],
                                               newshape=(4, 4))
                pose_gt = {
                    "R": pose_gt_transform[0:3, 0:3],
                    "t": pose_gt_transform[0:3, 3:4]
                }

                # absolute pose error
                rot_error = pose_error.re(R_est=pose_est["R"],
                                          R_gt=pose_gt["R"]) / np.pi * 180.
                trans_error = pose_error.te(t_est=pose_est["t"],
                                            t_gt=pose_gt["t"]) / 0.01

                # other pose metrics
                if model_id in [10, 11]:
                    add_metric = pose_error.adi(
                        pose_est=pose_est, pose_gt=pose_gt, model=model_ply
                    )  # use adi when object is eggbox or glue
                else:
                    add_metric = pose_error.add(
                        pose_est=pose_est, pose_gt=pose_gt, model=model_ply
                    )  # use adi when object is eggbox or glue
                reproj_metric = pose_error.reprojectionError(
                    pose_est=pose_est,
                    pose_gt=pose_gt,
                    model=model_ply,
                    K=self.cam_intrinsic[:, 0:3])
                cou_metric = pose_error.cou(pose_est=pose_est,
                                            pose_gt=pose_gt,
                                            model=model_ply,
                                            im_size=(640, 480),
                                            K=self.cam_intrinsic[:, 0:3])

                # metric update
                if reproj_metric <= 5:  # reprojection error less than 5 pixels
                    self.sum_metric[7] += 1

                if add_metric <= self.models_info[
                        bb8dets[nbatch, 0, 0] + model_id][
                            'diameter'] * self.scale_to_meters * 0.1:  # ADD metric less than 0.1 * diameter
                    self.sum_metric[8] += 1

                if add_metric <= self.models_info[
                        bb8dets[nbatch, 0, 0] + model_id][
                            'diameter'] * self.scale_to_meters * 0.3:  # ADD metric less than 0.1 * diameter
                    self.sum_metric[9] += 1

                if add_metric <= self.models_info[
                        bb8dets[nbatch, 0, 0] + model_id][
                            'diameter'] * self.scale_to_meters * 0.5:  # ADD metric less than 0.1 * diameter
                    self.sum_metric[10] += 1

                if rot_error < 5:  # 5 degrees
                    self.sum_metric[11] += 1

                if trans_error < 5:  # 5 cm
                    self.sum_metric[12] += 1

                if (rot_error < 5) and (trans_error < 5):  # 5 degrees and 5 cm
                    self.sum_metric[13] += 1

                if cou_metric < 0.5:  # 2D IoU greater than 0.5
                    self.sum_metric[14] += 1

                if cou_metric < 0.1:  # 2D IoU larger than 0.9
                    self.sum_metric[15] += 1

        loc_label = preds[7].asnumpy()
        loc_label_in_use = loc_label[loc_label.nonzero()]
        loc_pred_masked = preds[8].asnumpy()
        loc_pred_in_use = loc_pred_masked[loc_pred_masked.nonzero()]
        loc_mae = preds[9].asnumpy()
        loc_mae_in_use = loc_mae[loc_mae.nonzero()]
        loc_mae_pixel = np.abs((bb8dets[:, 0, 2:6] - labels[:, 0, 1:5]) *
                               300)  # need to be refined

        bb8_label = preds[10].asnumpy()
        bb8_label_in_use = bb8_label[bb8_label.nonzero()]
        bb8_pred = preds[11].asnumpy()
        bb8_pred_in_use = bb8_pred[bb8_pred.nonzero()]
        bb8_mae = preds[12].asnumpy()
        bb8_mae_in_use = bb8_mae[bb8_mae.nonzero()]
        bb8_mae_pixel = np.abs((labels[:, 0, 8:24] - bb8dets[:, 0, 6:22]) *
                               300)  # need to be refined
        bb8_mae_pixel_x = bb8_mae_pixel[:, [0, 2, 4, 6, 8, 10, 12, 14]]
        bb8_mae_pixel_y = bb8_mae_pixel[:, [1, 3, 5, 7, 9, 11, 13, 15]]
        bb8_mae_pixel = np.sqrt(
            np.square(bb8_mae_pixel_x) + np.square(bb8_mae_pixel_y))
        # multi objects in one image (to be done)
        # loc_mae_pixel = []
        # bb8_mae_pixel = []
        # # independant execution for each image
        # for i in range(labels.shape[0]):
        #     # get as numpy arrays
        #     label = labels[i]
        #     pred = bb8dets[i]
        #     loc_mae_pixel_per_image = []
        #     bb8_mae_pixel_per_image = []
        #     # calculate for each class
        #     while (pred.shape[0] > 0):
        #         cid = int(pred[0, 0])
        #         indices = np.where(pred[:, 0].astype(int) == cid)[0]
        #         if cid < 0:
        #             pred = np.delete(pred, indices, axis=0)
        #             continue
        #         dets = pred[indices]
        #         pred = np.delete(pred, indices, axis=0)
        #
        #         # ground-truths
        #         label_indices = np.where(label[:, 0].astype(int) == cid)[0]
        #         gts = label[label_indices, :]
        #         label = np.delete(label, label_indices, axis=0)
        #         if gts.size > 0:
        #             found = [False] * gts.shape[0]
        #             for j in range(dets.shape[0]):
        #                 # compute overlaps
        #                 ious = iou(dets[j, 2:6], gts[:, 1:5])
        #                 ovargmax = np.argmax(ious)
        #                 ovmax = ious[ovargmax]
        #                 if ovmax > self.ovp_thresh:
        #                     if not found[ovargmax]:
        #                         loc_mae_pixel_per_image.append(np.abs((dets[j, 2:6] - gts[ovargmax, 1:5]) * 300))   # tp
        #                         bb8_mae_pixel_per_image.append(np.abs((dets[j, 6:22] - gts[ovargmax, 8:24]) * 300))
        #                         found[ovargmax] = True
        #                     else:
        #                         # duplicate
        #                         pass  # fp
        #
        #     loc_mae_pixel.append(np.mean(loc_mae_pixel_per_image, axis=1))
        #     bb8_mae_pixel.append(np.mean(bb8_mae_pixel_per_image, axis=1))

        valid_count = np.sum(cls_label >= 0)
        box_count = np.sum(cls_label > 0)
        # overall accuracy & object accuracy
        label = cls_label.flatten()
        # in case you have a 'other' class
        label[np.where(label >= cls_prob.shape[1])] = 0
        mask = np.where(label >= 0)[0]
        indices = np.int64(label[mask])
        prob = cls_prob.transpose((0, 2, 1)).reshape(
            (-1, cls_prob.shape[1])).asnumpy()
        prob = prob[mask, indices]
        self.sum_metric[0] += (-np.log(prob + self.eps)).sum()
        self.num_inst[0] += valid_count
        # loc_smoothl1loss
        self.sum_metric[1] += np.sum(loc_loss)
        self.num_inst[1] += box_count * 4
        # loc_mae
        self.sum_metric[2] += np.sum(loc_mae)
        self.num_inst[2] += box_count * 4
        # loc_mae_pixel
        self.sum_metric[3] += np.sum(loc_mae_pixel)
        self.num_inst[3] += loc_mae_pixel.size
        # bb8_smoothl1loss
        self.sum_metric[4] += np.sum(bb8_loss)
        self.num_inst[4] += box_count * 16
        # bb8_mae
        self.sum_metric[5] += np.sum(bb8_mae)
        self.num_inst[5] += box_count * 16
        # bb8_mae_pixel
        self.sum_metric[6] += np.sum(bb8_mae_pixel)
        self.num_inst[6] += bb8_mae_pixel.size
Ejemplo n.º 3
0
    def update(self, labels, preds):
        """
        :param preds: [cls_prob, loc_loss, cls_label, bb8_loss, loc_pred, bb8_pred,
                           anchors, loc_label, loc_pred_masked, loc_mae, bb8_label, bb8_pred_masked, bb8_mae]
        Implementation of updating metrics
        """
        labels = labels[0].asnumpy()  # batchsize x 8 x 40
        # get generated multi label from network
        cls_prob = preds[0]  # batchsize x num_cls x num_anchors
        loc_loss = preds[1].asnumpy()  # smoothL1 loss
        # loc_loss_in_use = loc_loss[loc_loss.nonzero()]
        cls_target = preds[2].asnumpy()  # batchsize x num_anchors
        bb8_loss = preds[3].asnumpy()
        loc_pred = preds[4]
        bb8_pred = preds[5]
        anchors = preds[6]
        # anchor_in_use = anchors[anchors.nonzero()]

        # monitor results
        # loc_target = preds[7].asnumpy()
        # loc_target_in_use = loc_target[loc_target.nonzero()]
        # loc_pred_masked = preds[8].asnumpy()
        # loc_pred_in_use = loc_pred_masked[loc_pred_masked.nonzero()]
        loc_mae = preds[9].asnumpy()
        # loc_mae_in_use = loc_mae[loc_mae.nonzero()]
        # bb8_target = preds[10].asnumpy()
        # bb8_target_in_use = bb8_target[bb8_target.nonzero()]
        # bb8_pred_masked = preds[11].asnumpy()
        # bb8_pred_in_use = bb8_pred_masked[bb8_pred_masked.nonzero()]
        bb8_mae = preds[12].asnumpy()
        # bb8_mae_in_use = bb8_mae[bb8_mae.nonzero()]

        # basic evaluation
        valid_count = np.sum(cls_target >= 0)
        box_count = np.sum(cls_target > 0)
        # overall accuracy & object accuracy
        label = cls_target.flatten()
        # in case you have a 'other' class
        label[np.where(label >= cls_prob.shape[1])] = 0
        mask = np.where(label >= 0)[0]
        indices = np.int64(label[mask])
        prob = cls_prob.transpose((0, 2, 1)).reshape(
            (-1, cls_prob.shape[1])).asnumpy()
        prob = prob[mask, indices]
        self.sum_metric[0] += (-np.log(prob + self.eps)).sum()
        self.num_inst[0] += valid_count
        # loc_smoothl1loss
        self.sum_metric[1] += np.sum(loc_loss)
        self.num_inst[1] += box_count * 4
        # loc_mae
        self.sum_metric[2] += np.sum(loc_mae)
        self.num_inst[2] += box_count * 4
        # bb8_smoothl1loss
        self.sum_metric[4] += np.sum(bb8_loss)
        self.num_inst[4] += box_count * 16
        # bb8_mae
        self.sum_metric[5] += np.sum(bb8_mae)
        self.num_inst[5] += box_count * 16

        bb8dets = IndirectBB8MultiBoxDetection(cls_prob,
                                               loc_pred,
                                               bb8_pred,
                                               anchors,
                                               nms_threshold=0.5,
                                               force_suppress=False,
                                               variances=(0.1, 0.1, 0.2, 0.2),
                                               nms_topk=400)
        bb8dets = bb8dets.asnumpy()

        loc_mae_pixel = []
        bb8_mae_pixel = []

        # pose metrics, adapt to multi-class case
        for sampleDet, sampleLabel in zip(bb8dets, labels):
            # calculate for each class
            for instanceLabel in sampleLabel:
                if instanceLabel[0] < 0:
                    continue
                else:
                    cid = instanceLabel[0].astype(np.int16)
                    model_id = int(self.classes[cid].strip("obj_"))
                    indices = np.where(sampleDet[:, 0] == cid)[0]

                    if cid in self.counts:
                        self.counts[cid] += 1
                    else:
                        self.counts[cid] = 1

                    if indices.size > 0:
                        instanceDet = sampleDet[indices[
                            0]]  # only consider the most confident instance

                        loc_mae_pixel.append(
                            np.abs((instanceDet[2:6] - instanceLabel[1:5]) *
                                   300.))
                        bb8_mae_pixel.append(
                            np.abs((instanceDet[6:22] - instanceLabel[8:24]) *
                                   300.))

                        pose_est = self.calculate6Dpose(
                            instance_bb8det=instanceDet, model_id=model_id)
                        model_path = os.path.join(
                            self.LINEMOD_path, 'models',
                            '{}.ply'.format(self.classes[cid]))
                        model_ply = inout.load_ply(model_path)
                        model_ply[
                            'pts'] = model_ply['pts'] * self.scale_to_meters
                        pose_gt_transform = np.reshape(instanceLabel[24:40],
                                                       newshape=(4, 4))
                        pose_gt = {
                            "R": pose_gt_transform[0:3, 0:3],
                            "t": pose_gt_transform[0:3, 3:4]
                        }

                        # absolute pose error
                        rot_error = pose_error.re(
                            R_est=pose_est["R"],
                            R_gt=pose_gt["R"]) / np.pi * 180.
                        trans_error = pose_error.te(t_est=pose_est["t"],
                                                    t_gt=pose_gt["t"]) / 0.01

                        # other pose metrics
                        if model_id in [10, 11]:
                            add_metric = pose_error.adi(
                                pose_est=pose_est,
                                pose_gt=pose_gt,
                                model=model_ply
                            )  # use adi when object is eggbox or glue
                        else:
                            add_metric = pose_error.add(
                                pose_est=pose_est,
                                pose_gt=pose_gt,
                                model=model_ply)  # use add otherwise

                        reproj_metric = pose_error.reprojectionError(
                            pose_est=pose_est,
                            pose_gt=pose_gt,
                            model=model_ply,
                            K=self.cam_intrinsic[:, 0:3])
                        cou_metric = pose_error.cou(pose_est=pose_est,
                                                    pose_gt=pose_gt,
                                                    model=model_ply,
                                                    im_size=(640, 480),
                                                    K=self.cam_intrinsic[:,
                                                                         0:3])

                        # record all the Reproj. error to plot curve
                        if cid not in self.Reproj:
                            self.Reproj[cid] = [reproj_metric]
                        else:
                            assert cid in self.counts
                            self.Reproj[cid] += [reproj_metric]

                        # metric update
                        if reproj_metric <= 5:  # reprojection error less than 5 pixels
                            if cid not in self.Reproj5px:
                                self.Reproj5px[cid] = 1
                            else:
                                assert cid in self.counts
                                self.Reproj5px[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.1:  # ADD metric less than 0.1 * diameter
                            if cid not in self.ADD0_1:
                                self.ADD0_1[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_1[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.3:  # ADD metric less than 0.3 * diameter
                            if cid not in self.ADD0_3:
                                self.ADD0_3[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_3[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.5:  # ADD metric less than 0.5 * diameter
                            if cid not in self.ADD0_5:
                                self.ADD0_5[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_5[cid] += 1

                        if rot_error < 5:  # 5 degrees
                            if cid not in self.re:
                                self.re[cid] = 1
                            else:
                                assert cid in self.counts
                                self.re[cid] += 1

                        if trans_error < 5:  # 5 cm
                            if cid not in self.te:
                                self.te[cid] = 1
                            else:
                                assert cid in self.counts
                                self.te[cid] += 1

                        if (rot_error < 5) and (trans_error <
                                                5):  # 5 degrees and 5 cm
                            if cid not in self.re_te:
                                self.re_te[cid] = 1
                            else:
                                assert cid in self.counts
                                self.re_te[cid] += 1

                        if cou_metric < 0.5:  # 2D IoU greater than 0.5
                            if cid not in self.IoU2D0_5:
                                self.IoU2D0_5[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_5[cid] += 1

                        if cou_metric < 0.4:  # 2D IoU greater than 0.6
                            if cid not in self.IoU2D0_6:
                                self.IoU2D0_6[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_6[cid] += 1

                        if cou_metric < 0.3:  # 2D IoU greater than 0.7
                            if cid not in self.IoU2D0_7:
                                self.IoU2D0_7[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_7[cid] += 1

                        if cou_metric < 0.2:  # 2D IoU greater than 0.8
                            if cid not in self.IoU2D0_8:
                                self.IoU2D0_8[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_8[cid] += 1

                        if cou_metric < 0.1:  # 2D IoU larger than 0.9
                            if cid not in self.IoU2D0_9:
                                self.IoU2D0_9[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_9[cid] += 1
                    # else:
                    #     loc_mae_pixel.append(np.ones((4, )) * 300.)
                    #     bb8_mae_pixel.append(np.ones((16, )) * 300.)

        loc_mae_pixel = np.array(loc_mae_pixel)
        loc_mae_pixel_x = loc_mae_pixel[:, [0, 2]]
        loc_mae_pixel_y = loc_mae_pixel[:, [1, 3]]
        loc_mae_pixel = np.sqrt(
            np.square(loc_mae_pixel_x) + np.square(loc_mae_pixel_y))
        bb8_mae_pixel = np.array(bb8_mae_pixel)
        bb8_mae_pixel_x = bb8_mae_pixel[:, [0, 2, 4, 6, 8, 10, 12, 14]]
        bb8_mae_pixel_y = bb8_mae_pixel[:, [1, 3, 5, 7, 9, 11, 13, 15]]
        bb8_mae_pixel = np.sqrt(
            np.square(bb8_mae_pixel_x) + np.square(bb8_mae_pixel_y))

        # loc_mae_pixel
        self.sum_metric[3] += np.sum(loc_mae_pixel)
        self.num_inst[3] += loc_mae_pixel.size
        # bb8_mae_pixel
        self.sum_metric[6] += np.sum(bb8_mae_pixel)
        self.num_inst[6] += bb8_mae_pixel.size
Ejemplo n.º 4
0
    def update(self, labels, preds):
        """
        :param preds: [cls_prob, loc_loss, cls_label, bb8_loss, loc_pred, bb8_pred,
                           anchors, loc_label, loc_pred_masked, loc_mae, bb8_label, bb8_pred_masked, bb8_mae]
        Implementation of updating metrics
        """
        def iou(x, ys):
            """
            Calculate intersection-over-union overlap
            Params:
            ----------
            x : numpy.array
                single box [xmin, ymin ,xmax, ymax]
            ys : numpy.array
                multiple box [[xmin, ymin, xmax, ymax], [...], ]
            Returns:
            -----------
            numpy.array
                [iou1, iou2, ...], size == ys.shape[0]
            """
            ixmin = np.maximum(ys[:, 0], x[0])
            iymin = np.maximum(ys[:, 1], x[1])
            ixmax = np.minimum(ys[:, 2], x[2])
            iymax = np.minimum(ys[:, 3], x[3])
            iw = np.maximum(ixmax - ixmin, 0.)
            ih = np.maximum(iymax - iymin, 0.)
            inters = iw * ih
            uni = (x[2] - x[0]) * (x[3] - x[1]) + (ys[:, 2] - ys[:, 0]) * \
                (ys[:, 3] - ys[:, 1]) - inters
            ious = inters / uni
            ious[uni < 1e-12] = 0  # in case bad boxes
            return ious

        labels = labels[0].asnumpy()  # batchsize x 8 x 40
        # get generated multi label from network
        cls_prob = preds[0]  # batchsize x num_cls x num_anchors
        loc_loss = preds[1].asnumpy()  # smoothL1 loss
        loc_loss_in_use = loc_loss[loc_loss.nonzero()]
        cls_label = preds[2].asnumpy()  # batchsize x num_anchors
        bb8_loss = preds[3].asnumpy()
        loc_pred = preds[4]
        bb8_pred = preds[5]
        anchors = preds[6]
        # anchor_in_use = anchors[anchors.nonzero()]

        # basic evaluation, adapt to multi-class
        loc_label = preds[7].asnumpy()
        loc_label_in_use = loc_label[loc_label.nonzero()]
        loc_pred_masked = preds[8].asnumpy()
        loc_pred_in_use = loc_pred_masked[loc_pred_masked.nonzero()]
        loc_mae = preds[9].asnumpy()
        loc_mae_in_use = loc_mae[loc_mae.nonzero()]
        bb8_label = preds[10].asnumpy()
        bb8_label_in_use = bb8_label[bb8_label.nonzero()]
        bb8_pred_masked = preds[11].asnumpy()
        bb8_pred_in_use = bb8_pred_masked[bb8_pred_masked.nonzero()]
        bb8_mae = preds[12].asnumpy()
        bb8_mae_in_use = bb8_mae[bb8_mae.nonzero()]

        valid_count = np.sum(cls_label >= 0)
        box_count = np.sum(cls_label > 0)
        # overall accuracy & object accuracy
        label = cls_label.flatten()
        # in case you have a 'other' class
        label[np.where(label >= cls_prob.shape[1])] = 0
        mask = np.where(label >= 0)[0]
        indices = np.int64(label[mask])
        prob = cls_prob.transpose((0, 2, 1)).reshape(
            (-1, cls_prob.shape[1])).asnumpy()
        prob = prob[mask, indices]
        self.sum_metric[0] += (-np.log(prob + self.eps)).sum()
        self.num_inst[0] += valid_count
        # loc_smoothl1loss
        self.sum_metric[1] += np.sum(loc_loss)
        self.num_inst[1] += box_count * 4
        # loc_mae
        self.sum_metric[2] += np.sum(loc_mae)
        self.num_inst[2] += box_count * 4
        # bb8_smoothl1loss
        self.sum_metric[4] += np.sum(bb8_loss)
        self.num_inst[4] += box_count * 16
        # bb8_mae
        self.sum_metric[5] += np.sum(bb8_mae)
        self.num_inst[5] += box_count * 16

        bb8dets = BB8MultiBoxDetection(cls_prob,
                                       loc_pred,
                                       bb8_pred,
                                       anchors,
                                       nms_threshold=0.5,
                                       force_suppress=False,
                                       variances=(0.1, 0.1, 0.2, 0.2),
                                       nms_topk=400)
        bb8dets = bb8dets.asnumpy()

        # model_id = int(self.classes[0].strip("obj_"))
        #
        # for nbatch in range(bb8dets.shape[0]):
        #
        #     self.num_inst[7] += 1
        #     self.num_inst[8] += 1
        #     self.num_inst[9] += 1
        #     self.num_inst[10] += 1
        #     self.num_inst[11] += 1
        #     self.num_inst[12] += 1
        #     self.num_inst[13] += 1
        #     self.num_inst[14] += 1
        #     self.num_inst[15] += 1
        #
        #     if bb8dets[nbatch, 0, 0] == -1:
        #         continue
        #     else:
        #         # for LINEMOD dataset, for each image only select the first det
        #         # self.validate6Dpose(gt_pose=labels[nbatch, 0, 24:40], instance_bb8det=bb8dets[nbatch, 0, :], model_id=model_id)
        #         pose_est = self.calculate6Dpose(instance_bb8det=bb8dets[nbatch, 0, :], model_id=model_id)
        #         model_path = '/data/ZHANGXIN/DATASETS/SIXD_CHALLENGE/LINEMOD/models/' + self.classes[int(bb8dets[nbatch, 0, 0])] + '.ply'
        #         model_ply = inout.load_ply(model_path)
        #         model_ply['pts'] = model_ply['pts'] * self.scale_to_meters
        #         pose_gt_transform = np.reshape(labels[nbatch, 0, 24:40], newshape=(4, 4))
        #         pose_gt = {"R": pose_gt_transform[0:3, 0:3],
        #         "t": pose_gt_transform[0:3, 3:4]}
        #
        #         # absolute pose error
        #         rot_error = pose_error.re(R_est=pose_est["R"], R_gt=pose_gt["R"]) / np.pi * 180.
        #         trans_error = pose_error.te(t_est=pose_est["t"], t_gt=pose_gt["t"]) / 0.01
        #
        #         # other pose metrics
        #         add_metric = pose_error.add(pose_est=pose_est, pose_gt=pose_gt, model=model_ply)    # use adi when object is eggbox or glue
        #         reproj_metric = pose_error.reprojectionError(pose_est=pose_est, pose_gt=pose_gt,
        #                                                      model=model_ply, K=self.cam_intrinsic[:, 0:3])
        #         cou_metric = pose_error.cou(pose_est=pose_est, pose_gt=pose_gt,
        #                                     model=model_ply, im_size=(640, 480), K=self.cam_intrinsic[:, 0:3])
        #
        #         # metric update
        #         if reproj_metric <= 5:  # reprojection error less than 5 pixels
        #             self.sum_metric[7] += 1
        #
        #         if add_metric <= self.models_info[bb8dets[nbatch, 0, 0] + model_id]['diameter'] * self.scale_to_meters * 0.1:   # ADD metric less than 0.1 * diameter
        #             self.sum_metric[8] += 1
        #
        #         if add_metric <= self.models_info[bb8dets[nbatch, 0, 0] + model_id]['diameter'] * self.scale_to_meters * 0.3:   # ADD metric less than 0.1 * diameter
        #             self.sum_metric[9] += 1
        #
        #         if add_metric <= self.models_info[bb8dets[nbatch, 0, 0] + model_id]['diameter'] * self.scale_to_meters * 0.5:   # ADD metric less than 0.1 * diameter
        #             self.sum_metric[10] += 1
        #
        #         if rot_error < 5:   # 5 degrees
        #             self.sum_metric[11] += 1
        #
        #         if trans_error < 5: # 5 cm
        #             self.sum_metric[12] += 1
        #
        #         if (rot_error < 5) and (trans_error < 5):   # 5 degrees and 5 cm
        #             self.sum_metric[13] += 1
        #
        #         if cou_metric < 0.5:    # 2D IoU greater than 0.5
        #             self.sum_metric[14] += 1
        #
        #         if cou_metric < 0.1:    # 2D IoU larger than 0.9
        #             self.sum_metric[15] += 1
        #

        # for each class, only consider the most confident instance
        loc_mae_pixel = []
        bb8_mae_pixel = []

        # pose metrics, adapt to multi-class case
        for sampleDet, sampleLabel in zip(bb8dets, labels):
            # calculate for each class
            for instanceLabel in sampleLabel:
                if instanceLabel[0] < 0:
                    continue
                else:
                    cid = instanceLabel[0].astype(np.int16)
                    model_id = int(self.classes[cid].strip("obj_"))
                    indices = np.where(sampleDet[:, 0] == cid)[0]

                    if cid in self.counts:
                        self.counts[cid] += 1
                    else:
                        self.counts[cid] = 1

                    if indices.size > 0:
                        instanceDet = sampleDet[indices[
                            0]]  # only consider the most confident instance

                        loc_mae_pixel.append(
                            np.abs(
                                (instanceDet[2:6] - instanceLabel[1:5]) * 300))
                        bb8_mae_pixel.append(
                            np.abs((instanceDet[6:22] - instanceLabel[8:24]) *
                                   300))

                        pose_est = self.calculate6Dpose(
                            instance_bb8det=instanceDet, model_id=model_id)
                        model_path = os.path.join(
                            self.LINEMOD_path, 'models',
                            '{}.ply'.format(self.classes[cid]))
                        model_ply = inout.load_ply(model_path)
                        model_ply[
                            'pts'] = model_ply['pts'] * self.scale_to_meters
                        pose_gt_transform = np.reshape(instanceLabel[24:40],
                                                       newshape=(4, 4))
                        pose_gt = {
                            "R": pose_gt_transform[0:3, 0:3],
                            "t": pose_gt_transform[0:3, 3:4]
                        }

                        # absolute pose error
                        rot_error = pose_error.re(
                            R_est=pose_est["R"],
                            R_gt=pose_gt["R"]) / np.pi * 180.
                        trans_error = pose_error.te(t_est=pose_est["t"],
                                                    t_gt=pose_gt["t"]) / 0.01

                        # other pose metrics
                        if model_id in [10, 11]:
                            add_metric = pose_error.adi(
                                pose_est=pose_est,
                                pose_gt=pose_gt,
                                model=model_ply
                            )  # use adi when object is eggbox or glue
                        else:
                            add_metric = pose_error.add(
                                pose_est=pose_est,
                                pose_gt=pose_gt,
                                model=model_ply
                            )  # use adi when object is eggbox or glue

                        reproj_metric = pose_error.reprojectionError(
                            pose_est=pose_est,
                            pose_gt=pose_gt,
                            model=model_ply,
                            K=self.cam_intrinsic[:, 0:3])
                        cou_metric = pose_error.cou(pose_est=pose_est,
                                                    pose_gt=pose_gt,
                                                    model=model_ply,
                                                    im_size=(640, 480),
                                                    K=self.cam_intrinsic[:,
                                                                         0:3])

                        if cid not in self.Reproj:
                            self.Reproj[cid] = [reproj_metric]
                        else:
                            assert cid in self.counts
                            self.Reproj[cid] += [reproj_metric]

                        # metric update
                        if reproj_metric <= 5:  # reprojection error less than 5 pixels
                            if cid not in self.Reproj5px:
                                self.Reproj5px[cid] = 1
                            else:
                                assert cid in self.counts
                                self.Reproj5px[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.1:  # ADD metric less than 0.1 * diameter
                            if cid not in self.ADD0_1:
                                self.ADD0_1[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_1[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.3:  # ADD metric less than 0.3 * diameter
                            if cid not in self.ADD0_3:
                                self.ADD0_3[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_3[cid] += 1

                        if add_metric <= self.models_info[model_id][
                                'diameter'] * self.scale_to_meters * 0.5:  # ADD metric less than 0.5 * diameter
                            if cid not in self.ADD0_5:
                                self.ADD0_5[cid] = 1
                            else:
                                assert cid in self.counts
                                self.ADD0_5[cid] += 1

                        if rot_error < 5:  # 5 degrees
                            if cid not in self.re:
                                self.re[cid] = 1
                            else:
                                assert cid in self.counts
                                self.re[cid] += 1

                        if trans_error < 5:  # 5 cm
                            if cid not in self.te:
                                self.te[cid] = 1
                            else:
                                assert cid in self.counts
                                self.te[cid] += 1

                        if (rot_error < 5) and (trans_error <
                                                5):  # 5 degrees and 5 cm
                            if cid not in self.re_te:
                                self.re_te[cid] = 1
                            else:
                                assert cid in self.counts
                                self.re_te[cid] += 1

                        if cou_metric < 0.5:  # 2D IoU greater than 0.5
                            if cid not in self.IoU2D0_5:
                                self.IoU2D0_5[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_5[cid] += 1

                        if cou_metric < 0.1:  # 2D IoU larger than 0.9
                            if cid not in self.IoU2D0_9:
                                self.IoU2D0_9[cid] = 1
                            else:
                                assert cid in self.counts
                                self.IoU2D0_9[cid] += 1

        loc_mae_pixel = np.array(loc_mae_pixel)
        bb8_mae_pixel = np.array(bb8_mae_pixel)
        bb8_mae_pixel_x = bb8_mae_pixel[:, [0, 2, 4, 6, 8, 10, 12, 14]]
        bb8_mae_pixel_y = bb8_mae_pixel[:, [1, 3, 5, 7, 9, 11, 13, 15]]
        bb8_mae_pixel = np.sqrt(
            np.square(bb8_mae_pixel_x) + np.square(bb8_mae_pixel_y))

        # loc_mae_pixel
        self.sum_metric[3] += np.sum(loc_mae_pixel)
        self.num_inst[3] += loc_mae_pixel.size
        # bb8_mae_pixel
        self.sum_metric[6] += np.sum(bb8_mae_pixel)
        self.num_inst[6] += bb8_mae_pixel.size