Example #1
0
def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20):
    """
    :param anchor_matches: (n_anchors). [-1, 0, class_id] for negative, neutral, and positive matched anchors.
    :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network.
    :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining).
    :return: loss: torch tensor.
    :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
    """
    # Positive and Negative anchors contribute to the loss,
    # but neutral anchors (match value = 0) don't.
    pos_indices = torch.nonzero(anchor_matches > 0)
    neg_indices = torch.nonzero(anchor_matches == -1)

    # get positive samples and calucalte loss.
    if 0 not in pos_indices.size():
        pos_indices = pos_indices.squeeze(1)
        roi_logits_pos = class_pred_logits[pos_indices]
        targets_pos = anchor_matches[pos_indices]
        # print("roi_logits_pos.shape, targets_pos.shape : ", roi_logits_pos.shape, targets_pos.shape)
        # print("roi_logits_pos : ", roi_logits_pos)
        # print("targets_pos : ", targets_pos)
        # pos_loss = FocalLoss()(roi_logits_pos, targets_pos.long()).cuda()
        pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long())
    else:
        pos_loss = torch.FloatTensor([0]).cuda()

    # get negative samples, such that the amount matches the number of positive samples, but at least 1.
    # get high scoring negatives by applying online-hard-example-mining.
    if 0 not in neg_indices.size():
        neg_indices = neg_indices.squeeze(1)
        roi_logits_neg = class_pred_logits[neg_indices]
        negative_count = np.max((1, pos_indices.size()[0]))
        roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
        neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
        neg_loss = F.cross_entropy(
            roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda())
        # neg_loss = FocalLoss()(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()).cuda()
        # return the indices of negative samples, which contributed to the loss (for monitoring plots).
        np_neg_ix = neg_ix.cpu().data.numpy()
    else:
        neg_loss = torch.FloatTensor([0]).cuda()
        np_neg_ix = np.array([]).astype('int32')

    loss = (pos_loss + neg_loss) / 2
    return loss, np_neg_ix
Example #2
0
def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize):
    """
    :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
    :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier.
    :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining).
    :return: loss: torch tensor
    :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
    """

    # Filter out netural anchors
    pos_indices = torch.nonzero(rpn_match == 1)
    neg_indices = torch.nonzero(rpn_match == -1)

    # loss for positive samples
    if not 0 in pos_indices.size():
        pos_indices = pos_indices.squeeze(1)
        roi_logits_pos = rpn_class_logits[pos_indices]
        pos_loss = F.cross_entropy(
            roi_logits_pos,
            torch.LongTensor([1] * pos_indices.shape[0]).cuda())
    else:
        pos_loss = torch.FloatTensor([0]).cuda()

    # loss for negative samples: draw hard negative examples (SHEM)
    # that match the number of positive samples, but at least 1.
    if not 0 in neg_indices.size():
        neg_indices = neg_indices.squeeze(1)
        roi_logits_neg = rpn_class_logits[neg_indices]
        negative_count = np.max((1, pos_indices.cpu().data.numpy().size))
        roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
        neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
        neg_loss = F.cross_entropy(
            roi_logits_neg[neg_ix],
            torch.LongTensor([0] * neg_ix.shape[0]).cuda())
        np_neg_ix = neg_ix.cpu().data.numpy()
        #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count)
    else:
        neg_loss = torch.FloatTensor([0]).cuda()
        np_neg_ix = np.array([]).astype('int32')

    loss = (pos_loss + neg_loss) / 2
    return loss, np_neg_ix
