示例#1
0
    def process_inst(self, classes, scores, pred_inst, img_shape, ori_shape):
        """
        Simple process generate prediction of Things.

        Args:
            classes: predicted classes of Things
            scores: predicted scores of Things
            pred_inst: predicted instances of Things
            img_shape: input image shape
            ori_shape: original image shape

        Returns:
            result_instance: preserved results for Things
            pred_mask: preserved binary masks for Things
            classes: preserved object classes
            scores: processed object scores
        """
        pred_inst = pred_inst.sigmoid()[0]
        pred_mask = pred_inst > self.inst_thres
        # object rescore.
        sum_masks = pred_mask.sum((1, 2)).float() + 1e-6
        seg_score = (pred_inst * pred_mask.float()).sum((1, 2)) / sum_masks
        scores *= seg_score
        print('scores: ', scores.shape)
        keep = torch.argsort(scores, descending=True)

        pred_inst = pred_inst[keep]
        pred_mask = pred_mask[keep]
        scores = scores[keep]
        classes = classes[keep]
        sum_masks = sum_masks[keep]

        print('keep: ', keep.shape)
        print('pred_inst: ', pred_inst.shape)
        print('pred_mask: ', pred_mask.shape)
        print('scores: ', scores.shape)
        print('classes: ', classes.shape)
        print('sum_masks: ', sum_masks.shape)

        # object score filter.
        keep = scores >= 0.05
        print(keep)
        # print()
        if keep.sum() == 0:
            result_instance = Instances(ori_shape, pred_masks=[], pred_boxes=[],
                                        pred_classes=[], scores=[])
            return result_instance, pred_mask, None, None
        pred_inst = pred_inst[keep]
        scores = scores[keep]
        classes = classes[keep]

        # sort and keep top_k
        keep = torch.argsort(scores, descending=True)
        keep = keep[:self.center_top_num]
        pred_inst = pred_inst[keep]
        scores = scores[keep].reshape(-1)
        classes = classes[keep].reshape(-1).to(torch.int32)

        pred_inst = F.interpolate(pred_inst.unsqueeze(0),
                                  scale_factor=self.common_stride,
                                  mode="bilinear",
                                  align_corners=False)[..., :img_shape[0], :img_shape[1]]
        pred_inst = F.interpolate(pred_inst,
                                  size=ori_shape,
                                  mode="bilinear",
                                  align_corners=False)[0]

        pred_mask = pred_inst > self.inst_thres
        pred_bitinst = BitMasks(pred_mask)
        result_instance = Instances(ori_shape,
                                    pred_masks=pred_bitinst,
                                    pred_boxes=pred_bitinst.get_bounding_boxes(),
                                    pred_classes=classes,
                                    scores=scores)
        return result_instance, pred_mask, classes, scores
示例#2
0
    def process_inst_onnx(self, classes, scores, pred_inst, img_shape, vis=False):
        pred_inst = pred_inst.sigmoid()[0]
        pred_mask = pred_inst > self.inst_thres
        # object rescore.
        sum_masks = pred_mask.sum((1, 2)).float() + 1e-6
        seg_score = (pred_inst * pred_mask.float()).sum((1, 2)) / sum_masks
        # scores *= seg_score
        scores = scores * seg_score
        # keep = torch.argsort(scores, descending=True)
        dim = 0
        _, keep = torch.sort(scores, descending=True,dim=dim)

        pred_inst = pred_inst[keep]
        # pred_mask = pred_mask[keep]
        scores = scores[keep]
        classes = classes[keep]
        sum_masks = sum_masks[keep]

        print('keep: ', keep.shape)
        print('pred_inst: ', pred_inst.shape)
        print('pred_mask: ', pred_mask.shape)
        print('scores: ', scores.shape)
        print('classes: ', classes.shape)

        if vis:
            ori_shape = [720, 1280]
            # object score filter.
            keep = scores >= 0.05
            if keep.sum() == 0:
                result_instance = Instances(ori_shape, pred_masks=[], pred_boxes=[],
                                            pred_classes=[], scores=[])
                return {'instances': result_instance}
            pred_inst = pred_inst[keep]
            scores = scores[keep]
            classes = classes[keep]

            # sort and keep top_k
            keep = torch.argsort(scores, descending=True)
            keep = keep[:self.center_top_num]
            pred_inst = pred_inst[keep]
            scores = scores[keep].reshape(-1)
            classes = classes[keep].reshape(-1).to(torch.int32)

        pred_inst = F.interpolate(pred_inst.unsqueeze(0),
                                  scale_factor=self.common_stride,
                                  mode="bilinear",
                                  align_corners=False)[..., :img_shape[0], :img_shape[1]]
        if vis:
            pred_inst = F.interpolate(pred_inst,
                                      size=ori_shape,
                                      mode="bilinear",
                                      align_corners=False)[0]
            pred_mask = pred_inst > self.inst_thres
            pred_bitinst = BitMasks(pred_mask)
            result_instance = Instances(ori_shape,
                                        pred_masks=pred_bitinst,
                                        pred_boxes=pred_bitinst.get_bounding_boxes(),
                                        pred_classes=classes,
                                        scores=scores)
            return {"instances": result_instance}
        else:
            # let's visiualise the raw instances mask out (should be same output as TensorRT)
            # print('pred_inst shape: ', pred_inst.shape)
            # for i in pred_inst[0]:
            #     import numpy as np
            #     import cv2
            #     print(i.shape)
            #     i = i.cpu().numpy()
            #     print(i)
            #     cv2.imshow('aa', i)
            #     cv2.waitKey(0)
            return pred_inst, classes, scores