Example #1
0
    def inference(self, head, x, proposals, valid_size, img_size):
        x = x[self.min_level:self.min_level + self.levels]

        if not proposals.all_none:
            # Run head on the given proposals
            proposals, proposals_idx = proposals.contiguous
            cls_logits, bbx_logits = self._head(head, x, proposals,
                                                proposals_idx, img_size)

            # Shift the proposals according to the logits
            bbx_reg_weights = x[0].new(self.bbx_reg_weights)
            boxes = shift_boxes(proposals.unsqueeze(1),
                                bbx_logits / bbx_reg_weights)
            scores = torch.softmax(cls_logits, dim=1)

            # Split boxes and scores by image, clip to valid size
            boxes, scores = self._split_and_clip(boxes, scores, proposals_idx,
                                                 valid_size)

            bbx_pred, cls_pred, obj_pred = self.prediction_generator(
                boxes, scores)
        else:
            bbx_pred = PackedSequence([None for _ in range(x[0].size(0))])
            cls_pred = PackedSequence([None for _ in range(x[0].size(0))])
            obj_pred = PackedSequence([None for _ in range(x[0].size(0))])

        return bbx_pred, cls_pred, obj_pred
Example #2
0
    def _match_to_lbl(self, proposals, bbx, cat, match):
        cls_lbl = []
        bbx_lbl = []
        for i, (proposals_i, bbx_i, cat_i,
                match_i) in enumerate(zip(proposals, bbx, cat, match)):
            if match_i is not None:
                pos = match_i >= 0

                # Objectness labels
                cls_lbl_i = proposals_i.new_zeros(proposals_i.size(0),
                                                  dtype=torch.long)
                cls_lbl_i[pos] = cat_i[match_i[pos]] + 1 - self.num_stuff

                # Bounding box regression labels
                if pos.any().item():
                    bbx_lbl_i = calculate_shift(proposals_i[pos],
                                                bbx_i[match_i[pos]])
                    bbx_lbl_i *= bbx_lbl_i.new(self.bbx_reg_weights)
                else:
                    bbx_lbl_i = None

                cls_lbl.append(cls_lbl_i)
                bbx_lbl.append(bbx_lbl_i)
            else:
                cls_lbl.append(None)
                bbx_lbl.append(None)

        return PackedSequence(cls_lbl), PackedSequence(bbx_lbl)
Example #3
0
    def _match_to_lbl(self, proposals, bbx, cat, ids, msk, match):
        cls_lbl = []
        bbx_lbl = []
        msk_lbl = []
        for i, (proposals_i, bbx_i, cat_i, ids_i, msk_i,
                match_i) in enumerate(zip(proposals, bbx, cat, ids, msk,
                                          match)):
            if match_i is not None:
                pos = match_i >= 0

                # Objectness labels
                cls_lbl_i = proposals_i.new_zeros(proposals_i.size(0),
                                                  dtype=torch.long)
                cls_lbl_i[pos] = cat_i[match_i[pos]] + 1 - self.num_stuff

                # Bounding box regression labels
                if pos.any().item():
                    bbx_lbl_i = calculate_shift(proposals_i[pos],
                                                bbx_i[match_i[pos]])
                    bbx_lbl_i *= bbx_lbl_i.new(self.bbx_reg_weights)

                    iis_lbl_i = ids_i[match_i[pos]]

                    # Compute instance segmentation masks
                    msk_i = roi_sampling(
                        msk_i.unsqueeze(0),
                        proposals_i[pos],
                        msk_i.new_zeros(pos.long().sum().item()),
                        self.lbl_roi_size,  #28*28
                        interpolation="nearest")

                    # Calculate mask segmentation labels
                    msk_lbl_i = (msk_i == iis_lbl_i.view(
                        -1, 1, 1, 1)).any(dim=1).to(torch.long)
                    if not self.void_is_background:
                        msk_lbl_i[(msk_i == 0).all(dim=1)] = -1
                else:
                    bbx_lbl_i = None
                    msk_lbl_i = None

                cls_lbl.append(cls_lbl_i)
                bbx_lbl.append(bbx_lbl_i)
                msk_lbl.append(msk_lbl_i)
            else:
                cls_lbl.append(None)
                bbx_lbl.append(None)
                msk_lbl.append(None)

        return PackedSequence(cls_lbl), PackedSequence(
            bbx_lbl), PackedSequence(msk_lbl)
