class RetinaHead(nn.Module): def __init__(self, in_channel, inner_channel, num_cls=80, num_convs=4, layer_num=5, anchor_sizes=None, anchor_scales=None, anchor_ratios=None, strides=None): super(RetinaHead, self).__init__() self.num_cls = num_cls self.layer_num = layer_num if anchor_sizes is None: anchor_sizes = default_anchor_sizes self.anchor_sizes = anchor_sizes if anchor_scales is None: anchor_scales = default_anchor_scales self.anchor_scales = anchor_scales if anchor_ratios is None: anchor_ratios = default_anchor_ratios self.anchor_ratios = anchor_ratios if strides is None: strides = default_strides self.strides = strides self.anchor_nums = len(self.anchor_scales) * len(self.anchor_ratios) self.scales = nn.ModuleList( [Scale(init_val=1.0) for _ in range(self.layer_num)]) self.anchors = [torch.zeros(size=(0, 4))] * self.layer_num self.box_coder = BoxCoder() self.cls_head = RetinaClsHead(in_channel, inner_channel, self.anchor_nums, num_cls, num_convs) self.reg_head = RetinaRegHead(in_channel, inner_channel, self.anchor_nums, num_convs) def build_anchors_delta(self, size=32.): """ :param size: :return: [anchor_num, 4] """ scales = torch.tensor(self.anchor_scales).float() ratio = torch.tensor(self.anchor_ratios).float() scale_size = (scales * size) w = (scale_size[:, None] * ratio[None, :].sqrt()).view(-1) / 2 h = (scale_size[:, None] / ratio[None, :].sqrt()).view(-1) / 2 delta = torch.stack([-w, -h, w, h], dim=1) return delta def build_anchors(self, feature_maps): """ :param feature_maps: :return: list(anchor) anchor:[all,4] (x1,y1,x2,y2) """ assert self.layer_num == len(feature_maps) assert len(self.anchor_sizes) == len(feature_maps) assert len(self.anchor_sizes) == len(self.strides) anchors = list() for stride, size, feature_map in zip(self.strides, self.anchor_sizes, feature_maps): # 9*4 anchor_delta = self.build_anchors_delta(size) _, _, ny, nx = feature_map.shape yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) # h,w,4 grid = torch.stack([xv, yv, xv, yv], 2).float() anchor = (grid[:, :, None, :] + 0.5) * stride + anchor_delta[None, None, :, :] anchor = anchor.view(-1, 4) anchors.append(anchor) return anchors def forward(self, xs): cls_outputs = list() reg_outputs = list() for j, x in enumerate(xs): cls_outputs.append(self.cls_head(x)) reg_outputs.append(self.scales[j](self.reg_head(x))) if self.anchors[0] is None or self.anchors[0].shape[0] != cls_outputs[ 0].shape[1]: with torch.no_grad(): anchors = self.build_anchors(xs) assert len(anchors) == len(self.anchors) for i, anchor in enumerate(anchors): self.anchors[i] = anchor.to(xs[0].device) if self.training: return cls_outputs, reg_outputs, self.anchors else: predicts_list = list() for cls_out, reg_out, anchor in zip(cls_outputs, reg_outputs, self.anchors): scale_reg = self.box_coder.decoder(reg_out, anchor) predicts_out = torch.cat([scale_reg, cls_out], dim=-1) predicts_list.append(predicts_out) return predicts_list
class RetinaLoss(object): def __init__(self, iou_thresh=0.5, ignore_thresh=0.4, alpha=0.25, gamma=2.0, iou_type="giou", coord_type="xyxy"): self.iou_thresh = iou_thresh self.ignore_thresh = ignore_thresh self.alpha = alpha self.gama = gamma self.builder = RetinaLossBuilder(iou_thresh, ignore_thresh) self.iou_loss = IOULoss(iou_type, coord_type) self.box_coder = BoxCoder() def __call__(self, cls_predicts, reg_predicts, anchors, targets): """ :param cls_predicts: list(cls_predict) cls_predict[bs,all,num_cls] :param reg_predicts: list(reg_predict) reg_predict[bs,all,4] :param anchors: list(anchor) anchor[all,4] :param targets: [gt_num,7] (batch_id,weights,label_id,x1,y1,x2,y2) :return: """ for i in range(len(cls_predicts)): if cls_predicts[i].dtype == torch.float16: cls_predicts[i] = cls_predicts[i].float() device = cls_predicts[0].device bs = cls_predicts[0].shape[0] flags, gt_targets, all_anchors = self.builder(bs, anchors, targets) cls_loss_list = list() reg_loss_list = list() pos_num_sum = 0 for bi in range(bs): batch_cls_predict = torch.cat([cls_item[bi] for cls_item in cls_predicts], dim=0) \ .sigmoid() \ .clamp(1e-6, 1 - 1e-6) batch_reg_predict = torch.cat([reg_item[bi] for reg_item in reg_predicts], dim=0) flag = flags[bi] gt = gt_targets[bi] pos_idx = (flag == 1).nonzero(as_tuple=False).squeeze(1) pos_num = len(pos_idx) if pos_num == 0: neg_cls_loss = -(1 - self.alpha) * batch_cls_predict ** self.gama * ((1 - batch_cls_predict).log()) cls_loss_list.append(neg_cls_loss.sum()) continue pos_num_sum += pos_num neg_idx = (flag == 0).nonzero(as_tuple=False).squeeze(1) valid_idx = torch.cat([pos_idx, neg_idx]) valid_cls_predicts = batch_cls_predict[valid_idx, :] cls_targets = torch.zeros(size=valid_cls_predicts.shape, device=device) cls_targets[range(pos_num), gt[pos_idx, 1].long()] = 1. pos_loss = -self.alpha * cls_targets * ((1 - valid_cls_predicts) ** self.gama) * valid_cls_predicts.log() neg_loss = -(1 - self.alpha) * (1. - cls_targets) * (valid_cls_predicts ** self.gama) * ( (1 - valid_cls_predicts).log()) cls_loss = (pos_loss + neg_loss).sum() cls_loss_list.append(cls_loss) valid_reg_predicts = batch_reg_predict[pos_idx, :] predict_box = self.box_coder.decoder(valid_reg_predicts, all_anchors[pos_idx]) gt_bbox = gt[pos_idx, 2:] reg_loss = self.iou_loss(predict_box, gt_bbox) reg_loss_list.append(reg_loss.sum()) cls_loss_sum = torch.stack(cls_loss_list).sum() if pos_num_sum == 0: total_loss = cls_loss_sum / bs return total_loss, torch.stack([cls_loss_sum, torch.tensor(data=0., device=device)]).detach(), pos_num_sum reg_loss_sum = torch.stack(reg_loss_list).sum() cls_loss_mean = cls_loss_sum / pos_num_sum reg_loss_mean = reg_loss_sum / pos_num_sum total_loss = cls_loss_mean + reg_loss_mean return total_loss, torch.stack([cls_loss_mean, reg_loss_mean]).detach(), pos_num_sum
class RetinaAnchorFreeLoss(object): def __init__(self, gamma=2.0, alpha=0.25, top_k=64, box_iou_thresh=0.6, box_reg_weight=0.75, beta=1. / 9): super(RetinaAnchorFreeLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.top_k = top_k self.box_iou_thresh = box_iou_thresh self.box_reg_weight = box_reg_weight self.beta = beta self.box_coder = BoxCoder() def __call__(self, cls_predicts, box_predicts, anchors, targets): """ :param cls_predicts: :param box_predicts: :param anchors: :param targets: :return: """ device = cls_predicts[0].device bs = cls_predicts[0].shape[0] cls_num = cls_predicts[0].shape[-1] expand_anchor = torch.cat(anchors, dim=0) negative_loss_list = list() positive_loss_list = list() for bi in range(bs): batch_cls_predicts = torch.cat([cls_item[bi] for cls_item in cls_predicts], dim=0) \ .sigmoid() \ .clamp(min=1e-6, max=1 - 1e-6) batch_targets = targets[targets[:, 0] == bi, 1:] if len(batch_targets) == 0: negative_loss = -(1 - self.alpha) * ( batch_cls_predicts** self.gamma) * (1 - batch_cls_predicts).log() negative_loss_list.append(negative_loss.sum()) continue batch_box_predicts = torch.cat( [box_item[bi] for box_item in box_predicts], dim=0) # calc_positive_loss targets_anchor_iou = box_iou(batch_targets[:, 2:], expand_anchor) _, top_k_anchor_idx = targets_anchor_iou.topk(k=self.top_k, dim=1, sorted=False) matched_cls_prob = batch_cls_predicts[top_k_anchor_idx].gather( dim=-1, index=(batch_targets[:, [1]][:, None, :]).long().repeat( 1, self.top_k, 1)).squeeze(-1) match_box_target = self.box_coder.encoder( expand_anchor[top_k_anchor_idx], batch_targets[:, None, 2:]) matched_box_prob = ( -self.box_reg_weight * smooth_l1_loss(batch_box_predicts[top_k_anchor_idx], match_box_target, self.beta).sum(-1)).exp() positive_loss = self.alpha * mean_max( matched_cls_prob * matched_box_prob).sum() positive_loss_list.append(positive_loss) with torch.no_grad(): box_localization = self.box_coder.decoder( batch_box_predicts, expand_anchor) target_box_iou = box_iou(batch_targets[:, 2:], box_localization) t1 = self.box_iou_thresh t2 = target_box_iou.max(dim=1, keepdim=True)[0].clamp(min=t1 + 1e-6) target_box_prob = ((target_box_iou - t1) / (t2 - t1)).clamp(min=0., max=1.) indices = torch.stack([ torch.arange(len(batch_targets), device=device), batch_targets[:, 1] ], dim=0).long() object_cls_box_prob = torch.sparse_coo_tensor(indices, target_box_prob, device=device) cls_idx, anchor_idx = torch.sparse.sum( object_cls_box_prob, dim=0).to_dense().nonzero(as_tuple=False).t() if len(cls_idx) == 0: negative_loss = -(1 - self.alpha) * ( batch_cls_predicts** self.gamma) * (1 - batch_cls_predicts).log() negative_loss_list.append(negative_loss.sum()) continue anchor_positive_max_prob = torch.where( batch_targets[:, [1]].long() == cls_idx, target_box_prob[:, anchor_idx], torch.tensor(data=0., device=device)).max(dim=0)[0] anchor_cls_assign_prob = torch.zeros(size=(len(expand_anchor), cls_num), device=device) anchor_cls_assign_prob[anchor_idx, cls_idx] = anchor_positive_max_prob negative_prob = batch_cls_predicts * (1 - anchor_cls_assign_prob) negative_loss = -(1 - self.alpha) * (negative_prob**self.gamma) * ( 1 - negative_prob).log() negative_loss_list.append(negative_loss.sum()) negative_losses = torch.stack(negative_loss_list).sum() / max( 1, len(targets)) if len(positive_loss_list) == 0: total_loss = negative_losses return total_loss, torch.stack( [negative_losses, torch.tensor(data=0., device=device)]), len(targets) positive_losses = torch.stack(positive_loss_list).sum() / max( 1, len(targets)) total_loss = negative_losses + positive_losses return total_loss, torch.stack([negative_losses, positive_losses]), len(targets)
class RetinaAnchorFreeLoss(object): def __init__(self, gamma=2.0, alpha=0.25, top_k=50, box_iou_thresh=0.6, box_reg_weight=0.75, beta=1. / 9): super(RetinaAnchorFreeLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.top_k = top_k self.box_iou_thresh = box_iou_thresh self.box_reg_weight = box_reg_weight self.beta = beta self.positive_bag_loss_func = positive_bag_loss self.negative_bag_loss_func = focal_loss self.box_coder = BoxCoder() def __call__(self, cls_predicts, box_predicts, anchors, targets): ''' :param cls_predicts: :param box_predicts: :param anchors: :param targets: :return: ''' device = cls_predicts[0].device bs = cls_predicts[0].shape[0] cls_num = cls_predicts[0].shape[-1] expand_anchor = torch.cat(anchors, dim=0) #shape=[num_anchors,4] positive_numels = 0 # gt_box的数量 box_prob = list() # store P_A+, P_A-=1-P_A+ positive_loss_list = list() negative_loss_list = list() cls_probs = list() for bi in range(bs): cls_prob = torch.cat( [cls_item[bi] for cls_item in cls_predicts], dim=0).sigmoid().clamp( min=1e-6, max=1 - 1e-6) # cls_predict, shape=[num_anchors,80] target = targets[ targets[:, 0] == bi, 1:] # gt_box, shape=[num_gts,6] 6==>conf_score,label_id,x1,y1,x2,y2 # if no gt_box exist, just calc focal loss in negative condition if len(target) == 0: # negative_loss = -(cls_prob ** self.gamma) * (1 - cls_prob).log() negative_loss = -(cls_prob**self.gamma) * ( (1 - cls_prob).clamp( min=1e-10, max=1.0 - 1e-10).log().clamp(min=-1000., max=1000.)) negative_loss_list.append(negative_loss.sum()) continue cls_probs.append(cls_prob) box_regression = torch.cat( [box_item[bi] for box_item in box_predicts], dim=0) # box_predict , shape=[num_anchors,4] with torch.set_grad_enabled(False): # box_localization: a_{j}^{loc}, shape: [j, 4] box_localization = self.box_coder.decoder( box_regression, expand_anchor) # shape=[num_anchors,4] 4==>x1,y1,x2,y2 # object_box_iou: IoU_{ij}^{loc}, shape: [i, j] object_box_iou = box_iou( target[:, 2:], box_localization) # shape=(num_gts,num_anchors) t1 = self.box_iou_thresh t2 = object_box_iou.max(dim=1, keepdim=True)[0].clamp( min=t1 + 1e-12) # shape=[num_gts,1] # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j] object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp( min=0, max=1.) ''' indices.shape=[2,num_gts] 第0行元素代表所对应的gt_box的索引, 第1行元素代表所对应的gt_box所属的类别 ''' indices = torch.stack( [torch.arange(len(target), device=device), target[:, 1]], dim=0).long() # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j] ''' object_cls_box_prob.shape=[num_gts, max_cls_id+1, num_anchors] 按照类别的取值填充 note: 如果索引为gt_id的gt_box所属的类别为label_id, 则object_cls_box_prob[gt_id,label_id]=target_box_prob[gt_id], 其他位置均为0 ''' object_cls_box_prob = torch.sparse_coo_tensor(indices, object_box_prob, device=device) """ image_box_prob: P{a_{j} \in A_{+}}, shape: [j, c] or [num_anchors,num_cls] image_box_prob是用来判断一个anchor是否可以匹配到某个目标(无论类别和匹配到gt box是什么)的置信度 from "start" to "end" implement: image_box_prob = torch.sparse.max(object_cls_box_prob, dim=0).t() """ # start # indices = torch.nonzero(torch.sparse.sum(object_cls_box_prob, dim=0).to_dense()).t_() # shape=[2,N] indices = torch.sparse.sum( object_cls_box_prob, dim=0).to_dense().nonzero( as_tuple=False).t() # shape=[2,N] if indices.numel() == 0: image_box_prob = torch.zeros( expand_anchor.shape[0], cls_num).type_as(object_box_prob) else: nonzero_box_prob = torch.where( target[:, 1].unsqueeze(dim=-1) == indices[0], # (num_gts,1)== (N) ===>(num_gts,N) object_box_prob[:, indices[1]], torch.tensor([ 0 ]).type_as(object_box_prob)).max(dim=0)[0] # ===> (N) image_box_prob = torch.sparse_coo_tensor( indices.flip([0]), nonzero_box_prob, size=(expand_anchor.shape[0], cls_num), # shape=[num_anchors,num_cls] device=device).to_dense() # end box_prob.append(image_box_prob) # construct bags for objects match_quality_matrix = box_iou(target[:, 2:], expand_anchor) _, matched = torch.topk( match_quality_matrix, self.top_k, dim=1, sorted=False ) # shape=(num_gts,top_k) 元素的取值范围[0,num_gts) 表示匹配到某个gt的anchor集合的索引 del match_quality_matrix # matched_cls_prob: P_{ij}^{cls} # shape=(num_gts,top_k) 元素的取值范围[0,num_cls) 表示匹配到某个gt的anchor所属的类别 matched_cls_prob = cls_prob[matched].gather( dim=-1, index=(target[:, [1]][:, None, :]).long().repeat(1, self.top_k, 1)).squeeze(-1) # matched_box_prob: P_{ij}^{loc} matched_object_targets = self.box_coder.encoder( expand_anchor[matched], target[:, 2:].unsqueeze(dim=1)) # shape=[num_gts,topk,4] # P_loc retinanet_regression_loss = smooth_l1_loss(box_regression[matched], matched_object_targets, self.box_reg_weight, self.beta) matched_box_prob = torch.exp(-retinanet_regression_loss) # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } positive_numels += len(target) positive_loss_list.append( self.positive_bag_loss_func(matched_cls_prob * matched_box_prob, dim=1)) # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B|| # positive_loss = torch.cat(positive_loss_list).sum() / max(1, positive_numels) item1 = torch.cat(positive_loss_list).sum() item2 = max(1, positive_numels) positive_loss = reduce_sum(item1) / reduce_sum( torch.tensor(data=item2, device=device).float()).item() # box_prob: P{a_{j} \in A_{+}} box_prob = torch.stack(box_prob, dim=0) cls_probs = torch.stack(cls_probs, dim=0) # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B|| ''' (1-P_bg)<==>P_cls shape=[num_anchors,num_cls] P{A-}<==>(1-P{box_cls}) ''' if len(negative_loss_list) != 0: neg_loss_empty = torch.stack(negative_loss_list, dim=0).sum() else: neg_loss_empty = 0 # negative_loss = (neg_loss_empty + self.negative_bag_loss_func(cls_probs * (1 - box_prob), self.gamma)) / max(1, positive_numels * self.top_k) item3 = neg_loss_empty + self.negative_bag_loss_func( cls_probs * (1 - box_prob), self.gamma) item4 = max(1, positive_numels * self.top_k) negative_loss = reduce_sum(item3) / reduce_sum( torch.tensor(data=item4, device=device).float()).item() total_loss = positive_loss * self.alpha + negative_loss * (1 - self.alpha) # total_loss=reduce_sum(total_loss)/get_world_size() return total_loss, torch.stack([negative_loss, positive_loss]), positive_numels