Beispiel #1
0
    def _create_label(self, anchor, bbox):
        # label: 1 is positive, 0 is negative, -1 is dont care
        label = -jt.ones((anchor.shape[0], ), dtype="int32")

        argmax_ious, max_ious, gt_argmax_ious = self._calc_ious(anchor, bbox)

        # assign negative labels first so that positive labels can clobber them
        label[max_ious < self.neg_iou_thresh] = 0

        # positive label: for each gt, anchor with highest iou
        label[gt_argmax_ious] = 1

        # positive label: above threshold IOU
        label[max_ious >= self.pos_iou_thresh] = 1

        # subsample positive labels if we have too many
        n_pos = int(self.pos_ratio * self.n_sample)
        pos_index = jt.where(label == 1)[0]
        if len(pos_index) > n_pos:
            tmp_index = np.arange(0, pos_index.shape[0])
            np.random.shuffle(tmp_index)
            disable_index = tmp_index[:pos_index.shape[0] - n_pos]
            disable_index = pos_index[disable_index]
            label[disable_index] = -1

        # subsample negative labels if we have too many
        n_neg = self.n_sample - jt.sum(label == 1).item()
        neg_index = jt.where(label == 0)[0]
        if len(neg_index) > n_neg:
            tmp_index = np.arange(0, neg_index.shape[0])
            np.random.shuffle(tmp_index)
            disable_index = tmp_index[:neg_index.shape[0] - n_neg]
            disable_index = neg_index[disable_index]
            label[disable_index] = -1
        return argmax_ious, label
Beispiel #2
0
 def test(self):
     assert (jt.where([0, 1, 0, 1])[0].data == [1, 3]).all()
     a, = jt.where([0, 1, 0, 1])
     assert a.uncertain_shape == [-4]
     a.data
     assert a.uncertain_shape == [2]
     a, b = jt.where([[0, 0, 1], [1, 0, 0]])
     assert (a.data == [0, 1]).all() and (b.data == [2, 0]).all()
Beispiel #3
0
    def detect(self, batch_idx, conf_preds, decoded_boxes, mask_data,
               inst_data):
        """ Perform nms for only the max scoring class that isn't background (class 0) """
        with timer.env('Slices'):
            cur_scores = conf_preds[batch_idx, 1:]
            conf_scores = jt.max(cur_scores, dim=0)

            keep = (conf_scores > self.conf_thresh)
            keep = jt.where(keep)[0]

            scores = cur_scores[:, keep]
            boxes = decoded_boxes[keep]
            masks = mask_data[batch_idx, keep]

            if inst_data is not None:
                inst = inst_data[batch_idx, keep]

        if scores.shape[1] == 0:
            return None
        if self.use_fast_nms:
            if self.use_cross_class_nms:
                boxes, masks, classes, scores = self.cc_fast_nms(
                    boxes, masks, scores, self.nms_thresh, self.top_k)
            else:
                boxes, masks, classes, scores = self.fast_nms(
                    boxes, masks, scores, self.nms_thresh, self.top_k)
        else:
            boxes, masks, classes, scores = self.traditional_nms(
                boxes, masks, scores, self.nms_thresh, self.conf_thresh)

            if self.use_cross_class_nms:
                print(
                    'Warning: Cross Class Traditional NMS is not implemented.')
        return {'box': boxes, 'mask': masks, 'class': classes, 'score': scores}