Example #4
0
    def __call__(self, boxes, scores, training):
        """Perform NMS-based selection of proposals

        Parameters
        ----------
        boxes : torch.Tensor
            Tensor of bounding boxes with shape N x M x 4
        scores : torch.Tensor
            Tensor of bounding box scores with shape N x M
        training : bool
            Switch between training and validation modes

        Returns
        -------
        proposals : PackedSequence
            Sequence of N tensors of selected bounding boxes with shape M_i x 4, entries can be None
        """
        if training:
            num_pre_nms = self.num_pre_nms_train
            num_post_nms = self.num_post_nms_train
        else:
            num_pre_nms = self.num_pre_nms_val
            num_post_nms = self.num_post_nms_val

        proposals = []
        indices = []
        for bbx_i, obj_i in zip(boxes, scores):
            try:
                # Optional size pre-selection, remove very small boxes
                if self.min_size > 0:
                    bbx_size = bbx_i[:, 2:] - bbx_i[:, :2]
                    valid = (bbx_size[:, 0] >=
                             self.min_size) & (bbx_size[:, 1] >= self.min_size)

                    if valid.any().item():
                        bbx_i, obj_i = bbx_i[valid], obj_i[valid]
                    else:
                        raise Empty

                # Score pre-selection pick top num_pre_nms ones

                obj_i, idx = obj_i.topk(min(obj_i.size(0), num_pre_nms))

                bbx_i = bbx_i[idx]

                # NMS
                _idx = nms(bbx_i, obj_i, self.nms_threshold, num_post_nms)

                if _idx.numel() == 0:
                    raise Empty
                bbx_i = bbx_i[_idx]

                indices.append(idx[_idx])
                proposals.append(bbx_i)

            except Empty:
                indices.append(None)
                proposals.append(None)

        return PackedSequence(proposals), indices
Example #5
0
    def training(self, head, x, sem, valid_size, img_size):
        """Given input features and ground truth compute semantic segmentation loss, confusion matrix and prediction

        Parameters
        ----------
        head : torch.nn.Module
            Module to compute semantic segmentation logits given an input feature map. Must be callable as `head(x)`
        x : torch.Tensor
            A tensor of image features with shape N x C x H x W
        sem : sequence of torch.Tensor
            A sequence of N tensors of ground truth semantic segmentations with shapes H_i x W_i
        valid_size : list of tuple of int
            List of valid image sizes in input coordinates
        img_size : tuple of int
            Spatial size of the, possibly padded, image tensor used as input to the network that calculates x

        Returns
        -------
        sem_loss : torch.Tensor
            A scalar tensor with the computed loss
        conf_mat : torch.Tensor
            A confusion matrix tensor with shape M x M, where M is the number of classes
        sem_pred : PackedSequence
            A sequence of N tensors of semantic segmentations with shapes H_i x W_i
        """
        # Compute logits and prediction
        sem_logits = self._logits(head, x, valid_size, img_size)
        sem_pred = PackedSequence([sem_logits_i.max(dim=0)[1] for sem_logits_i in sem_logits])

        # Compute loss and confusion matrix
        sem_loss = self.loss(sem_logits, sem)
        conf_mat = self._confusion_matrix(sem_pred, sem)

        return sem_loss, conf_mat, sem_pred
Example #6
0
def iss_collate_fn(items):
    """Collate function for ISS batches"""
    out = {}
    if len(items) > 0:
        for key in items[0]:
            out[key] = [item[key] for item in items]
            if isinstance(items[0][key], torch.Tensor):
                out[key] = PackedSequence(out[key])
    return out
Example #7
0
    def inference(self, head, x, proposals, valid_size, img_size):
        x = x[self.min_level:self.min_level + self.levels]

        try:
            if proposals.all_none:
                raise Empty

            # Run head on the given proposals
            proposals, proposals_idx = proposals.contiguous
            cls_logits, bbx_logits, _ = self._head(head, x, proposals,
                                                   proposals_idx, img_size,
                                                   True, False)

            # Shift the proposals according to the logits
            bbx_reg_weights = x[0].new(self.bbx_reg_weights)
            boxes = shift_boxes(proposals.unsqueeze(1),
                                bbx_logits / bbx_reg_weights)
            scores = torch.softmax(cls_logits, dim=1)

            # Split boxes and scores by image, clip to valid size
            boxes, scores = self._split_and_clip(boxes, scores, proposals_idx,
                                                 valid_size)

            # Do nms to find final predictions
            bbx_pred, cls_pred, obj_pred = self.bbx_prediction_generator(
                boxes, scores)

            if bbx_pred.all_none:
                raise Empty

            # Run head again on the finalized boxes to compute instance masks
            proposals, proposals_idx = bbx_pred.contiguous
            _, _, msk_logits = self._head(head, x, proposals, proposals_idx,
                                          img_size, False, True)

            # Finalize instance mask computation
            msk_pred = self.msk_prediction_generator(cls_pred, msk_logits)
        except Empty:
            bbx_pred = PackedSequence([None for _ in range(x[0].size(0))])
            cls_pred = PackedSequence([None for _ in range(x[0].size(0))])
            obj_pred = PackedSequence([None for _ in range(x[0].size(0))])
            msk_pred = PackedSequence([None for _ in range(x[0].size(0))])

        return bbx_pred, cls_pred, obj_pred, msk_pred
