def process_inst(self, classes, scores, pred_inst, img_shape, ori_shape): """ Simple process generate prediction of Things. Args: classes: predicted classes of Things scores: predicted scores of Things pred_inst: predicted instances of Things img_shape: input image shape ori_shape: original image shape Returns: result_instance: preserved results for Things pred_mask: preserved binary masks for Things classes: preserved object classes scores: processed object scores """ pred_inst = pred_inst.sigmoid()[0] pred_mask = pred_inst > self.inst_thres # object rescore. sum_masks = pred_mask.sum((1, 2)).float() + 1e-6 seg_score = (pred_inst * pred_mask.float()).sum((1, 2)) / sum_masks scores *= seg_score print('scores: ', scores.shape) keep = torch.argsort(scores, descending=True) pred_inst = pred_inst[keep] pred_mask = pred_mask[keep] scores = scores[keep] classes = classes[keep] sum_masks = sum_masks[keep] print('keep: ', keep.shape) print('pred_inst: ', pred_inst.shape) print('pred_mask: ', pred_mask.shape) print('scores: ', scores.shape) print('classes: ', classes.shape) print('sum_masks: ', sum_masks.shape) # object score filter. keep = scores >= 0.05 print(keep) # print() if keep.sum() == 0: result_instance = Instances(ori_shape, pred_masks=[], pred_boxes=[], pred_classes=[], scores=[]) return result_instance, pred_mask, None, None pred_inst = pred_inst[keep] scores = scores[keep] classes = classes[keep] # sort and keep top_k keep = torch.argsort(scores, descending=True) keep = keep[:self.center_top_num] pred_inst = pred_inst[keep] scores = scores[keep].reshape(-1) classes = classes[keep].reshape(-1).to(torch.int32) pred_inst = F.interpolate(pred_inst.unsqueeze(0), scale_factor=self.common_stride, mode="bilinear", align_corners=False)[..., :img_shape[0], :img_shape[1]] pred_inst = F.interpolate(pred_inst, size=ori_shape, mode="bilinear", align_corners=False)[0] pred_mask = pred_inst > self.inst_thres pred_bitinst = BitMasks(pred_mask) result_instance = Instances(ori_shape, pred_masks=pred_bitinst, pred_boxes=pred_bitinst.get_bounding_boxes(), pred_classes=classes, scores=scores) return result_instance, pred_mask, classes, scores
def process_inst_onnx(self, classes, scores, pred_inst, img_shape, vis=False): pred_inst = pred_inst.sigmoid()[0] pred_mask = pred_inst > self.inst_thres # object rescore. sum_masks = pred_mask.sum((1, 2)).float() + 1e-6 seg_score = (pred_inst * pred_mask.float()).sum((1, 2)) / sum_masks # scores *= seg_score scores = scores * seg_score # keep = torch.argsort(scores, descending=True) dim = 0 _, keep = torch.sort(scores, descending=True,dim=dim) pred_inst = pred_inst[keep] # pred_mask = pred_mask[keep] scores = scores[keep] classes = classes[keep] sum_masks = sum_masks[keep] print('keep: ', keep.shape) print('pred_inst: ', pred_inst.shape) print('pred_mask: ', pred_mask.shape) print('scores: ', scores.shape) print('classes: ', classes.shape) if vis: ori_shape = [720, 1280] # object score filter. keep = scores >= 0.05 if keep.sum() == 0: result_instance = Instances(ori_shape, pred_masks=[], pred_boxes=[], pred_classes=[], scores=[]) return {'instances': result_instance} pred_inst = pred_inst[keep] scores = scores[keep] classes = classes[keep] # sort and keep top_k keep = torch.argsort(scores, descending=True) keep = keep[:self.center_top_num] pred_inst = pred_inst[keep] scores = scores[keep].reshape(-1) classes = classes[keep].reshape(-1).to(torch.int32) pred_inst = F.interpolate(pred_inst.unsqueeze(0), scale_factor=self.common_stride, mode="bilinear", align_corners=False)[..., :img_shape[0], :img_shape[1]] if vis: pred_inst = F.interpolate(pred_inst, size=ori_shape, mode="bilinear", align_corners=False)[0] pred_mask = pred_inst > self.inst_thres pred_bitinst = BitMasks(pred_mask) result_instance = Instances(ori_shape, pred_masks=pred_bitinst, pred_boxes=pred_bitinst.get_bounding_boxes(), pred_classes=classes, scores=scores) return {"instances": result_instance} else: # let's visiualise the raw instances mask out (should be same output as TensorRT) # print('pred_inst shape: ', pred_inst.shape) # for i in pred_inst[0]: # import numpy as np # import cv2 # print(i.shape) # i = i.cpu().numpy() # print(i) # cv2.imshow('aa', i) # cv2.waitKey(0) return pred_inst, classes, scores