Beispiel #4
0
    def predict(self, images,score_thresh=0.7,nms_thresh = 0.3):
        N = images.shape[0]
        img_size = (images.shape[-1],images.shape[-2])
        rpn_locs, rpn_scores,roi_cls_locs, roi_scores, rois, roi_indices = self.execute(images)
        roi_cls_locs = roi_cls_locs.reshape(roi_cls_locs.shape[0],-1,4)
        probs = nn.softmax(roi_scores,dim=-1)
        rois = rois.unsqueeze(1).repeat(1,self.n_class,1)
        cls_bbox = loc2bbox(rois.reshape(-1,4),roi_cls_locs.reshape(-1,4))
        cls_bbox[:,0::2] = jt.clamp(cls_bbox[:,0::2],min_v=0,max_v=img_size[0])
        cls_bbox[:,1::2] = jt.clamp(cls_bbox[:,1::2],min_v=0,max_v=img_size[1])
        
        cls_bbox = cls_bbox.reshape(roi_cls_locs.shape)
        
        results = []
        for i in range(N):
            index = jt.where(roi_indices==i)[0]
            score = probs[index,:]
            bbox = cls_bbox[index,:,:]
            boxes = []
            scores = []
            labels = []
            for j in range(1,self.n_class):
                bbox_j = bbox[:,j,:]
                score_j = score[:,j]
                mask = jt.where(score_j>score_thresh)[0]
                bbox_j = bbox_j[mask,:]
                score_j = score_j[mask]
                dets = jt.contrib.concat([bbox_j,score_j.unsqueeze(1)],dim=1)
                keep = jt.nms(dets,nms_thresh)
                bbox_j = bbox_j[keep]
                score_j = score_j[keep]
                label_j = jt.ones_like(score_j).int32()*j
                boxes.append(bbox_j)
                scores.append(score_j)
                labels.append(label_j)
            
            boxes = jt.contrib.concat(boxes,dim=0)
            scores = jt.contrib.concat(scores,dim=0)
            labels = jt.contrib.concat(labels,dim=0)
            results.append((boxes,scores,labels))

        return results
    
    
    
    
        
Beispiel #5
0
    def _calc_ious(self, anchor, bbox):
        # ious between the anchors and the gt boxes
        ious = bbox_iou(anchor, bbox)
        argmax_ious, max_ious = ious.argmax(dim=1)

        gt_argmax_ious, gt_max_ious = ious.argmax(dim=0)
        gt_argmax_ious = jt.where(ious == gt_max_ious)[0]

        return argmax_ious, max_ious, gt_argmax_ious
Beispiel #6
0
def nonzero(x):
    r'''
    Return the index of the elements of input tensor which are not equal to zero.
    '''
    x = jt.where(x)
    x = [xx.unsqueeze(1) for xx in x]
    if len(x)<2:
        return x[0]
    x = jt.contrib.concat(x,dim=1)
    return x
Beispiel #7
0
 def test_vary_shape_dep2(self):
     a = jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
     index0, = jt.where(a.sum(1) > 7)  # [1,2]
     index0 = index0.broadcast([1, 3], dims=[1])  # [[1,1,1],[2,2,2]]
     index1 = index0.index_var(1)  # [[0,1,2],[0,1,2]]
     b = a.reindex_var([index0, index1])
     assert b.uncertain_shape == [-3, 3]
     assert (b.data == [[4, 5, 6], [7, 8, 9]]).all()
     assert (index0.data == [[1, 1, 1], [2, 2, 2]]).all()
     assert (index1.data == [[0, 1, 2], [0, 1, 2]]).all()
Beispiel #8
0
def remove_small_boxes(boxlist, min_size):
    """
    Only keep boxes with both sides >= min_size

    Arguments:
        boxlist (Boxlist)
        min_size (int)
    """
    # TODO maybe add an API for querying the ws / hs
    xywh_boxes = boxlist.convert("xywh").bbox
    _, _, ws, hs = xywh_boxes.unbind(dim=1)
    keep = jt.where(jt.logical_and((ws >= min_size), (hs >= min_size)))[0]
    return boxlist[keep]