Example #3
0
def detection_target_layer(batch_proposals, batch_mrcnn_class_scores, batch_gt_class_ids, batch_gt_boxes, cf):
    """
    Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive
    effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic-hard-example-mining
    (SHEM), where a number of negative proposals are drawn from larger pool of highest scoring proposals for stochasticity.
    Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than
    loss-based class balancing methods like "online-hard-example-mining" or "focal loss".)

    :param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs).
    boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS.
    :param batch_mrcnn_class_scores: (n_proposals, n_classes)
    :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
    :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
    :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c)
    :return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions.
    :return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals.
    :return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement.
    :return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals.
    """
    # normalization of target coordinates
    if cf.dim == 2:
        h, w = cf.patch_size
        scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda()
    else:
        h, w, z = cf.patch_size
        scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda()


    positive_count = 0
    negative_count = 0
    sample_positive_indices = []
    sample_negative_indices = []
    sample_deltas = []
    sample_class_ids = []

    # loop over batch and get positive and negative sample rois.
    for b in range(len(batch_gt_class_ids)):

        gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda()
        if np.any(batch_gt_class_ids[b] > 0):  # skip roi selection for no gt images.
            gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale
        else:
            gt_boxes = torch.FloatTensor().cuda()

        # get proposals and indices of current batch element.
        proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1]
        batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1)

        # Compute overlaps matrix [proposals, gt_boxes]
        if 0 not in gt_boxes.size():
            if gt_boxes.shape[1] == 4:
                overlaps = mutils.bbox_overlaps_2D(proposals, gt_boxes)
            else:
                overlaps = mutils.bbox_overlaps_3D(proposals, gt_boxes)

            # Determine postive and negative ROIs
            roi_iou_max = torch.max(overlaps, dim=1)[0]
            # 1. Positive ROIs are those with >= 0.5 IoU with a GT box
            positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3)
            # 2. Negative ROIs are those with < 0.1 with every GT box.
            negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01)
        else:
            positive_roi_bool = torch.FloatTensor().cuda()
            negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda()

        # Sample Positive ROIs
        if 0 not in torch.nonzero(positive_roi_bool).size():
            positive_indices = torch.nonzero(positive_roi_bool).squeeze(1)
            positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio)
            rand_idx = torch.randperm(positive_indices.size()[0])
            rand_idx = rand_idx[:positive_samples].cuda()
            positive_indices = positive_indices[rand_idx]
            positive_samples = positive_indices.size()[0]
            positive_rois = proposals[positive_indices, :]
            # Assign positive ROIs to GT boxes.
            positive_overlaps = overlaps[positive_indices, :]
            roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1]
            roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :]
            roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment]

            # Compute bbox refinement targets for positive ROIs
            deltas = mutils.box_refinement(positive_rois, roi_gt_boxes)
            std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda()
            deltas /= std_dev

            sample_positive_indices.append(batch_element_indices[positive_indices])
            sample_deltas.append(deltas)
            sample_class_ids.append(roi_gt_class_ids)
            positive_count += positive_samples
        else:
            positive_samples = 0

        # Negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM.
        if 0 not in torch.nonzero(negative_roi_bool).size():
            negative_indices = torch.nonzero(negative_roi_bool).squeeze(1)
            r = 1.0 / cf.roi_positive_ratio
            b_neg_count = np.max((int(r * positive_samples - positive_samples), 1))
            roi_probs_neg = batch_mrcnn_class_scores[batch_element_indices[negative_indices]]
            raw_sampled_indices = mutils.shem(roi_probs_neg, b_neg_count, cf.shem_poolsize)
            sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]])
            negative_count += raw_sampled_indices.size()[0]

    if len(sample_positive_indices) > 0:
        target_deltas = torch.cat(sample_deltas)
        target_class_ids = torch.cat(sample_class_ids)

    # Pad target information with zeros for negative ROIs.
    if positive_count > 0 and negative_count > 0:
        sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0)
        zeros = torch.zeros(negative_count).int().cuda()
        target_class_ids = torch.cat([target_class_ids, zeros], dim=0)
        zeros = torch.zeros(negative_count, cf.dim * 2).cuda()
        target_deltas = torch.cat([target_deltas, zeros], dim=0)
    elif positive_count > 0:
        sample_indices = torch.cat(sample_positive_indices)
    elif negative_count > 0:
        sample_indices = torch.cat(sample_negative_indices)
        zeros = torch.zeros(negative_count).int().cuda()
        target_class_ids = zeros
        zeros = torch.zeros(negative_count, cf.dim * 2).cuda()
        target_deltas = zeros
    else:
        sample_indices = torch.LongTensor().cuda()
        target_class_ids = torch.IntTensor().cuda()
        target_deltas = torch.FloatTensor().cuda()

    return sample_indices, target_class_ids, target_deltas