Example #8
0
    def inference(self, head, x, valid_size, img_size):
        """Given input features compute semantic segmentation prediction

        Parameters
        ----------
        head : torch.nn.Module
            Module to compute semantic segmentation logits given an input feature map. Must be callable as `head(x)`
        x : torch.Tensor
            A tensor of image features with shape N x C x H x W
        valid_size : list of tuple of int
            List of valid image sizes in input coordinates
        img_size : tuple of int
            Spatial size of the, possibly padded, image tensor used as input to the network that calculates x

        Returns
        -------
        sem_pred : PackedSequence
            A sequence of N tensors of semantic segmentations with shapes H_i x W_i
        """
        sem_logits = self._logits(head, x, valid_size, img_size)
        sem_pred = PackedSequence([sem_logits_i.max(dim=0)[1] for sem_logits_i in sem_logits])
        return sem_pred
Example #9
0
    def __call__(self, cls_pred, msk_logits):
        """Compute mask predictions given mask logits and bounding box / class predictions

        Parameters
        ----------
        cls_pred : sequence of torch.Tensor
            A sequence of N tensors with shape S_i, each containing the predicted classes of the detections selected in
            each image, entries can be None.
        msk_logits : torch.Tensor
            A tensor with shape S x C x H x W containing the class-specific mask logits predicted for the instances
            in bbx_preds. Note that S = sum_i S_i.

        Returns
        -------
        msk_pred : PackedSequence
            A sequence of N tensors with shape S_i x H x W containing the mask logits for the detections in each
            image. Entries of `msk_preds` are None for images with no instances.
        """
        # Prepare output lists
        msk_pred = []

        last_it = 0
        for cls_pred_i in cls_pred:
            if cls_pred_i is not None:
                msk_pred_i = msk_logits[last_it:last_it + cls_pred_i.numel()]
                idx = torch.arange(0,
                                   cls_pred_i.numel(),
                                   dtype=torch.long,
                                   device=msk_pred_i.device)
                msk_pred_i = msk_pred_i[idx, cls_pred_i, ...]

                msk_pred.append(msk_pred_i)
                last_it += cls_pred_i.numel()
            else:
                msk_pred.append(None)

        return PackedSequence(msk_pred)