Beispiel #9
0
    def _forward_train(self,features,img_size,boxes,labels):
        N = features.shape[0]
        rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn(features, img_size)
        
        sample_rois = []
        gt_roi_locs = []
        gt_roi_labels = []
        sample_roi_indexs = []
        gt_rpn_locs = []
        gt_rpn_labels = []
        for i in range(N):
            index = jt.where(roi_indices == i)[0]
            roi = rois[index,:]
            box = boxes[i]
            label = labels[i]
            sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(roi,box,label)
            sample_roi_index = i*jt.ones((sample_roi.shape[0],))
            
            sample_rois.append(sample_roi)
            gt_roi_labels.append(gt_roi_label)
            gt_roi_locs.append(gt_roi_loc)
            sample_roi_indexs.append(sample_roi_index)
            
            gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(box,anchor,img_size)
            gt_rpn_locs.append(gt_rpn_loc)
            gt_rpn_labels.append(gt_rpn_label)
            
        sample_roi_indexs = jt.contrib.concat(sample_roi_indexs,dim=0)
        sample_rois = jt.contrib.concat(sample_rois,dim=0)
        roi_cls_loc, roi_score = self.head(features,sample_rois,sample_roi_indexs)
        
        # ------------------ RPN losses -------------------#
        rpn_locs = rpn_locs.reshape(-1,4)
        rpn_scores = rpn_scores.reshape(-1,2)
        gt_rpn_labels = jt.contrib.concat(gt_rpn_labels,dim=0)
        gt_rpn_locs = jt.contrib.concat(gt_rpn_locs,dim=0)
        rpn_loc_loss = _fast_rcnn_loc_loss(rpn_locs,gt_rpn_locs,gt_rpn_labels,self.rpn_sigma)
        rpn_cls_loss = nn.cross_entropy_loss(rpn_scores[gt_rpn_labels>=0,:],gt_rpn_labels[gt_rpn_labels>=0])
        
        # ------------------ ROI losses (fast rcnn loss) -------------------#
        gt_roi_locs = jt.contrib.concat(gt_roi_locs,dim=0)
        gt_roi_labels = jt.contrib.concat(gt_roi_labels,dim=0)
        n_sample = roi_cls_loc.shape[0]
        roi_cls_loc = roi_cls_loc.view(n_sample, np.prod(roi_cls_loc.shape[1:]).item()//4, 4)
        roi_loc = roi_cls_loc[jt.arange(0, n_sample).int32(), gt_roi_labels]
        roi_loc_loss = _fast_rcnn_loc_loss(roi_loc,gt_roi_locs,gt_roi_labels,self.roi_sigma)
        roi_cls_loss = nn.cross_entropy_loss(roi_score, gt_roi_labels)

        losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss]
        losses = losses + [sum(losses)]
        return losses
Beispiel #10
0
    def execute(self, bbox, anchor, img_size):
        """Assign ground truth supervision to sampled subset of anchors.

        Types of input arrays and output arrays are same.

        Here are notations.

        * :math:`S` is the number of anchors.
        * :math:`R` is the number of bounding boxes.

        Args:
            bbox (array): Coordinates of bounding boxes. Its shape is
                :math:`(R, 4)`.
            anchor (array): Coordinates of anchors. Its shape is
                :math:`(S, 4)`.
            img_size (tuple of ints): A tuple :obj:`H, W`, which
                is a tuple of height and width of an image.

        Returns:
            (array, array):

            #NOTE: it's scale not only  offset
            * **loc**: Offsets and scales to match the anchors to \
                the ground truth bounding boxes. Its shape is :math:`(S, 4)`.
            * **label**: Labels of anchors with values \
                :obj:`(1=positive, 0=negative, -1=ignore)`. Its shape \
                is :math:`(S,)`.

        """

        img_W, img_H = img_size

        n_anchor = len(anchor)
        inside_index = jt.where((anchor[:, 0] >= 0) & (anchor[:, 1] >= 0)
                                & (anchor[:, 2] <= img_W)
                                & (anchor[:, 3] <= img_H))[0]
        anchor = anchor[inside_index]
        argmax_ious, label = self._create_label(anchor, bbox)

        # compute bounding box regression targets
        loc = bbox2loc(anchor, bbox[argmax_ious])

        # map up to original set of anchors
        label = _unmap(label, n_anchor, inside_index, fill=-1)
        loc = _unmap(loc, n_anchor, inside_index, fill=0)

        return loc, label
