def match_anchors2(anns, a_xywh, a_ltrb, pos_thresh=0.7, neg_thresh=0.3, get_label=lambda x: x['category_id'], debug=False): num_anchors = len(a_xywh) if len(anns) == 0: loc_t = a_xywh.new_zeros(num_anchors, 4) cls_t = loc_t.new_zeros(num_anchors, dtype=torch.long) ignore = loc_t.new_zeros(num_anchors, dtype=torch.uint8) return loc_t, cls_t, ignore bboxes = a_xywh.new_tensor([ann['bbox'] for ann in anns]) bboxes = BBox.convert(bboxes, format=BBox.LTWH, to=BBox.XYWH, inplace=True) labels = a_xywh.new_tensor([get_label(ann) for ann in anns], dtype=torch.long) bboxes_ltrb = BBox.convert(bboxes, BBox.XYWH, BBox.LTRB) ious = iou_mn(bboxes_ltrb, a_ltrb) pos = ious > pos_thresh cls_t, indices = (pos.long() * labels[:, None]).max(dim=0) loc_t_all = coords_to_target2(bboxes, a_xywh) loc_t = select(loc_t_all, 0, indices) max_ious, max_indices = ious.max(dim=1) if debug: print(max_ious.tolist()) loc_t[max_indices] = select(loc_t_all, 1, max_indices) cls_t[max_indices] = labels ignore = (cls_t == 0) & ((ious >= neg_thresh).sum(dim=0) != 0) return loc_t, cls_t, ignore
def match_rois2(anns, rois, pos_thresh=0.5, n_samples=64, pos_neg_ratio=1 / 3): num_rois = len(rois) if len(anns) == 0: loc_t = rois.new_zeros(num_rois, 4) cls_t = loc_t.new_zeros(num_rois, dtype=torch.long) return loc_t, cls_t, rois_xywh = BBox.convert(rois, BBox.LTRB, BBox.XYWH) bboxes = rois.new_tensor([ann['bbox'] for ann in anns]) bboxes = BBox.convert(bboxes, format=BBox.LTWH, to=BBox.XYWH, inplace=True) labels = rois.new_tensor([ann['category_id'] for ann in anns], dtype=torch.long) bboxes_ltrb = BBox.convert(bboxes, BBox.XYWH, BBox.LTRB) ious = iou_mn(bboxes_ltrb, rois) pos = ious > pos_thresh cls_t, ann_indices = (pos.long() * labels[:, None]).max(dim=0) loc_t_all = coords_to_target2(bboxes, rois_xywh) loc_t = select(loc_t_all, 0, ann_indices) max_ious, max_indices = ious.max(dim=1) loc_t[max_indices] = select(loc_t_all, 1, max_indices) cls_t[max_indices] = labels pos = cls_t != 0 n_pos = int(n_samples * pos_neg_ratio / (pos_neg_ratio + 1)) n_neg = n_samples - n_pos pos_indices = sample(torch.nonzero(pos).squeeze(1), n_pos) neg_indices = sample(torch.nonzero(~pos).squeeze(1), n_neg) loc_t = loc_t[pos_indices] indices = torch.cat([pos_indices, neg_indices], dim=0) cls_t = cls_t[indices] return loc_t, cls_t, indices
def match_rois2(anns, rois, pos_thresh=0.5, mask_size=(14, 14), n_samples=64, pos_neg_ratio=1 / 3): num_rois = len(rois) if len(anns) == 0: loc_t = rois.new_zeros(num_rois, 4) cls_t = loc_t.new_zeros(num_rois, dtype=torch.long) return loc_t, cls_t rois_xywh = BBox.convert(rois, BBox.LTRB, BBox.XYWH) bboxes = rois.new_tensor([ann['bbox'] for ann in anns]) bboxes = BBox.convert(bboxes, format=BBox.LTWH, to=BBox.XYWH, inplace=True) labels = rois.new_tensor([ann['category_id'] for ann in anns], dtype=torch.long) bboxes_ltrb = BBox.convert(bboxes, BBox.XYWH, BBox.LTRB) ious = iou_mn(bboxes_ltrb, rois) pos = ious > pos_thresh cls_t, ann_indices = (pos.long() * labels[:, None]).max(dim=0) loc_t_all = coords_to_target2(bboxes, rois_xywh) loc_t = select(loc_t_all, 0, ann_indices) max_ious, max_indices = ious.max(dim=1) loc_t[max_indices] = select(loc_t_all, 1, max_indices) cls_t[max_indices] = labels ann_indices[max_indices] = torch.arange(len(anns), device=rois.device) pos = cls_t != 0 n_pos = int(n_samples * pos_neg_ratio / (pos_neg_ratio + 1)) n_neg = n_samples - n_pos pos_indices = sample(torch.nonzero(pos).squeeze(1), n_pos) neg_indices = sample(torch.nonzero(~pos).squeeze(1), n_neg) loc_t = loc_t[pos_indices] indices = torch.cat([pos_indices, neg_indices], dim=0) cls_t = cls_t[indices] mask_t = loc_t.new_zeros(n_pos, *mask_size) for i in range(n_pos): ind = pos_indices[i] mask = anns[ann_indices[ind]]['segmentation'] height, width = mask.shape l, t, r, b = rois[ind] l = max(0, int(l * width)) t = max(0, int(t * height)) r = int(r * width) b = int(b * height) m = mask[t:b, l:r].float() m = m.view(1, 1, *m.size()) m = F.interpolate(m, size=mask_size).squeeze() mask_t[i] = m return loc_t, cls_t, mask_t, indices
def roi_based_inference(rois, loc_p, cls_p, predict_mask, iou_threshold=0.5, topk=100, nms_method='soft_nms'): scores, labels = torch.softmax(cls_p, dim=1)[:, 1:].max(dim=1) num_classes = cls_p.size(1) - 1 loc_p = expand_last_dim(loc_p, num_classes, 4) loc_p = select(loc_p, 1, labels) loc_p[..., :2].mul_(rois[:, 2:]).add_(rois[:, :2]) loc_p[..., 2:].exp_().mul_(rois[:, 2:]) bboxes = loc_p bboxes = BBox.convert(bboxes, format=BBox.XYWH, to=BBox.LTRB, inplace=True).cpu() scores = scores.cpu() if nms_method == 'nms': indices = nms(bboxes, scores, iou_threshold) if len(indices) > topk: indices = indices[scores[indices].topk(topk)[1]] else: warnings.warn("Only %d RoIs left after nms rather than top %d" % (len(scores), topk)) else: indices = soft_nms_cpu(bboxes, scores, iou_threshold, topk) bboxes = BBox.convert(bboxes, format=BBox.LTRB, to=BBox.LTWH, inplace=True) if predict_mask is not None: mask_p = predict_mask(indices) masks = (select(mask_p, 1, labels[indices]).sigmoid_() > 0.5).cpu().numpy() dets = [] for i, ind in enumerate(indices): det = { 'image_id': -1, 'category_id': labels[ind].item() + 1, 'bbox': bboxes[ind].tolist(), 'score': scores[ind].item(), } if predict_mask: det['segmentation'] = masks[i] dets.append(det) return dets
def forward(self, loc_p, cls_p, mask_p, loc_t, cls_t, mask_t, rpn_loc_p, rpn_cls_p, rpn_loc_t, rpn_cls_t, ignore): rpn_loc_p = rpn_loc_p.view(-1, 4) rpn_cls_p = rpn_cls_p.view(-1, rpn_cls_p.size(-1)) rpn_loss = self.rpn_loss(rpn_loc_p, rpn_cls_p, rpn_loc_t, rpn_cls_t, ignore) num_classes = cls_p.size(-1) - 1 loc_p = expand_last_dim(loc_p, num_classes, 4) pos = cls_t != 0 labels = cls_t[pos] - 1 loc_p = select(loc_p[pos], 1, labels) rcnn_loss = self.rcnn_loss(loc_p, cls_p, loc_t, cls_t) mask_p = select(mask_p, 1, labels) mask_loss = F.binary_cross_entropy_with_logits(mask_p, mask_t) if random.random() < self.p: print("mask: %.4f" % mask_loss.item()) loss = rpn_loss + rcnn_loss + mask_loss return loss