Example #10
0
    def __call__(self, boxes, scores):
        """Perform NMS-based selection of detections

        Parameters
        ----------
        boxes : sequence of torch.Tensor
            Sequence of N tensors of class-specific bounding boxes with shapes M_i x C x 4, entries can be None
        scores : sequence of torch.Tensor
            Sequence of N tensors of class probabilities with shapes M_i x (C + 1), entries can be None

        Returns
        -------
        bbx_pred : PackedSequence
            A sequence of N tensors of bounding boxes with shapes S_i x 4, entries are None for images in which no
            detection can be kept according to the selection parameters
        cls_pred : PackedSequence
            A sequence of N tensors of thing class predictions with shapes S_i, entries are None for images in which no
            detection can be kept according to the selection parameters
        obj_pred : PackedSequence
            A sequence of N tensors of detection confidences with shapes S_i, entries are None for images in which no
            detection can be kept according to the selection parameters
        """
        bbx_pred, cls_pred, obj_pred = [], [], []
        for bbx_i, obj_i in zip(boxes, scores):
            try:
                if bbx_i is None or obj_i is None:
                    raise Empty

                # Do NMS separately for each class
                bbx_pred_i, cls_pred_i, obj_pred_i = [], [], []
                for cls_id, (bbx_cls_i, obj_cls_i) in enumerate(
                        zip(torch.unbind(bbx_i, dim=1),
                            torch.unbind(obj_i, dim=1)[1:])):
                    # Filter out low-scoring predictions
                    idx = obj_cls_i > self.score_threshold
                    if not idx.any().item():
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]

                    # Filter out empty predictions
                    idx = (bbx_cls_i[:, 2] > bbx_cls_i[:, 0]) & (
                        bbx_cls_i[:, 3] > bbx_cls_i[:, 1])
                    if not idx.any().item():
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]

                    # Do NMS
                    idx = nms(bbx_cls_i.contiguous(),
                              obj_cls_i.contiguous(),
                              threshold=self.nms_threshold,
                              n_max=-1)
                    if idx.numel() == 0:
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]

                    # Save remaining outputs
                    bbx_pred_i.append(bbx_cls_i)
                    cls_pred_i.append(
                        bbx_cls_i.new_full((bbx_cls_i.size(0), ),
                                           cls_id,
                                           dtype=torch.long))
                    obj_pred_i.append(obj_cls_i)

                # Compact predictions from the classes
                if len(bbx_pred_i) == 0:
                    raise Empty
                bbx_pred_i = torch.cat(bbx_pred_i, dim=0)
                cls_pred_i = torch.cat(cls_pred_i, dim=0)
                obj_pred_i = torch.cat(obj_pred_i, dim=0)

                # Do post-NMS selection (if needed)
                if bbx_pred_i.size(0) > self.max_predictions:
                    _, idx = obj_pred_i.topk(self.max_predictions)
                    bbx_pred_i = bbx_pred_i[idx]
                    cls_pred_i = cls_pred_i[idx]
                    obj_pred_i = obj_pred_i[idx]

                # Save results
                bbx_pred.append(bbx_pred_i)
                cls_pred.append(cls_pred_i)
                obj_pred.append(obj_pred_i)
            except Empty:
                bbx_pred.append(None)
                cls_pred.append(None)
                obj_pred.append(None)

        return PackedSequence(bbx_pred), PackedSequence(
            cls_pred), PackedSequence(obj_pred)
Example #11
0
    def __call__(self, proposals, bbx, cat, iscrowd):
        """Match proposals to ground truth boxes

        Parameters
        ----------
        proposals : PackedSequence
            A sequence of N tensors with shapes P_i x 4 containing bounding box proposals, entries can be None
        bbx : sequence of torch.Tensor
            A sequence of N tensors with shapes K_i x 4 containing ground truth bounding boxes, entries can be None
        cat : sequence of torch.Tensor
            A sequence of N tensors with shapes K_i containing ground truth instance -> category mappings, entries can
            be None
        iscrowd : sequence of torch.Tensor
            Sequence of N tensors of ground truth crowd regions (shapes H_i x W_i), or ground truth crowd bounding boxes
            (shapes K_i x 4), entries can be None

        Returns
        -------
        out_proposals : PackedSequence
            A sequence of N tensors with shapes S_i x 4 containing the non-void bounding box proposals, entries are None
            for images that do not contain any non-void proposal
        match : PackedSequence
            A sequence of matching results with shape S_i, with the following semantic:
              - match[i, j] == -1: the j-th anchor in image i is negative
              - match[i, j] == k, k >= 0: the j-th anchor in image i is matched to bbx[i][k]
        """
        out_proposals = []
        match = []

        for proposals_i, bbx_i, cat_i, iscrowd_i in zip(
                proposals, bbx, cat, iscrowd):
            try:
                # Append proposals to ground truth bounding boxes before proceeding
                if bbx_i is not None and proposals_i is not None:
                    proposals_i = torch.cat([bbx_i, proposals_i], dim=0)
                elif bbx_i is not None:
                    proposals_i = bbx_i
                else:
                    raise Empty

                # Optionally check overlap with void
                if self.void_threshold != 0 and iscrowd_i is not None:
                    if iscrowd_i.dtype == torch.uint8:
                        overlap = mask_overlap(proposals_i, iscrowd_i)
                    else:
                        overlap = bbx_overlap(proposals_i, iscrowd_i)
                        overlap, _ = overlap.max(dim=1)

                    valid = overlap < self.void_threshold
                    proposals_i = proposals_i[valid]

                if proposals_i.size(0) == 0:
                    raise Empty

                # Find positives and negatives based on IoU
                if bbx_i is not None:
                    iou = ious(proposals_i, bbx_i)
                    best_iou, best_gt = iou.max(dim=1)

                    pos_idx = best_iou >= self.pos_threshold
                    neg_idx = (best_iou >= self.neg_threshold_lo) & (
                        best_iou < self.neg_threshold_hi)
                else:
                    # No ground truth boxes: all proposals that are non-void are negative
                    pos_idx = proposals_i.new_zeros(proposals_i.size(0),
                                                    dtype=torch.uint8)
                    neg_idx = proposals_i.new_ones(proposals_i.size(0),
                                                   dtype=torch.uint8)

                # Check that there are still some non-voids and do sub-sampling
                if not pos_idx.any().item() and not neg_idx.any().item():
                    raise Empty
                pos_idx, neg_idx = self._subsample(pos_idx, neg_idx)

                # Gather selected proposals
                out_proposals_i = proposals_i[torch.cat([pos_idx, neg_idx])]

                # Save matching
                match_i = out_proposals_i.new_full((out_proposals_i.size(0), ),
                                                   -1,
                                                   dtype=torch.long)
                match_i[:pos_idx.numel()] = best_gt[pos_idx]

                # Save to output
                out_proposals.append(out_proposals_i)
                match.append(match_i)
            except Empty:
                out_proposals.append(None)
                match.append(None)

        return PackedSequence(out_proposals), PackedSequence(match)