Beispiel #11
0
    def process_batch(self, detections, labels):
        """
        Return intersection-over-union (Jaccard index) of boxes.
        Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
        Arguments:
            detections (Array[N, 6]), x1, y1, x2, y2, conf, class
            labels (Array[M, 5]), class, x1, y1, x2, y2
        Returns:
            None, updates confusion matrix accordingly
        """
        detections = detections[detections[:, 4] > self.conf]
        gt_classes = labels[:, 0].int()
        detection_classes = detections[:, 5].int()
        iou = general.box_iou(labels[:, 1:], detections[:, :4])

        x = jt.where(iou > self.iou_thres)
        if x[0].shape[0]:
            matches = jt.contrib.concat(
                (jt.stack(x, 1), iou[x[0], x[1]][:, None]), 1).numpy()
            if x[0].shape[0] > 1:
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 1],
                                            return_index=True)[1]]
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 0],
                                            return_index=True)[1]]
        else:
            matches = np.zeros((0, 3))

        n = matches.shape[0] > 0
        m0, m1, _ = matches.transpose().astype(np.int16)
        for i, gc in enumerate(gt_classes):
            j = m0 == i
            if n and sum(j) == 1:
                self.matrix[gc, detection_classes[m1[j]]] += 1  # correct
            else:
                self.matrix[self.nc, gc] += 1  # background FP

        if n:
            for i, dc in enumerate(detection_classes):
                if not any(m1 == i):
                    self.matrix[dc, self.nc] += 1  # background FN
Beispiel #12
0
    def execute(self, roi, bbox, label):
        """Assigns ground truth to sampled proposals.

        This function samples total of :obj:`self.n_sample` RoIs
        from the combination of :obj:`roi` and :obj:`bbox`.
        The RoIs are assigned with the ground truth class labels as well as
        bounding box offsets and scales to match the ground truth bounding
        boxes. As many as :obj:`pos_ratio * self.n_sample` RoIs are
        sampled as foregrounds.

        Offsets and scales of bounding boxes are calculated using
        :func:`model.utils.bbox_tools.bbox2loc`.
        Also, types of input arrays and output arrays are same.

        Here are notations.

        * :math:`S` is the total number of sampled RoIs, which equals \
            :obj:`self.n_sample`.
        * :math:`L` is number of object classes possibly including the \
            background.

        Args:
            roi (array): Region of Interests (RoIs) from which we sample.
                Its shape is :math:`(R, 4)`
            bbox (array): The coordinates of ground truth bounding boxes.
                Its shape is :math:`(R', 4)`.
            label (array): Ground truth bounding box labels. Its shape
                is :math:`(R',)`. Its range is :math:`[0, L - 1]`, where
                :math:`L` is the number of foreground classes.

        Returns:
            (array, array, array):

            * **sample_roi**: Regions of interests that are sampled. \
                Its shape is :math:`(S, 4)`.
            * **gt_roi_loc**: Offsets and scales to match \
                the sampled RoIs to the ground truth bounding boxes. \
                Its shape is :math:`(S, 4)`.
            * **gt_roi_label**: Labels assigned to sampled RoIs. Its shape is \
                :math:`(S,)`. Its range is :math:`[0, L]`. The label with \
                value 0 is the background.

        """
        pos_roi_per_image = np.round(self.n_sample * self.pos_ratio)
        iou = bbox_iou(roi, bbox)
        gt_assignment, max_iou = iou.argmax(dim=1)
        # Offset range of classes from [0, n_fg_class - 1] to [1, n_fg_class].
        # The label with value 0 is the background.
        gt_roi_label = label[gt_assignment]

        # Select foreground RoIs as those with >= pos_iou_thresh IoU.
        pos_index = jt.where(max_iou >= self.pos_iou_thresh)[0]
        pos_roi_per_this_image = int(min(pos_roi_per_image,
                                         pos_index.shape[0]))
        if pos_index.shape[0] > 0:
            tmp_indexes = np.arange(0, pos_index.shape[0])
            np.random.shuffle(tmp_indexes)
            tmp_indexes = tmp_indexes[:pos_roi_per_this_image]
            pos_index = pos_index[tmp_indexes]

        # Select background RoIs as those within
        # [neg_iou_thresh_lo, neg_iou_thresh_hi).
        neg_index = jt.where((max_iou < self.neg_iou_thresh_hi)
                             & (max_iou >= self.neg_iou_thresh_lo))[0]
        neg_roi_per_this_image = self.n_sample - pos_roi_per_this_image
        neg_roi_per_this_image = int(
            min(neg_roi_per_this_image, neg_index.shape[0]))
        if neg_index.shape[0] > 0:
            tmp_indexes = np.arange(0, neg_index.shape[0])
            np.random.shuffle(tmp_indexes)
            tmp_indexes = tmp_indexes[:neg_roi_per_this_image]
            neg_index = neg_index[tmp_indexes]

        # The indices that we're selecting (both positive and negative).
        keep_index = jt.contrib.concat((pos_index, neg_index), dim=0)
        gt_roi_label = gt_roi_label[keep_index]
        gt_roi_label[pos_roi_per_this_image:] = 0  # negative labels --> 0
        sample_roi = roi[keep_index]

        # Compute offsets and scales to match sampled RoIs to the GTs.
        gt_roi_loc = bbox2loc(sample_roi, bbox[gt_assignment[keep_index]])

        return sample_roi, gt_roi_loc, gt_roi_label
