def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores", nms_type='nms', vote_thresh=0.65): if nms_thresh <= 0: return boxlist mode = boxlist.mode boxlist = boxlist.convert("xyxy") boxes = boxlist.bbox score = boxlist.get_field(score_field) if nms_type == 'nms': keep = _box_nms(boxes, score, nms_thresh) if max_proposals > 0: keep = keep[:max_proposals] boxlist = boxlist[keep] else: if nms_type == 'vote': boxes_vote, scores_vote = bbox_vote(boxes, score, vote_thresh) else: boxes_vote, scores_vote = soft_bbox_vote(boxes, score, vote_thresh) if len(boxes_vote) > 0: boxlist.bbox = boxes_vote boxlist.extra_fields['scores'] = scores_vote return boxlist.convert(mode)
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"): """ Performs non-maximum suppression on a boxlist, with scores specified in a boxlist field via score_field. Arguments: boxlist(BoxList) nms_thresh (float) max_proposals (int): if > 0, then only the top max_proposals are kept after non-maximum suppression score_field (str) """ if nms_thresh <= 0: return boxlist mode = boxlist.mode boxlist = boxlist.convert("xyxy") boxes = boxlist.bbox score = boxlist.get_field(score_field) keep = _box_nms(boxes, score, nms_thresh) if max_proposals > 0: keep = keep[:max_proposals] boxlist = boxlist[keep] return boxlist.convert(mode)
def __call__(self, locations, box_cls, box_regression, centerness, cof_preds, feat_mask, targets): """ Arguments: locations (list[BoxList]) box_cls (list[Tensor]) box_regression (list[Tensor]) centerness (list[Tensor]) targets (list[BoxList]) Returns: cls_loss (Tensor) reg_loss (Tensor) centerness_loss (Tensor) """ N = box_cls[0].size(0) num_classes = box_cls[0].size(1) labels, reg_targets, labels_list, bbox_gt_list, gt_inds = self.prepare_targets( locations, targets) ######decode box######## sampled_boxes = [] for _, (l, b, s) in enumerate( zip(locations, box_regression, self.fpn_strides)): sampled_boxes.append(self.decode_for_single_feature_map(l, b, s)) flatten_sampled_boxes = [ torch.cat([ labels_level_img.reshape(-1, 4) for labels_level_img in sampled_boxes_per_img ]) for sampled_boxes_per_img in zip(*sampled_boxes) ] box_cls_flatten = [] box_regression_flatten = [] centerness_flatten = [] labels_flatten = [] reg_targets_flatten = [] for l in range(len(labels)): box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape( -1, num_classes)) box_regression_flatten.append(box_regression[l].permute( 0, 2, 3, 1).reshape(-1, 4)) labels_flatten.append(labels[l].reshape(-1)) reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) centerness_flatten.append(centerness[l].reshape(-1)) box_cls_flatten = torch.cat(box_cls_flatten, dim=0) box_regression_flatten = torch.cat(box_regression_flatten, dim=0) centerness_flatten = torch.cat(centerness_flatten, dim=0) labels_flatten = torch.cat(labels_flatten, dim=0) reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] centerness_flatten = centerness_flatten[pos_inds] num_gpus = get_num_gpus() # sync num_pos from all gpus total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel() ])).item() num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0) cls_loss = self.cls_loss_func( box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu if pos_inds.numel() > 0: centerness_targets = self.compute_centerness_targets( reg_targets_flatten) # average sum_centerness_targets from all gpus, # which is used to normalize centerness-weighed reg loss sum_centerness_targets_avg_per_gpu = \ reduce_sum(centerness_targets.sum()).item() / float(num_gpus) reg_loss = self.box_reg_loss_func( box_regression_flatten, reg_targets_flatten, centerness_targets) / sum_centerness_targets_avg_per_gpu centerness_loss = self.centerness_loss_func( centerness_flatten, centerness_targets) / num_pos_avg_per_gpu else: reg_loss = box_regression_flatten.sum() reduce_sum(centerness_flatten.new_tensor([0.0])) centerness_loss = centerness_flatten.sum() ##########mask loss################# num_imgs = len(flatten_sampled_boxes) flatten_cls_scores1 = [] for l in range(len(labels)): flatten_cls_scores1.append(box_cls[l].permute(0, 2, 3, 1).reshape( num_imgs, -1, num_classes)) flatten_cls_scores1 = torch.cat(flatten_cls_scores1, dim=1) flatten_cof_preds = [ cof_pred.permute(0, 2, 3, 1).reshape(len(labels_list), -1, 32 * 4) for cof_pred in cof_preds ] flatten_cof_preds = torch.cat(flatten_cof_preds, dim=1) loss_mask = 0 for i in range(num_imgs): labels = torch.cat( [labels_level.flatten() for labels_level in labels_list[i]]) # bbox_gt = torch.cat([gt_level.reshape(-1,4) for gt_level in bbox_gt_list[i]]) bbox_dt = flatten_sampled_boxes[i] / 2 bbox_dt = bbox_dt.detach() pos_inds = labels > 0 cof_pred = flatten_cof_preds[i][pos_inds] img_mask = feat_mask[i] mask_h = feat_mask[i].shape[1] mask_w = feat_mask[i].shape[2] idx_gt = gt_inds[i] bbox_dt = bbox_dt[pos_inds, :4] gt_masks = targets[i].get_field("masks").get_mask_tensor().to( dtype=torch.float32, device=feat_mask.device) gt_masks = gt_masks.reshape(-1, gt_masks.shape[-2], gt_masks.shape[-1]) area = (bbox_dt[:, 2] - bbox_dt[:, 0]) * (bbox_dt[:, 3] - bbox_dt[:, 1]) bbox_dt = bbox_dt[area > 1.0, :] idx_gt = idx_gt[area > 1.0] cof_pred = cof_pred[area > 1.0] if bbox_dt.shape[0] == 0: continue bbox_gt = targets[i].bbox cls_score = flatten_cls_scores1[i, pos_inds, labels[pos_inds] - 1].sigmoid().detach() cls_score = cls_score[area > 1.0] ious = bbox_overlaps(bbox_gt[idx_gt] / 2, bbox_dt, is_aligned=True) weighting = cls_score * ious weighting = weighting / torch.sum(weighting) * len(weighting) keep = _box_nms(bbox_dt, cls_score, 0.9) bbox_dt = bbox_dt[keep] weighting = weighting[keep] idx_gt = idx_gt[keep] cof_pred = cof_pred[keep] gt_mask = F.interpolate(gt_masks.unsqueeze(0), scale_factor=0.5, mode='bilinear', align_corners=False).squeeze(0) shape = np.minimum(feat_mask[i].shape, gt_mask.shape) gt_mask_new = gt_mask.new_zeros(gt_mask.shape[0], mask_h, mask_w) gt_mask_new[:gt_mask.shape[0], :shape[1], : shape[2]] = gt_mask[:gt_mask. shape[0], :shape[1], :shape[2]] gt_mask_new = gt_mask_new.gt(0.5).float() gt_mask_new = torch.index_select(gt_mask_new, 0, idx_gt).permute(1, 2, 0).contiguous() #######spp########################### img_mask1 = img_mask.permute(1, 2, 0) pos_masks00 = torch.sigmoid(img_mask1 @ cof_pred[:, 0:32].t()) pos_masks01 = torch.sigmoid(img_mask1 @ cof_pred[:, 32:64].t()) pos_masks10 = torch.sigmoid(img_mask1 @ cof_pred[:, 64:96].t()) pos_masks11 = torch.sigmoid(img_mask1 @ cof_pred[:, 96:128].t()) pred_masks = torch.stack( [pos_masks00, pos_masks01, pos_masks10, pos_masks11], dim=0) pred_masks = self.crop_cuda(pred_masks, bbox_dt) gt_mask_crop = self.crop_gt_cuda(gt_mask_new, bbox_dt) pre_loss = F.binary_cross_entropy(pred_masks, gt_mask_crop, reduction='none') pos_get_csize = center_size(bbox_dt) gt_box_width = pos_get_csize[:, 2] gt_box_height = pos_get_csize[:, 3] pre_loss = pre_loss.sum(dim=( 0, 1)) / gt_box_width / gt_box_height / pos_get_csize.shape[0] loss_mask += torch.sum(pre_loss * weighting.detach()) loss_mask = loss_mask / num_imgs if loss_mask > 1.0: loss_mask = loss_mask * 0.5 return cls_loss, reg_loss, centerness_loss, loss_mask