Example #12
0
    def inference(self, head, x, rpn_proposals, valid_size, img_size):
        x = x[self.min_level:self.min_level + self.levels]
        y = None

        try:
            if rpn_proposals.all_none:
                raise Empty

            for stage in range(len(self.stage_loss_weights)):

                print('__ {} __'.format(stage))
                current_head = head[stage]

                _proposals, _proposals_idx = rpn_proposals.contiguous

                cls_logits, bbx_logits, _, _ = self._head(
                    current_head, x, y, _proposals, _proposals_idx, img_size,
                    True, False)

                bbx_reg_weights = x[0].new(self.bbx_reg_weights)

                boxes = shift_boxes(_proposals.unsqueeze(1),
                                    bbx_logits / bbx_reg_weights)

                scores = torch.softmax(cls_logits, dim=1)

                boxes, scores = self._split_and_clip(boxes, scores,
                                                     _proposals_idx,
                                                     valid_size)

                # replicate indices
                indices = []
                idx = torch.arange(0, bbx_logits.size(0)).long()
                indices.append(idx.repeat(bbx_logits.size(1), 1).permute(1, 0))

                bbx_pred, cls_pred, obj_pred, indices_pred = self.bbx_prediction_generator(
                    boxes, scores, indices, stage)

                if bbx_pred.all_none:
                    raise Empty

                # Run head again on the regressed boxes to compute instance masks
                _proposals, _proposals_idx = bbx_pred.contiguous

                _, _, msk_logits, y = self._head(current_head, x, y,
                                                 _proposals, _proposals_idx,
                                                 img_size, False, True)

                # Finalize instance mask computation
                msk_pred = self.msk_prediction_generator(cls_pred, msk_logits)

                # Flip to replace repeated indices with best scoring bbox
                rpn_proposals_augument = []

                for _proposals_i, _indices_i, rpn_proposals_i in zip(
                        bbx_pred, indices_pred, rpn_proposals):
                    rpn_proposals_i[_indices_i.flip(0)] = _proposals_i.flip(0)
                    rpn_proposals_augument.append(rpn_proposals_i)

                rpn_proposals = PackedSequence(rpn_proposals_augument)

        except Empty:
            bbx_pred = PackedSequence([None for _ in range(x[0].size(0))])
            cls_pred = PackedSequence([None for _ in range(x[0].size(0))])
            obj_pred = PackedSequence([None for _ in range(x[0].size(0))])
            msk_pred = PackedSequence([None for _ in range(x[0].size(0))])

        return bbx_pred, cls_pred, obj_pred, msk_pred


# Run head on the given proposals
# proposals, proposals_idx = proposals.contiguous
# cls_logits, bbx_logits, _ = self._head(head, x, proposals, proposals_idx, img_size, True, False)

# Shift the proposals according to the logits
# bbx_reg_weights = x[0].new(self.bbx_reg_weights)
# boxes = shift_boxes(proposals.unsqueeze(1), bbx_logits / bbx_reg_weights)
# scores = torch.softmax(cls_logits, dim=1)

# Split boxes and scores by image, clip to valid size
# boxes, scores = self._split_and_clip(boxes, scores, proposals_idx, valid_size)

# Do nms to find final predictions
# bbx_pred, cls_pred, obj_pred = self.bbx_prediction_generator(boxes, scores)

# if bbx_pred.all_none:
#    raise Empty