Beispiel #13
0
    def execute(self, loc, score, anchor, img_size, scale=1.):
        """input should  be ndarray
        Propose RoIs.

        Inputs :obj:`loc, score, anchor` refer to the same anchor when indexed
        by the same index.

        On notations, :math:`R` is the total number of anchors. This is equal
        to product of the height and the width of an image and the number of
        anchor bases per pixel.

        Type of the output is same as the inputs.

        Args:
            loc (array): Predicted offsets and scaling to anchors.
                Its shape is :math:`(R, 4)`.
            score (array): Predicted foreground probability for anchors.
                Its shape is :math:`(R,)`.
            anchor (array): Coordinates of anchors. Its shape is
                :math:`(R, 4)`.
            img_size (tuple of ints): A tuple :obj:`height, width`,
                which contains image size after scaling.
            scale (float): The scaling factor used to scale an image after
                reading it from a file.

        Returns:
            array:
            An array of coordinates of proposal boxes.
            Its shape is :math:`(S, 4)`. :math:`S` is less than
            :obj:`self.n_test_post_nms` in test time and less than
            :obj:`self.n_train_post_nms` in train time. :math:`S` depends on
            the size of the predicted bounding boxes and the number of
            bounding boxes discarded by NMS.

        """
        # NOTE: when test, remember
        if self.is_training():
            n_pre_nms = self.n_train_pre_nms
            n_post_nms = self.n_train_post_nms
        else:
            n_pre_nms = self.n_test_pre_nms
            n_post_nms = self.n_test_post_nms

        # Convert anchors into proposal via bbox transformations.
        roi = loc2bbox(anchor, loc)

        # Clip predicted boxes to image.
        roi[:, 0] = jt.clamp(roi[:, 0], min_v=0, max_v=img_size[0])
        roi[:, 2] = jt.clamp(roi[:, 2], min_v=0, max_v=img_size[0])

        roi[:, 1] = jt.clamp(roi[:, 1], min_v=0, max_v=img_size[1])
        roi[:, 3] = jt.clamp(roi[:, 3], min_v=0, max_v=img_size[1])

        # Remove predicted boxes with either height or width < threshold.
        min_size = self.min_size * scale
        hs = roi[:, 2] - roi[:, 0]
        ws = roi[:, 3] - roi[:, 1]
        keep = jt.where((hs >= min_size) & (ws >= min_size))[0]
        roi = roi[keep, :]
        score = score[keep]

        # Sort all (proposal, score) pairs by score from highest to lowest.
        # Take top pre_nms_topN (e.g. 6000).
        order, _ = jt.argsort(score, descending=True)
        if n_pre_nms > 0:
            order = order[:n_pre_nms]
        roi = roi[order, :]
        score = score[order]

        # Apply nms (e.g. threshold = 0.7).
        # Take after_nms_topN (e.g. 300).

        dets = jt.contrib.concat([roi, score.unsqueeze(1)], dim=1)
        keep = jt.nms(dets, self.nms_thresh)
        if n_post_nms > 0:
            keep = keep[:n_post_nms]
        roi = roi[keep]
        return roi
 def test_vary_shape_setitem(self):
     a = jt.array([1, 2, 3, 4, 5])
     b = jt.array([1, 2, 3, 4, 5])
     c = jt.where(b > 3)
     a[c] = 0
     assert (a.data == [1, 2, 3, 0, 0]).all()
