def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] """ Computes the loss for Faster R-CNN. Arguments: class_logits (Tensor): 预测类别概率信息,shape=[num_anchors, num_classes] box_regression (Tensor): 预测边目标界框回归信息 labels (list[BoxList]): 真实类别信息 regression_targets (Tensor): 真实目标边界框信息 Returns: classification_loss (Tensor) box_loss (Tensor) """ labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) # 计算类别损失信息 classification_loss = F.cross_entropy(class_logits, labels) # get indices that correspond to the regression targets for # the corresponding ground truth labels, to be used with # advanced indexing # 返回标签类别大于0的索引 # sampled_pos_inds_subset = torch.nonzero(torch.gt(labels, 0)).squeeze(1) sampled_pos_inds_subset = torch.where(torch.gt(labels, 0))[0] # 返回标签类别大于0位置的类别信息 labels_pos = labels[sampled_pos_inds_subset] # shape=[num_proposal, num_classes] N, num_classes = class_logits.shape box_regression = box_regression.reshape(N, -1, 4) # 计算边界框损失信息 box_loss = det_utils.smooth_l1_loss( # 获取指定索引proposal的指定类别box信息 box_regression[sampled_pos_inds_subset, labels_pos], regression_targets[sampled_pos_inds_subset], beta=1 / 9, size_average=False, ) / labels.numel() return classification_loss, box_loss
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] """ 计算RPN损失,包括类别损失(前景与背景),bbox regression损失 Arguments: objectness (Tensor):预测的前景概率 pred_bbox_deltas (Tensor):预测的bbox regression labels (List[Tensor]):真实的标签 1, 0, -1(batch中每一张图片的labels对应List的一个元素中) regression_targets (List[Tensor]):真实的bbox regression Returns: objectness_loss (Tensor) : 类别损失 box_loss (Tensor):边界框回归损失 """ # 按照给定的batch_size_per_image, positive_fraction选择正负样本 sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) # 将一个batch中的所有正负样本List(Tensor)分别拼接在一起,并获取非零位置的索引 # sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0] # sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0] # 将所有正负样本索引拼接在一起 sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) objectness = objectness.flatten() labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) # 计算边界框回归损失 box_loss = det_utils.smooth_l1_loss( pred_bbox_deltas[sampled_pos_inds], regression_targets[sampled_pos_inds], beta=1 / 9, size_average=False, ) / (sampled_inds.numel()) # 计算目标预测概率损失 objectness_loss = F.binary_cross_entropy_with_logits( objectness[sampled_inds], labels[sampled_inds] ) return objectness_loss, box_loss