# Run head again on the finalized boxes to compute instance masks
# proposals, proposals_idx = bbx_pred.contiguous
# _, _, msk_logits = self._head(head, x, proposals, proposals_idx, img_size, False, True)

# Finalize instance mask computation
# msk_pred = self.msk_prediction_generator(cls_pred, msk_logits)
Example #13
0
    def training(self, head, x, rpn_proposals, bbx, cat, iscrowd, ids, msk,
                 img_size, valid_size):
        x = x[self.min_level:self.min_level + self.levels]

        cls_loss_sum, bbx_loss_sum, msk_loss_sum = 0.0, 0.0, 0.0
        y = None

        try:
            #### Zero  pooling ###
            if rpn_proposals.all_none:
                raise Empty

            #Match proposals to ground truth
            with torch.no_grad():
                # get stage proposals Top 512 samples
                proposals, match, rpn_proposals, indices = self.proposal_matcher(
                    rpn_proposals, bbx, cat, iscrowd, 0)
                #np.savetxt('rpn_proposals_original', rpn_proposals[0].cpu(), fmt='%1.2f')
                cls_lbl, bbx_lbl, msk_lbl = self._match_to_lbl(
                    proposals, bbx, cat, ids, msk, match)

            if proposals.all_none:
                raise Empty

            _proposals, _proposals_idx = proposals.contiguous

            for stage in range(len(self.stage_loss_weights)):

                #print('__ {} __'.format(stage))

                # Current head and its  weighted loss coefficient lw
                current_head = head[stage]
                lw = self.stage_loss_weights[stage]

                ### Run Current head for BBox training & loss calculation  ###
                set_active_group(current_head, active_group(True))

                cls_logits, bbx_logits, _, _ = self._head(
                    current_head, x, y, _proposals, _proposals_idx, img_size,
                    True, False)

                cls_loss, bbx_loss = self.bbx_loss(cls_logits, bbx_logits,
                                                   cls_lbl, bbx_lbl)

                cls_loss_sum += lw * cls_loss
                bbx_loss_sum += lw * bbx_loss

                #print('cls_loss_{} ='.format(stage), cls_loss.item())
                #print('bbx_loss_{} ='.format(stage), bbx_loss.item())

                ### Get new proposals for next stage ###
                _proposals, _indices = self._inference(x, _proposals,
                                                       _proposals_idx,
                                                       bbx_logits, cls_logits,
                                                       valid_size, indices,
                                                       stage)

                if _proposals.all_none:
                    raise Empty

                # Flip to replace repeated indices with best scoring bbox
                rpn_proposals_augument = []

                for _proposals_i, _indices_i, rpn_proposals_i in zip(
                        _proposals, _indices, rpn_proposals):
                    rpn_proposals_i[_indices_i.flip(0)] = _proposals_i.flip(0)
                    rpn_proposals_augument.append(rpn_proposals_i)

                ### New proposals for next stage ###
                rpn_proposals = PackedSequence(rpn_proposals_augument)

                if rpn_proposals.all_none:
                    raise Empty

                #with torch.no_grad():
                #np.savetxt('rpn_proposals_augument_{}.txt'.format(stage), rpn_proposals_augument[0].cpu().numpy(),
                #          fmt='%1.2f')

                #### Stage pooling ###
                with torch.no_grad():

                    # get stage proposals Top 512 samples
                    proposals, match, rpn_proposals, indices = self.proposal_matcher(
                        rpn_proposals, bbx, cat, iscrowd, stage)
                    #np.savetxt('rpn_proposals_{}'.format(stage), rpn_proposals[0].cpu(), fmt='%1.2f')

                    cls_lbl, bbx_lbl, msk_lbl = self._match_to_lbl(
                        proposals, bbx, cat, ids, msk, match)

                # Run Previous head for Mask training & loss calculation
                _proposals, _proposals_idx = _proposals.contiguous

                set_active_group(current_head, active_group(True))

                _, _, msk_logits, y = self._head(current_head, x, y,
                                                 _proposals, _proposals_idx,
                                                 img_size, False, True)

                msk_loss = self.msk_loss(msk_logits, cls_lbl, msk_lbl)
                msk_loss_sum += lw * msk_loss

                #print('msk_loss_{} ='.format(stage), msk_loss.item())

        except Empty:
            active_group(False)
            cls_loss_sum = bbx_loss_sum = msk_loss_sum = sum(x_i.sum()
                                                             for x_i in x) * 0

        return cls_loss_sum, bbx_loss_sum, msk_loss_sum