Beispiel #15
0
 def test_vary_shape_dep(self):
     a, = jt.where([1, 0, 1])
     b, = a.index_var()
     assert a.uncertain_shape == [-3] and b.uncertain_shape == [-3]
     assert (b.data == [0, 1]).all()
Beispiel #16
0
    def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'):
        mask_h = proto_data.shape[1]
        mask_w = proto_data.shape[2]


        process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop

        if cfg.mask_proto_remove_empty_masks:
            # Make sure to store a copy of this because we edit it to get rid of all-zero masks
            pos = pos.clone()

        loss_m = 0
        loss_d = 0 # Coefficient diversity loss

        maskiou_t_list = []
        maskiou_net_input_list = []
        label_t_list = []

        for idx in range(mask_data.shape[0]):
            with jt.no_grad():
                downsampled_masks = nn.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w),
                                                  mode=interpolation_mode, align_corners=False).squeeze(0)
                downsampled_masks = downsampled_masks.permute(1, 2, 0)

                if cfg.mask_proto_binarize_downsampled_gt:
                    downsampled_masks = (downsampled_masks>0.5).float()

                if cfg.mask_proto_remove_empty_masks:
                    # Get rid of gt masks that are so small they get downsampled away
                    very_small_masks = (downsampled_masks.sum(0).sum(0) <= 0.0001)
                    for i in range(very_small_masks.shape[0]):
                        if very_small_masks[i]:
                            pos[idx, idx_t[idx] == i] = 0

                if cfg.mask_proto_reweight_mask_loss:
                    # Ensure that the gt is binary
                    if not cfg.mask_proto_binarize_downsampled_gt:
                        bin_gt = (downsampled_masks>0.5).float()
                    else:
                        bin_gt = downsampled_masks

                    gt_foreground_norm = bin_gt     / (jt.sum(bin_gt,   dim=(0,1), keepdim=True) + 0.0001)
                    gt_background_norm = (1-bin_gt) / (jt.sum(1-bin_gt, dim=(0,1), keepdim=True) + 0.0001)

                    mask_reweighting   = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm
                    mask_reweighting  *= mask_h * mask_w

            cur_pos = pos[idx]
            cur_pos = jt.where(cur_pos)[0]
            pos_idx_t = idx_t[idx, cur_pos]
            
            if process_gt_bboxes:
                # Note: this is in point-form
                if cfg.mask_proto_crop_with_pred_box:
                    pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos]
                else:
                    pos_gt_box_t = gt_box_t[idx, cur_pos]

            if pos_idx_t.shape[0] == 0:
                continue

            proto_masks = proto_data[idx]
            proto_coef  = mask_data[idx, cur_pos, :]
            if cfg.use_mask_scoring:
                mask_scores = score_data[idx, cur_pos, :]

            if cfg.mask_proto_coeff_diversity_loss:
                if inst_data is not None:
                    div_coeffs = inst_data[idx, cur_pos, :]
                else:
                    div_coeffs = proto_coef

                loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)
            
            # If we have over the allowed number of masks, select a random sample
            old_num_pos = proto_coef.shape[0]
            if old_num_pos > cfg.masks_to_train:
                perm = jt.randperm(proto_coef.shape[0])
                select = perm[:cfg.masks_to_train]

                proto_coef = proto_coef[select, :]
                pos_idx_t  = pos_idx_t[select]
                
                if process_gt_bboxes:
                    pos_gt_box_t = pos_gt_box_t[select, :]
                if cfg.use_mask_scoring:
                    mask_scores = mask_scores[select, :]

            num_pos = proto_coef.shape[0]
            mask_t = downsampled_masks[:, :, pos_idx_t]     
            label_t = labels[idx][pos_idx_t]     

            # Size: [mask_h, mask_w, num_pos]
            pred_masks = proto_masks @ proto_coef.transpose(1,0)

            pred_masks = cfg.mask_proto_mask_activation(pred_masks)

            if cfg.mask_proto_double_loss:
                if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                    pre_loss = nn.bce_loss(jt.clamp(pred_masks, 0, 1), mask_t, size_average=False)
                else:
                    pre_loss = nn.smooth_l1_loss(pred_masks, mask_t, reduction='sum')
                
                loss_m += cfg.mask_proto_double_loss_alpha * pre_loss

            if cfg.mask_proto_crop:
                pred_masks = crop(pred_masks, pos_gt_box_t)
            
            if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                pre_loss = binary_cross_entropy(jt.clamp(pred_masks, 0, 1), mask_t)
            else:
                pre_loss = nn.smooth_l1_loss(pred_masks, mask_t, reduction='none')

            if cfg.mask_proto_normalize_mask_loss_by_sqrt_area:
                gt_area  = jt.sum(mask_t, dim=(0, 1), keepdims=True)
                pre_loss = pre_loss / (jt.sqrt(gt_area) + 0.0001)
            
            if cfg.mask_proto_reweight_mask_loss:
                pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t]
            
                
            if cfg.mask_proto_normalize_emulate_roi_pooling:
                weight = mask_h * mask_w if cfg.mask_proto_crop else 1
                pos_gt_csize = center_size(pos_gt_box_t)
                gt_box_width  = pos_gt_csize[:, 2] * mask_w
                gt_box_height = pos_gt_csize[:, 3] * mask_h
                pre_loss = pre_loss.sum(0).sum(0) / gt_box_width / gt_box_height * weight
            

            # If the number of masks were limited scale the loss accordingly
            if old_num_pos > num_pos:
                pre_loss *= old_num_pos / num_pos

            loss_m += jt.sum(pre_loss)

            if cfg.use_maskiou:
                if cfg.discard_mask_area > 0:
                    gt_mask_area = jt.sum(mask_t, dim=(0, 1))
                    select = gt_mask_area > cfg.discard_mask_area

                    if jt.sum(select).item() < 1:
                        continue

                    pos_gt_box_t = pos_gt_box_t[select, :]
                    pred_masks = pred_masks[:, :, select]
                    mask_t = mask_t[:, :, select]
                    label_t = label_t[select]

                maskiou_net_input = pred_masks.permute(2, 0, 1).unsqueeze(1)
                pred_masks = (pred_masks>0.5).float()                
                maskiou_t = self._mask_iou(pred_masks, mask_t)
                
                maskiou_net_input_list.append(maskiou_net_input)
                maskiou_t_list.append(maskiou_t)
                label_t_list.append(label_t)
        
        losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}

        if cfg.mask_proto_coeff_diversity_loss:
            losses['D'] = loss_d

        if cfg.use_maskiou:
            # discard_mask_area discarded every mask in the batch, so nothing to do here
            if len(maskiou_t_list) == 0:
                return losses, None

            maskiou_t = jt.contrib.concat(maskiou_t_list)
            label_t = jt.contrib.concat(label_t_list)
            maskiou_net_input = jt.contrib.concat(maskiou_net_input_list)

            num_samples = maskiou_t.shape[0]
            if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
                perm = jt.randperm(num_samples)
                select = perm[:cfg.masks_to_train]
                maskiou_t = maskiou_t[select]
                label_t = label_t[select]
                maskiou_net_input = maskiou_net_input[select]

            return losses, [maskiou_net_input, maskiou_t, label_t]

        return losses