def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)
        # logging.info(f"roi belong to which image: shape {rpn_ret['rois'][:, 0:1].shape}\
        #  \n {rpn_ret['rois'][:, 0]}")
        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            if self.weak_supervise:
                cls_score, det_score, bbox_pred = self.Box_Outs(box_feat)
            else:
                cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training and not self.weak_supervise:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        elif self.training and self.weak_supervise:
            # Weak supervision image-level loss
            # logging.info(f"image-level labels: shape {rpn_ret['image_labels_vec'].shape}\n {rpn_ret['image_labels_vec']}")
            # logging.info(f"cls score: shape {cls_score.shape}\n {cls_score}")
            # logging.info(f"det score: shape {det_score.shape}\n {det_score}")
            #
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            image_loss_cls, acc_score, reg = fast_rcnn_heads.image_level_loss(
                cls_score, det_score, rpn_ret['rois'],
                rpn_ret['image_labels_vec'], self.bceloss, box_feat)
            return_dict['losses']['image_loss_cls'] = image_loss_cls
            return_dict['losses']['spatial_reg'] = reg

            return_dict['metrics']['accuracy_cls'] = acc_score

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            ))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(res5_feat, rpn_ret,
                                               roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
Exemple #3
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)

            if cfg.RPN.VIS_QUANT_TARGET:
                import numpy as np
                import json
                import os
                import time

                gt_boxes = []
                gt_label = roidb[0]['gt_classes']
                for inds, item in enumerate(gt_label):
                    if item != 0:
                        gt_boxes.append(roidb[0]['boxes'][inds])

                gt_boxes = np.array(gt_boxes, dtype=np.float32)
                gt_boxes *= im_info.detach().numpy()[:, 2]

                path = "/nfs/project/libo_i/Boosting/Targets_Info"
                if not os.path.exists(path):
                    os.makedirs(path)
                b, c, h, w = rpn_kwargs['rpn_cls_logits_fpn3'].shape
                sample_targets = rpn_kwargs[
                    'rpn_bbox_targets_wide_fpn3'][:, :, :h, :w]

                line_targets = sample_targets.detach().data.cpu().numpy()
                with open(os.path.join(path, "quant_anchors.json"), "r") as fp:
                    quant_anchors = np.array(json.load(fp), dtype=np.float32)
                    quant_anchors = quant_anchors[:h, :w]

                line_targets = line_targets[:, 4:8, :, :].transpose(
                    (0, 2, 3, 1)).reshape(quant_anchors.shape)
                line_targets = line_targets.reshape(-1, 4)

                width = im_data.shape[3]
                height = im_data.shape[2]
                # 在这里加上targets的偏移

                line_quant_anchors = quant_anchors.reshape(-1, 4)
                pred_boxes = box_utils.onedim_bbox_transform(
                    line_quant_anchors, line_targets)
                pred_boxes = box_utils.clip_tiled_boxes(
                    pred_boxes, (height, width, 3))

                im = im_data.detach().cpu().numpy().reshape(3, height,
                                                            width).transpose(
                                                                (1, 2, 0))

                means = np.squeeze(cfg.PIXEL_MEANS)
                for i in range(3):
                    im[:, :, i] += means[i]

                im = im.astype(int)
                dpi = 200
                fig = plt.figure(frameon=False)
                fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
                ax = plt.Axes(fig, [0., 0., 1., 1.])
                ax.axis('off')
                fig.add_axes(ax)
                ax.imshow(im[:, :, ::-1])
                # 在im上添加gt
                for item in gt_boxes:
                    ax.add_patch(
                        plt.Rectangle((item[0], item[1]),
                                      item[2] - item[0],
                                      item[3] - item[1],
                                      fill=False,
                                      edgecolor='white',
                                      linewidth=1,
                                      alpha=1))

                cnt = 0
                for inds, before_item in enumerate(line_quant_anchors):
                    after_item = pred_boxes[inds]
                    targets_i = line_targets[inds]
                    if np.sum(targets_i) == 0:
                        continue
                    ax.add_patch(
                        plt.Rectangle((before_item[0], before_item[1]),
                                      before_item[2] - before_item[0],
                                      before_item[3] - before_item[1],
                                      fill=False,
                                      edgecolor='r',
                                      linewidth=1,
                                      alpha=1))

                    ax.add_patch(
                        plt.Rectangle((after_item[0], after_item[1]),
                                      after_item[2] - after_item[0],
                                      after_item[3] - after_item[1],
                                      fill=False,
                                      edgecolor='g',
                                      linewidth=1,
                                      alpha=1))

                    logger.info("valid boxes: {}".format(cnt))
                    cnt += 1

                if cnt != 0:
                    ticks = time.time()
                    fig.savefig(
                        "/nfs/project/libo_i/Boosting/Targets_Info/{}.png".
                        format(ticks),
                        dpi=dpi)

                plt.close('all')

            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            if cfg.RPN.ZEROLOSS:
                zero_loss_bbox = torch.Tensor([0.]).squeeze().cuda()
                zero_loss_bbox.requires_grad = True
                zero_loss_cls = torch.Tensor([0.]).squeeze().cuda()
                zero_loss_cls.requires_grad = True
                return_dict['losses']['loss_bbox'] = zero_loss_bbox
                return_dict['losses']['loss_cls'] = zero_loss_cls

            else:
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['losses']['loss_cls'] = loss_cls

            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
            if cfg.TEST.PROPOSALS_OUT:
                import os
                import json
                import numpy as np

                # 直接在这里做变换,输出经过变换之后的1000个框
                bbox_pred = bbox_pred.data.cpu().numpy().squeeze()
                box_deltas = bbox_pred.reshape([-1, bbox_pred.shape[-1]])
                shift_boxes = box_utils.bbox_transform(
                    rpn_ret['rois'][:, 1:5], box_deltas,
                    cfg.MODEL.BBOX_REG_WEIGHTS)
                shift_boxes = box_utils.clip_tiled_boxes(
                    shift_boxes,
                    im_info.data.cpu().numpy().squeeze()[0:2])

                num_classes = cfg.MODEL.NUM_CLASSES
                onecls_pred_boxes = []
                inds_all = []
                for j in range(1, num_classes):
                    inds = np.where(cls_score[:, j] > cfg.TEST.SCORE_THRESH)[0]
                    boxes_j = shift_boxes[inds, j * 4:(j + 1) * 4]
                    onecls_pred_boxes += boxes_j.tolist()
                    inds_all.extend(inds.tolist())

                inds_all = np.array(inds_all, dtype=np.int)
                aligned_proposals = rpn_ret['rois'][:, 1:5][inds_all]
                aligned_boxes = np.array(onecls_pred_boxes, dtype=np.float32)

                assert inds_all.shape[0] == aligned_boxes.shape[0]
                assert aligned_proposals.size == aligned_boxes.size

                path = "/nfs/project/libo_i/Boosting/Anchor_Info"
                with open(os.path.join(path, "proposals.json"), "w") as fp:
                    json.dump(aligned_proposals.tolist(), fp)

                with open(os.path.join(path, "boxes.json"), "w") as fp:
                    json.dump(aligned_boxes.tolist(), fp)

        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)
        '''
        anchors of different scale and aspect ratios will be collected
        the function self.RPN is rpn_heads.generic_rpn_outputs
        it will call FPN.fpn_rpn_outputs which is a class name, so it creates the class object
        '''
        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                '''in faster rcnn fpn r-50 case it is fast_rcnn_heads.roi_2mlp_head'''
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
    def forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        rois = Variable(torch.from_numpy(rpn_ret['rois'])).cuda(device_id)
        return_dict['rois'] = rois
        if self.training:
            return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            ))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            return_dict['loss_rpn_cls'] = loss_rpn_cls
            return_dict['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
            return_dict['loss_rcnn_cls'] = loss_cls
            return_dict['loss_rcnn_bbox'] = loss_bbox

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(res5_feat, rpn_ret,
                                               roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
                return_dict['loss_rcnn_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['loss_rcnn_keypoints'] = loss_keypoints

        return return_dict
    def _forward(self, data, im_info, roidb=None, rois=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        if cfg.MODEL.RELATION_NET_INPUT == 'GEO' and not cfg.REL_INFER.TRAIN and cfg.MODEL.NUM_RELATIONS > 0 and self.training:
            blob_conv = im_data
        else:
            blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)
        # If rpn_ret doesn't have rpn_ret, the rois should have been given, we use that.
        if 'rois' not in rpn_ret:
            rpn_ret['rois'] = rois
        if hasattr(self, '_ignore_classes') and 'labels_int32' in rpn_ret:
            # Turn ignore classes labels to 0, because they are treated as background
            rpn_ret['labels_int32'][np.isin(rpn_ret['labels_int32'],
                                            self._ignore_classes)] = 0

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                if cfg.MODEL.RELATION_NET_INPUT == 'GEO' and not cfg.REL_INFER.TRAIN and cfg.MODEL.NUM_RELATIONS > 0 and self.training:
                    box_feat = blob_conv.new_zeros(
                        rpn_ret['labels_int32'].shape[0],
                        self.Box_Head.dim_out)
                else:
                    box_feat = self.Box_Head(blob_conv, rpn_ret)
            if cfg.MODEL.RELATION_NET_INPUT == 'GEO' and not cfg.REL_INFER.TRAIN and cfg.MODEL.NUM_RELATIONS > 0 and self.training:
                cls_score = box_feat.new_zeros(
                    rpn_ret['labels_int32'].shape[0], cfg.MODEL.NUM_CLASSES)
                bbox_pred = box_feat.new_zeros(
                    rpn_ret['labels_int32'].shape[0],
                    4 * cfg.MODEL.NUM_CLASSES)
            else:
                cls_score, bbox_pred = self.Box_Outs(box_feat)
            if cfg.TEST.TAGGING:
                rois_label = torch.from_numpy(
                    rpn_ret['labels_int32'].astype('int64')).to(
                        cls_score.device)
                accuracy_cls = cls_score.max(1)[1].eq(rois_label).float().mean(
                    dim=0)
                print('Before refine:', accuracy_cls.item())

        if cfg.MODEL.NUM_RELATIONS > 0 and self.training:
            rel_scores, rel_labels = self.Rel_Outs(rpn_ret['rois'],
                                                   box_feat,
                                                   rpn_ret['labels_int32'],
                                                   roidb=roidb)
        elif cfg.MODEL.NUM_RELATIONS > 0 and cfg.TEST.USE_REL_INFER:
            if cfg.TEST.TAGGING:
                rel_scores = self.Rel_Outs(rpn_ret['rois'], box_feat)
                cls_score = self.Rel_Inf(cls_score, rel_scores, roidb=roidb)
            else:  # zero shot detection
                filt = (cls_score.topk(cfg.TEST.REL_INFER_PROPOSAL,
                                       1)[1] == 0).float().sum(1) == 0
                filt[cls_score[:,
                               1:].max(1)[0].topk(min(100, cls_score.shape[0]),
                                                  0)[1]] += 1
                filt = filt >= 2
                if filt.sum() == 0:
                    print('all background?')
                else:
                    tmp_rois = rpn_ret['rois'][filt.cpu().numpy().astype(
                        'bool')]

                    rel_scores = self.Rel_Outs(tmp_rois, box_feat[filt])
                    tmp_cls_score = self.Rel_Inf(cls_score[filt],
                                                 rel_scores,
                                                 roidb=None)
                    cls_score[filt] = tmp_cls_score
            if cfg.TEST.TAGGING:
                rois_label = torch.from_numpy(
                    rpn_ret['labels_int32'].astype('int64')).to(
                        cls_score.device)
                accuracy_cls = cls_score.max(1)[1].eq(rois_label).float().mean(
                    dim=0)
                print('After refine:', accuracy_cls.item())

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            if not cfg.MODEL.TAGGING:
                rpn_kwargs.update(
                    dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                         if (k.startswith('rpn_cls_logits')
                             or k.startswith('rpn_bbox_pred'))))
                loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                    **rpn_kwargs)
                if cfg.FPN.FPN_ON:
                    for i, lvl in enumerate(
                            range(cfg.FPN.RPN_MIN_LEVEL,
                                  cfg.FPN.RPN_MAX_LEVEL + 1)):
                        return_dict['losses']['loss_rpn_cls_fpn%d' %
                                              lvl] = loss_rpn_cls[i]
                        return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                              lvl] = loss_rpn_bbox[i]
                else:
                    return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                    return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            if not cfg.MODEL.RPN_ONLY:
                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls

                if hasattr(loss_cls, 'mean_similarity'):
                    return_dict['metrics'][
                        'mean_similarity'] = loss_cls.mean_similarity

                if cfg.FAST_RCNN.SAE_REGU:
                    tmp = cls_score.sae_loss[(rpn_ret['labels_int32'] >
                                              0).tolist()]
                    return_dict['losses']['loss_sae'] = 1e-4 * tmp.mean(
                    ) if tmp.numel() > 0 else tmp.new_tensor(0)
                    if torch.isnan(return_dict['losses']['loss_sae']):
                        import pdb
                        pdb.set_trace()

            if cfg.REL_INFER.TRAIN:
                return_dict['losses']['loss_rel_infer'] = cfg.REL_INFER.TRAIN_WEIGHT * \
                    self.Rel_Inf_Train(rpn_ret['rois'], rpn_ret['labels_int32'], cls_score, rel_scores, roidb=roidb)

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict.update(rpn_ret)
            if not cfg.MODEL.RPN_ONLY:
                if cfg.TEST.KEEP_HIGHEST:
                    cls_score = F.softmax(cls_score * 1e10, dim=1) * cls_score
                return_dict['cls_score'] = cls_score
                return_dict['bbox_pred'] = bbox_pred

        return return_dict
Exemple #7
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            # roidb: list, length = batch size
            # 'has_visible_keypoints': bool
            # 'boxes' & 'gt_classes': object bboxes and classes
            # 'segms', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map': pass
            # 'gt_actions': num_box*26
            # 'gt_role_id': num_box*26*2
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        # Original RPN module will generate proposals, and sample 256 positive/negative
        # examples in a 1:3 ratio for r-cnn stage. For InteractNet(hoi), I set
        # cfg.TRAIN.BATCH_SIZE_PER_IM and cfg.TRAIN.FG_FRACTION big value, to save
        # every proposal in rpn_ret, then I will re-sample from rpn_ret for three branch
        # of InteractNet, see roi_data/hoi_data.py for more information.
        if not cfg.VCOCO.USE_PRECOMP_BOX:
            rpn_ret = self.RPN(blob_conv, im_info, roidb)
            if cfg.MODEL.VCOCO_ON and self.training:
                # WARNING! always sample hoi branch before detection branch when training
                hoi_blob_in = sample_for_hoi_branch(rpn_ret,
                                                    roidb,
                                                    im_info,
                                                    is_training=True)
                # Re-sampling for RCNN head, rpn_ret will be modified inplace
                sample_for_detection_branch(rpn_ret)
        elif self.training:
            json_dataset.add_proposals(roidb,
                                       rois=None,
                                       im_info=im_info.data.numpy(),
                                       crowd_thresh=0)  #[:, 2]
            hoi_blob_in = sample_for_hoi_branch_precomp_box_train(
                roidb, im_info, is_training=True)
            if hoi_blob_in is None:
                return_dict['losses'] = {}
                return_dict['metrics'] = {}
                return_dict['losses'][
                    'loss_hoi_interaction_action'] = torch.tensor(
                        [0.]).cuda(device_id)
                return_dict['metrics'][
                    'accuracy_interaction_cls'] = torch.tensor(
                        [0.]).cuda(device_id)
                return_dict['losses'][
                    'loss_hoi_interaction_affinity'] = torch.tensor(
                        [0.]).cuda(device_id)
                return_dict['metrics'][
                    'accuracy_interaction_affinity'] = torch.tensor(
                        [0.]).cuda(device_id)
                return return_dict

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            #blob_conv = blob_conv[-self.num_roi_levels:]
            if cfg.FPN.MULTILEVEL_ROIS:
                blob_conv = blob_conv[-self.num_roi_levels:]
            else:
                blob_conv = blob_conv[-1]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.VCOCO.USE_PRECOMP_BOX:
            if not cfg.MODEL.RPN_ONLY:
                if cfg.MODEL.SHARE_RES5 and self.training:
                    box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
                else:
                    box_feat = self.Box_Head(blob_conv, rpn_ret)
                cls_score, bbox_pred = self.Box_Outs(box_feat)
            else:
                # TODO: complete the returns for RPN only situation
                pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            if not cfg.VCOCO.USE_PRECOMP_BOX:
                rpn_kwargs.update(
                    dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                         if (k.startswith('rpn_cls_logits')
                             or k.startswith('rpn_bbox_pred'))))
                loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                    **rpn_kwargs)
                if cfg.FPN.FPN_ON:
                    for i, lvl in enumerate(
                            range(cfg.FPN.RPN_MIN_LEVEL,
                                  cfg.FPN.RPN_MAX_LEVEL + 1)):
                        return_dict['losses']['loss_rpn_cls_fpn%d' %
                                              lvl] = loss_rpn_cls[i]
                        return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                              lvl] = loss_rpn_bbox[i]
                else:
                    return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                    return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls

                if cfg.MODEL.MASK_ON:
                    if getattr(self.Mask_Head, 'SHARE_RES5', False):
                        mask_feat = self.Mask_Head(
                            res5_feat,
                            rpn_ret,
                            roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                    else:
                        mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                    mask_pred = self.Mask_Outs(mask_feat)
                    # return_dict['mask_pred'] = mask_pred
                    # mask loss
                    loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                        mask_pred, rpn_ret['masks_int32'])
                    return_dict['losses']['loss_mask'] = loss_mask

                if cfg.MODEL.KEYPOINTS_ON:
                    if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                        # No corresponding keypoint head implemented yet (Neither in Detectron)
                        # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                        kps_feat = self.Keypoint_Head(
                            res5_feat,
                            rpn_ret,
                            roi_has_keypoints_int32=rpn_ret[
                                'roi_has_keypoint_int32'])
                    else:
                        kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                    kps_pred = self.Keypoint_Outs(kps_feat)
                    # return_dict['keypoints_pred'] = kps_pred
                    # keypoints loss
                    if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                        loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                            kps_pred, rpn_ret['keypoint_locations_int32'],
                            rpn_ret['keypoint_weights'])
                    else:
                        loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                            kps_pred, rpn_ret['keypoint_locations_int32'],
                            rpn_ret['keypoint_weights'],
                            rpn_ret['keypoint_loss_normalizer'])
                    return_dict['losses']['loss_kps'] = loss_keypoints

            if cfg.MODEL.VCOCO_ON:
                hoi_blob_out = self.HOI_Head(blob_conv, hoi_blob_in)

                interaction_action_loss, interaction_affinity_loss, \
                interaction_action_accuray_cls, interaction_affinity_cls = self.HOI_Head.loss(
                    hoi_blob_out)

                return_dict['losses'][
                    'loss_hoi_interaction_action'] = interaction_action_loss
                return_dict['metrics'][
                    'accuracy_interaction_cls'] = interaction_action_accuray_cls
                return_dict['losses'][
                    'loss_hoi_interaction_affinity'] = interaction_affinity_loss
                return_dict['metrics'][
                    'accuracy_interaction_affinity'] = interaction_affinity_cls

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            if not cfg.VCOCO.USE_PRECOMP_BOX:
                return_dict['rois'] = rpn_ret['rois']
                return_dict['cls_score'] = cls_score
                return_dict['bbox_pred'] = bbox_pred

        #print('return ready')
        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)
        # rpn proposals

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        rois_certification = False
        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if rois_certification:
            lvl_min = cfg.FPN.ROI_MIN_LEVEL
            lvl_max = cfg.FPN.ROI_MAX_LEVEL
            test_rpn_ret = {'rois': rpn_ret['rois']}
            lvls = fpn_utils.map_rois_to_fpn_levels(test_rpn_ret['rois'],
                                                    lvl_min, lvl_max)
            rois_idx_order = np.empty((0, ))
            test_rois = test_rpn_ret['rois']

            for output_idx, lvl in enumerate(range(lvl_min, lvl_max + 1)):
                idx_lvl = np.where(lvls == lvl)[0]
                rois_lvl = test_rois[idx_lvl, :]
                rois_idx_order = np.concatenate((rois_idx_order, idx_lvl))
                test_rpn_ret['rois_fpn{}'.format(lvl)] = rois_lvl

            rois_idx_restore = np.argsort(rois_idx_order).astype(np.int32,
                                                                 copy=False)
            test_rpn_ret['rois_idx_restore_int32'] = rois_idx_restore

            test_feat = self.Box_Head(blob_conv, test_rpn_ret)
            test_cls_score, test_bbox_pred = self.Box_Outs(test_feat)

            test_cls_score = test_cls_score.data.cpu().numpy().squeeze()
            test_bbox_pred = test_bbox_pred.data.cpu().numpy().squeeze()

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            # bbox proposals
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        # 在这里开始计算loss
        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            import json
            if cfg.TEST.IOU_OUT:
                # 直接通过rpn_ret可以取出rois
                with open(
                        "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/raw_roi.json",
                        'w') as f:
                    json.dump((return_dict['rois'][:, 1:] /
                               im_info.numpy()[0][2]).tolist(), f)

                # 如果在FPN模式下,需要进到一个collect_and_distribute...的函数去取出分配后的scores
                # ,我直接在collect_and_distribute_fpn_rpn_proposals.py里把json输出
                # 因此这里直接考虑RPN_ONLY模式的取值。
                if not cfg.FPN.FPN_ON:
                    with open(
                            "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/rois_score.json",
                            'w') as f:
                        score_2_json = []
                        for item in rpn_ret['rpn_roi_probs']:
                            score_2_json.append(item.item())
                        json.dump(score_2_json, f)

            # 开始第二个fast_head阶段,首先通过rois和bbox_delta计算pred_box
            if cfg.FAST_RCNN.FAST_HEAD2_DEBUG:
                lvl_min = cfg.FPN.ROI_MIN_LEVEL
                lvl_max = cfg.FPN.ROI_MAX_LEVEL
                if cfg.FPN.FPN_ON:
                    im_scale = im_info.data.cpu().numpy().squeeze()[2]
                    rois = rpn_ret['rois'][:, 1:5] / im_scale
                    bbox_pred = bbox_pred.data.cpu().numpy().squeeze()
                    box_deltas = bbox_pred.reshape([-1, bbox_pred.shape[-1]])
                    shift_boxes = box_utils.bbox_transform(
                        rois, box_deltas, cfg.MODEL.BBOX_REG_WEIGHTS)
                    shift_boxes = box_utils.clip_tiled_boxes(
                        shift_boxes,
                        im_info.data.cpu().numpy().squeeze()[0:2])
                    num_classes = cfg.MODEL.NUM_CLASSES

                    onecls_pred_boxes = []
                    onecls_score = []
                    dets_cls = {}
                    count = 0
                    for j in range(1, num_classes):
                        inds = np.where(
                            cls_score[:, j] > cfg.TEST.SCORE_THRESH)[0]
                        boxes_j = shift_boxes[inds, j * 4:(j + 1) * 4]
                        score_j = cls_score[inds, j]
                        onecls_pred_boxes += boxes_j.tolist()
                        onecls_score += score_j.tolist()
                        dets_cls.update({j: []})
                        for k in range(len(boxes_j.tolist())):
                            dets_cls[j].append(count)
                            count += 1

                    assert count == len(onecls_pred_boxes)
                    stage2_rois_score = np.array(onecls_score,
                                                 dtype=np.float32)
                    stage2_rois = np.array(onecls_pred_boxes, dtype=np.float32)

                    # Redistribute stage2_rois using fpn_utils module provided functions
                    # calculate by formula
                    cls_tracker = {}
                    if not stage2_rois.tolist():
                        stage1_pred_iou = stage2_rois_score.tolist()
                        stage2_final_boxes = np.empty((0, ))
                        stage2_final_score = np.empty((0, ))

                        logger.info("Detections above threshold is null.")
                    else:
                        alter_rpn = {}
                        unresize_stage2_rois = stage2_rois * im_scale
                        # unresize_stage2_rois = np.concatenate((unresize_stage2_rois, unresized_rois[:, 1:5]))

                        lvls = fpn_utils.map_rois_to_fpn_levels(
                            unresize_stage2_rois, lvl_min, lvl_max)
                        # TAG: We might need to visualize "stage2_rois" to make sure.
                        rois_idx_order = np.empty((0, ))
                        dummy_batch = np.zeros(
                            (unresize_stage2_rois.shape[0], 1),
                            dtype=np.float32)
                        alter_rpn["rois"] = np.hstack(
                            (dummy_batch,
                             unresize_stage2_rois)).astype(np.float32,
                                                           copy=False)
                        # alter_rpn['rois'] = np.concatenate((alter_rpn['rois'], unresized_rois))

                        for output_idx, lvl in enumerate(
                                range(lvl_min, lvl_max + 1)):
                            idx_lvl = np.where(lvls == lvl)[0]
                            rois_lvl = unresize_stage2_rois[idx_lvl, :]
                            rois_idx_order = np.concatenate(
                                (rois_idx_order, idx_lvl))
                            _ = np.zeros((rois_lvl.shape[0], 1),
                                         dtype=np.float32)
                            alter_rpn['rois_fpn{}'.format(lvl)] = np.hstack(
                                (_, rois_lvl)).astype(np.float32, copy=False)

                        rois_idx_restore = np.argsort(rois_idx_order).astype(
                            np.int32, copy=False)
                        alter_rpn['rois_idx_restore_int32'] = rois_idx_restore
                        # Go through 2nd stage of FPN and fast_head
                        stage2_feat = self.Box_Head(blob_conv, alter_rpn)
                        stage2_cls_score, stage2_bbox_pred = self.Box_Outs(
                            stage2_feat)

                        # Transform shift value to original one to get final pred boxes coordinates
                        stage2_bbox_pred = stage2_bbox_pred.data.cpu().numpy(
                        ).squeeze()
                        stage2_cls_score = stage2_cls_score.data.cpu().numpy()

                        stage2_box_deltas = stage2_bbox_pred.reshape(
                            [-1, bbox_pred.shape[-1]])
                        # Add some variance to box delta
                        if cfg.FAST_RCNN.STAGE1_TURBULENCE:
                            import random
                            for i in range(len(stage2_box_deltas)):
                                for j in range(len(stage2_box_deltas[i])):
                                    stage2_box_deltas[i][j] *= random.uniform(
                                        0.9, 1.1)

                        stage2_cls_out = box_utils.bbox_transform(
                            stage2_rois, stage2_box_deltas,
                            cfg.MODEL.BBOX_REG_WEIGHTS)
                        stage2_cls_out = box_utils.clip_tiled_boxes(
                            stage2_cls_out,
                            im_info.data.cpu().numpy().squeeze()[0:2])
                        onecls_pred_boxes = []
                        onecls_score = []
                        count = 0
                        for j in range(1, num_classes):
                            inds = np.where(
                                stage2_cls_score[:,
                                                 j] > cfg.TEST.SCORE_THRESH)[0]
                            boxes_j = stage2_cls_out[inds, j * 4:(j + 1) * 4]
                            score_j = stage2_cls_score[inds, j]
                            dets_j = np.hstack(
                                (boxes_j,
                                 score_j[:, np.newaxis])).astype(np.float32,
                                                                 copy=False)
                            keep = box_utils.nms(dets_j, cfg.TEST.NMS)
                            boxes_j = boxes_j[keep]
                            score_j = score_j[keep]
                            # 用于记录每个框属于第几类
                            onecls_score += score_j.tolist()
                            onecls_pred_boxes += boxes_j.tolist()

                            for k in range(len(score_j)):
                                cls_tracker.update({count: j})
                                count += 1

                        assert count == len(onecls_score)
                        stage2_final_boxes = np.array(onecls_pred_boxes,
                                                      dtype=np.float32)
                        stage2_final_score = np.array(onecls_score,
                                                      dtype=np.float32)
                        inds = np.where(stage2_final_score > 0.3)[0]

                        # Filtered by keep index...
                        preserve_stage2_final_boxes = copy.deepcopy(
                            stage2_final_boxes)
                        preserve_stage2_final_score = copy.deepcopy(
                            stage2_final_score)
                        stage2_final_boxes = stage2_final_boxes[inds]
                        stage2_final_score = stage2_final_score[inds]

                        # if nothing left after 0.3 threshold filter, reserve whole boxes to original.
                        if stage2_final_boxes.size == 0:
                            lower_inds = np.where(
                                preserve_stage2_final_score > 0.1)[0]
                            stage2_final_boxes = preserve_stage2_final_boxes[
                                lower_inds]
                            stage2_final_score = preserve_stage2_final_score[
                                lower_inds]

                        else:
                            del preserve_stage2_final_boxes
                            del preserve_stage2_final_score

                        # if all boxes are clsfied into bg class.
                        if stage2_final_boxes.size == 0:
                            stage1_pred_iou = stage2_rois_score.tolist()
                            stage2_final_boxes = np.empty((0, ))
                            stage2_final_score = np.empty((0, ))
                            logger.info("Detections above threshold is null.")

                        else:
                            # Restore stage2_pred_boxes to match the index with stage2_rois, Compute IOU between
                            # final_boxes and stage2_rois, one by one
                            flag = "cross_product"
                            if flag == "element_wise":
                                if stage2_final_boxes.shape[
                                        0] == stage2_rois.shape[0]:
                                    restored_stage2_final_boxes = stage2_final_boxes[
                                        rois_idx_restore]
                                    stage1_pred_iou = []
                                    for ind, item in enumerate(stage2_rois):
                                        stage1 = np.array(
                                            item, dtype=np.float32).reshape(
                                                (1, 4))
                                        stage2 = np.array(
                                            restored_stage2_final_boxes[ind],
                                            dtype=np.float32).reshape((1, 4))
                                        iou = box_utils.bbox_overlaps(
                                            stage1, stage2)
                                        stage1_pred_iou.append(
                                            iou.squeeze().item())
                                else:
                                    logger.info(
                                        "Mistake while processing {}".format(
                                            str(im_info)))
                            elif flag == "cross_product":
                                iou = box_utils.bbox_overlaps(
                                    stage2_rois, stage2_final_boxes)
                                stage1_pred_iou = iou.max(axis=1).tolist()

                    # stage1_pred is another name of stage2_rois
                    assert len(stage1_pred_iou) == len(stage2_rois)
                    if cfg.FAST_RCNN.IOU_NMS:
                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/iou_stage1_score.json",
                                "w") as f:
                            json.dump(stage2_rois_score.tolist(), f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/iou_stage2_score.json",
                                "w") as f:
                            json.dump(stage2_final_score.tolist(), f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/iou_stage1_pred_boxes.json",
                                'w') as f:
                            json.dump(stage2_rois.tolist(), f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/iou_stage1_pred_iou.json",
                                'w') as f:
                            json.dump(stage1_pred_iou, f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/iou_stage2_pred_boxes.json",
                                'w') as f:
                            json.dump(stage2_final_boxes.tolist(), f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/iou_dets_cls.json",
                                'w') as f:
                            json.dump(dets_cls, f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/iou_cls_tracker.json",
                                'w') as f:
                            json.dump(cls_tracker, f)

                    elif cfg.FAST_RCNN.SCORE_NMS:
                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/score_stage1_score.json",
                                "w") as f:
                            json.dump(stage2_rois_score.tolist(), f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/score_stage2_score.json",
                                "w") as f:
                            json.dump(stage2_final_score.tolist(), f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/score_stage1_pred_boxes.json",
                                'w') as f:
                            json.dump(stage2_rois.tolist(), f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/score_stage1_pred_iou.json",
                                'w') as f:
                            json.dump(stage1_pred_iou, f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/score_stage2_pred_boxes.json",
                                'w') as f:
                            json.dump(stage2_final_boxes.tolist(), f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/score_dets_cls.json",
                                'w') as f:
                            json.dump(dets_cls, f)

                        with open(
                                "/nfs/project/libo_i/IOU.pytorch/IOU_Validation/score_cls_tracker.json",
                                'w') as f:
                            json.dump(cls_tracker, f)

                else:
                    im_scale = im_info.data.cpu().numpy().squeeze()[2]
                    rois = rpn_ret['rois'][:, 1:5] / im_scale
                    # unscale back to raw image space
                    box_deltas = bbox_pred.data.cpu().numpy().squeeze()
                    fast_stage1_score = cls_score.data.cpu().numpy().squeeze()

                    box_deltas = box_deltas.reshape([-1, bbox_pred.shape[-1]])
                    stage2_rois = box_utils.bbox_transform(
                        rois, box_deltas, cfg.MODEL.BBOX_REG_WEIGHTS)
                    stage2_rois = box_utils.clip_tiled_boxes(
                        stage2_rois,
                        im_info.data.cpu().numpy().squeeze()[0:2])

                    num_classes = cfg.MODEL.NUM_CLASSES

                    onecls_pred_boxes = []
                    onecls_cls_score = []
                    for j in range(1, num_classes):
                        inds = np.where(
                            cls_score[:, j] > cfg.TEST.SCORE_THRESH)[0]
                        boxes_j = stage2_rois[inds, j * 4:(j + 1) * 4]
                        score_j = fast_stage1_score[inds, j]
                        onecls_pred_boxes += boxes_j.tolist()
                        onecls_cls_score += score_j.tolist()

                    stage2_rois = np.array(onecls_pred_boxes, dtype=np.float32)
                    stage2_rois_score = np.array(onecls_cls_score,
                                                 dtype=np.float32)

                    assert len(stage2_rois) == len(stage2_rois_score)

                    # Send stage2 rois to next stage fast head, do ROI ALIGN again
                    # to modify rpn_ret['rois] , rpn_ret['rpn_rois'] and rpn['rois_rpn_score']

                    rpn_ret['rois'] = stage2_rois
                    rpn_ret['rpn_rois'] = stage2_rois
                    rpn_ret['rpn_roi_probs'] = stage2_rois_score
                    stage2_box_feat = self.Box_Head(blob_conv, rpn_ret)
                    stage2_cls_score, stage2_bbox_pred = self.Box_Outs(
                        stage2_box_feat)

                    stage2_bbox_pred = stage2_bbox_pred.data.cpu().numpy(
                    ).squeeze()
                    stage2_bbox_pred = stage2_bbox_pred.reshape(
                        [-1, bbox_pred.shape[-1]])

                    stage2_cls_pred_boxes = box_utils.bbox_transform(
                        stage2_rois, stage2_bbox_pred,
                        cfg.MODEL.BBOX_REG_WEIGHTS)
                    stage2_cls_pred_boxes = box_utils.clip_tiled_boxes(
                        stage2_cls_pred_boxes,
                        im_info.data.cpu().numpy().squeeze()[0:2])

                    onecls_pred_boxes = []
                    onecls_cls_score = []
                    for j in range(1, num_classes):
                        inds = np.where(
                            stage2_cls_score[:, j] > cfg.TEST.SCORE_THRESH)[0]
                        if len(inds) != 0:
                            print("KKKKK")
                        boxes_j = stage2_cls_pred_boxes[inds,
                                                        j * 4:(j + 1) * 4]
                        score_j = stage2_cls_score[inds, j]
                        onecls_pred_boxes += boxes_j.tolist()
                        onecls_cls_score += score_j.tolist()

                    stage2_bbox_pred = np.array(onecls_pred_boxes,
                                                dtype=np.float32)
                    stage2_bbox_pred_score = np.array(onecls_cls_score,
                                                      dtype=np.float32)

        # get stage2 pred_boxes here

        return_dict['cls_score'] = cls_score
        return_dict['bbox_pred'] = bbox_pred
        return return_dict
Exemple #9
0
    def forward(self, data, im_info=None, roidb=None, only_body=None, **rpn_kwargs):
        im_data = data
        if self.training and cfg.RPN.RPN_ON:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        if cfg.RPN.RPN_ON:
            rpn_ret = self.RPN(blob_conv, im_info, roidb)
        else:
            rpn_ret = rpn_kwargs

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if only_body is not None:
            return return_dict

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            if cfg.RPN.RPN_ON:
                rpn_kwargs.update(dict(
                    (k, rpn_ret[k]) for k in rpn_ret.keys()
                    if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
                ))
                loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
                if cfg.FPN.FPN_ON:
                    for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                        return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                        return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
                else:
                    return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                    return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(res5_feat, rpn_ret,
                                               roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                if cfg.MRCNN.FUSION:
                    mask_pred = mask_feat
                elif cfg.MODEL.BOUNDARY_ON:
                    mask_pred = mask_feat[0]
                    boundary_pred = mask_feat[1]
                else:
                    mask_pred = mask_feat
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask
                if cfg.MODEL.BOUNDARY_ON:
                    loss_boundary = mask_rcnn_heads.mask_rcnn_losses_balanced(boundary_pred, rpn_ret['boundary_int32'])
                    return_dict['losses']['loss_boundary'] = loss_boundary

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints
        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
Exemple #10
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()
        return_dict = {}  # A dict to collect return variables

        if cfg.LESION.USE_POSITION:
            blob_conv, res5_feat = self.Conv_Body(im_data)
        elif cfg.LESION.SHALLOW_POSITION:
            blob_conv, pos_cls_pred, pos_reg_pred = self.Conv_Body(im_data)
        else:
            blob_conv = self.Conv_Body(im_data)

        if cfg.MODEL.LR_VIEW_ON or cfg.MODEL.GIF_ON or cfg.MODEL.LRASY_MAHA_ON:
            blob_conv = self._get_lrview_blob_conv(blob_conv)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
            # print cls_score.shape
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

            if cfg.LESION.USE_POSITION:
                position_feat = self.Position_Head(res5_feat)
                pos_cls_pred = self.Position_Cls_Outs(position_feat)
                pos_reg_pred = self.Position_Reg_Outs(position_feat)
                return_dict['pos_cls_pred'] = pos_cls_pred
                return_dict['pos_reg_pred'] = pos_reg_pred
            if cfg.LESION.SHALLOW_POSITION:
                return_dict['pos_cls_pred'] = pos_cls_pred
                return_dict['pos_reg_pred'] = pos_reg_pred
            if cfg.LESION.POSITION_RCNN:
                pos_cls_rcnn, pos_reg_rcnn = self.Position_RCNN(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            if cfg.MODEL.FASTER_RCNN:
                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls
            # RCNN Position
            if cfg.LESION.POSITION_RCNN:
                pos_cls_rcnn_loss, pos_reg_rcnn_loss, accuracy_position_rcnn = position_rcnn_losses(
                    pos_cls_rcnn, pos_reg_rcnn, roidb)
                return_dict['losses']['RcnnPosCls_loss'] = pos_cls_rcnn_loss
                return_dict['losses']['RcnnPosReg_loss'] = pos_reg_rcnn_loss
                return_dict['metrics'][
                    'accuracy_position_rcnn'] = accuracy_position_rcnn
            # Shallow Position Branch
            elif cfg.LESION.SHALLOW_POSITION:
                pos_cls_loss, pos_reg_loss, accuracy_position = position_losses(
                    pos_cls_pred, pos_reg_pred, roidb)
                #pos_reg_loss = position_reg_losses(reg_pred, roidb)
                return_dict['losses']['pos_cls_loss'] = pos_cls_loss
                return_dict['losses']['pos_reg_loss'] = pos_reg_loss
                return_dict['metrics']['accuracy_position'] = accuracy_position
            # Position Branch
            elif cfg.LESION.USE_POSITION:
                position_feat = self.Position_Head(res5_feat)
                pos_cls_pred = self.Position_Cls_Outs(position_feat)
                pos_reg_pred = self.Position_Reg_Outs(position_feat)
                pos_cls_loss, pos_reg_loss, accuracy_position = position_losses(
                    pos_cls_pred, pos_reg_pred, roidb)
                #pos_reg_loss = position_reg_losses(reg_pred, roidb)
                return_dict['losses']['pos_cls_loss'] = pos_cls_loss
                return_dict['losses']['pos_reg_loss'] = pos_reg_loss
                return_dict['metrics']['accuracy_position'] = accuracy_position

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        return return_dict
    def _forward(self, data, support_data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))
        device_id = im_data.get_device()
        return_dict = {}  # A dict to collect return variables
        original_blob_conv = self.Conv_Body(im_data)

        if not self.training:
            all_cls = roidb[0]['support_cls']
            test_way = len(all_cls)
            test_shot = int(roidb[0]['support_boxes'].shape[0] / test_way)
        else:
            train_way = 2  #2
            train_shot = 5  #5
        support_blob_conv = self.Conv_Body(support_data.squeeze(0))

        img_num = int(support_blob_conv.shape[0])
        img_channel = int(original_blob_conv.shape[1])
        # Construct support rpn_ret
        support_rpn_ret = {
            'rois': np.insert(roidb[0]['support_boxes'][0], 0,
                              0.)[np.newaxis, :]
        }
        if img_num > 1:
            for i in range(img_num - 1):
                support_rpn_ret['rois'] = np.concatenate(
                    (support_rpn_ret['rois'],
                     np.insert(roidb[0]['support_boxes'][i + 1], 0,
                               float(i + 1))[np.newaxis, :]),
                    axis=0)
        # Get support pooled feature
        support_feature = self.roi_feature_transform(
            support_blob_conv,
            support_rpn_ret,
            blob_rois='rois',
            method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
            resolution=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION,
            spatial_scale=self.Conv_Body.spatial_scale,
            sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO)

        blob_conv = original_blob_conv
        rpn_ret = self.RPN(blob_conv, im_info, roidb)
        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        assert not cfg.MODEL.RPN_ONLY

        support_box_feat = self.Box_Head(support_blob_conv, support_rpn_ret)
        if self.training:
            support_feature_mean_0 = support_feature[:5].mean(0, True)
            support_pool_feature_0 = self.avgpool(support_feature_mean_0)
            correlation_0 = F.conv2d(original_blob_conv,
                                     support_pool_feature_0.permute(
                                         1, 0, 2, 3),
                                     groups=1024)
            rpn_ret_0 = self.RPN(correlation_0, im_info, roidb)
            box_feat_0 = self.Box_Head(blob_conv, rpn_ret_0)
            support_0 = support_box_feat[:5].mean(
                0, True)  # simple average few shot support features

            support_feature_mean_1 = support_feature[5:10].mean(0, True)
            support_pool_feature_1 = self.avgpool(support_feature_mean_1)
            correlation_1 = F.conv2d(original_blob_conv,
                                     support_pool_feature_1.permute(
                                         1, 0, 2, 3),
                                     groups=1024)
            rpn_ret_1 = self.RPN(correlation_1, im_info, roidb)
            box_feat_1 = self.Box_Head(blob_conv, rpn_ret_1)
            support_1 = support_box_feat[5:10].mean(
                0, True)  # simple average few shot support features

            cls_score_now_0, bbox_pred_now_0 = self.Box_Outs(
                box_feat_0, support_0)
            cls_score_now_1, bbox_pred_now_1 = self.Box_Outs(
                box_feat_1, support_1)

            cls_score = torch.cat([cls_score_now_0, cls_score_now_1], dim=0)
            bbox_pred = torch.cat([bbox_pred_now_0, bbox_pred_now_1], dim=0)
            rpn_ret = {}
            for key in rpn_ret_0.keys():
                if key != 'rpn_cls_logits' and key != 'rpn_bbox_pred':
                    rpn_ret[key] = rpn_ret_0[
                        key]  #np.concatenate((rpn_ret_0[key], rpn_ret_1[key]), axis=0)
                else:
                    rpn_ret[key] = torch.cat([rpn_ret_0[key], rpn_ret_1[key]],
                                             dim=0)

        else:
            for way_id in range(test_way):
                begin = way_id * test_shot
                end = (way_id + 1) * test_shot

                support_feature_mean = support_feature[begin:end].mean(0, True)
                support_pool_feature = self.avgpool(support_feature_mean)
                correlation = F.conv2d(original_blob_conv,
                                       support_pool_feature.permute(
                                           1, 0, 2, 3),
                                       groups=1024)
                rpn_ret = self.RPN(correlation, im_info, roidb)

                if not cfg.MODEL.RPN_ONLY:
                    if cfg.MODEL.SHARE_RES5 and self.training:
                        box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
                    else:
                        box_feat = self.Box_Head(blob_conv, rpn_ret)

                    support = support_box_feat[begin:end].mean(
                        0, True)  # simple average few shot support features

                    cls_score_now, bbox_pred_now = self.Box_Outs(
                        box_feat, support)
                    cls_now = cls_score_now.new_full(
                        (cls_score_now.shape[0], 1),
                        int(roidb[0]['support_cls'][way_id]))
                    cls_score_now = torch.cat([cls_score_now, cls_now], dim=1)
                    rois = rpn_ret['rois']
                    if way_id == 0:
                        cls_score = cls_score_now
                        bbox_pred = bbox_pred_now
                        rois_all = rois
                    else:
                        cls_score = torch.cat([cls_score, cls_score_now],
                                              dim=0)
                        bbox_pred = torch.cat([bbox_pred, bbox_pred_now],
                                              dim=0)
                        rois_all = np.concatenate((rois_all, rois), axis=0)
                else:
                    # TODO: complete the returns for RPN only situation
                    pass
            rpn_ret['rois'] = rois_all
        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            target_cls = roidb[0]['target_cls']
            rpn_ret['labels_obj'] = np.array(
                [int(i == target_cls) for i in rpn_ret['labels_int32']])

            # filter other class bbox targets, only supervise the target cls bbox, because the bbox_pred is from the compare feature.
            bg_idx = np.where(rpn_ret['labels_int32'] != target_cls)[0]
            rpn_ret['bbox_targets'][bg_idx] = np.full_like(
                rpn_ret['bbox_targets'][bg_idx], 0.)
            rpn_ret['bbox_inside_weights'][bg_idx] = np.full_like(
                rpn_ret['bbox_inside_weights'][bg_idx], 0.)
            rpn_ret['bbox_outside_weights'][bg_idx] = np.full_like(
                rpn_ret['bbox_outside_weights'][bg_idx], 0.)

            neg_labels_obj = np.full_like(rpn_ret['labels_obj'], 0.)
            neg_bbox_targets = np.full_like(rpn_ret['bbox_targets'], 0.)
            neg_bbox_inside_weights = np.full_like(
                rpn_ret['bbox_inside_weights'], 0.)
            neg_bbox_outside_weights = np.full_like(
                rpn_ret['bbox_outside_weights'], 0.)

            rpn_ret['labels_obj'] = np.concatenate(
                [rpn_ret['labels_obj'], neg_labels_obj], axis=0)
            rpn_ret['bbox_targets'] = np.concatenate(
                [rpn_ret['bbox_targets'], neg_bbox_targets], axis=0)
            rpn_ret['bbox_inside_weights'] = np.concatenate(
                [rpn_ret['bbox_inside_weights'], neg_bbox_inside_weights],
                axis=0)
            rpn_ret['bbox_outside_weights'] = np.concatenate(
                [rpn_ret['bbox_outside_weights'], neg_bbox_outside_weights],
                axis=0)

            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_obj'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = 1 * loss_cls
            return_dict['losses']['loss_bbox'] = 1 * loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)



        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            elif cfg.MODEL.FASTER_RCNN: #hw FAST_RCNN or LIGHT_HEAD_RCNN网络预测类别和坐标
                box_feat = self.Box_Head(blob_conv, rpn_ret)
                cls_score, bbox_pred = self.Box_Outs(box_feat)
                return_dict['cls_score'] = cls_score
                return_dict['bbox_pred'] = bbox_pred
            elif cfg.MODEL.LIGHT_HEAD_RCNN:
                box_feat = F.relu(self.LightHead(blob_conv), inplace=True)
                cls_score, bbox_pred = self.Box_Outs(box_feat,rpn_ret)
                return_dict['cls_score'] = cls_score
                return_dict['bbox_pred'] = bbox_pred

        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            )) #hw RPN网络类别和坐标的损失值计算
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss #hw RCNN网络类别和坐标的损失值计算
            if cfg.MODEL.FASTER_RCNN:
                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                    rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls
            elif cfg.MODEL.LIGHT_HEAD_RCNN:  # hw
                loss_cls, loss_bbox = light_head_rcnn_heads.light_head_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox



            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(res5_feat, rpn_ret,
                                               roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
Exemple #13
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for k, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[k]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[k]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score,
                bbox_pred,
                rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'],
                stage=0)
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        if not cfg.FAST_RCNN.USE_CASCADE:
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        else:
            return_dict['rois' + '_{}'.format(0)] = rpn_ret['rois']
            return_dict['cls_score' + '_{}'.format(0)] = cls_score.detach()
            return_dict['bbox_pred' + '_{}'.format(0)] = bbox_pred.detach()

        if cfg.FAST_RCNN.USE_CASCADE:
            for i in range(2):
                i += 1
                pre_stage_name = '_{}'.format(i - 1)
                cls_score_cpu = cls_score.data.cpu().numpy()
                import utils.boxes as box_utils
                bbox_pred_cpu = bbox_pred.reshape(
                    [-1, bbox_pred.shape[-1]]).data.cpu().numpy().squeeze()
                rois = deepcopy(return_dict['rois' + pre_stage_name])

                assert cfg.MODEL.CLS_AGNOSTIC_BBOX_REG is True
                if not cfg.MODEL.CLS_AGNOSTIC_BBOX_REG:
                    cls_loc = np.argmax(cls_score_cpu[:, 1:], axis=1) + 1
                    cls_loc = np.reshape(cls_loc, (cls_loc.shape[0], 1))
                    # Based on scores, we can select transformed rois
                    scores = np.zeros((cls_score_cpu.shape[0], 1))
                    for k in range(len(cls_loc)):
                        scores[k] = cls_score_cpu[k, cls_loc[k]]

                batch_inds = rois[:, 0]
                uni_inds = np.unique(batch_inds)

                # WE suppose the WIDTH of image is EQUAL to its HEIGHT

                # We also provide an example to show how to perform the operation batch-wise
                # Scale forward
                batch_se = []
                for e in range(len(uni_inds)):
                    id_min = min(np.where(batch_inds == uni_inds[e])[0])
                    id_max = max(np.where(batch_inds == uni_inds[e])[0])
                    rois[id_min:id_max + 1, 1:5] /= im_info[e][2]
                    batch_se.append([id_min, id_max])

                pred_boxes = box_utils.bbox_transform(
                    rois[:, 1:5], bbox_pred_cpu,
                    cfg.CASCADE_RCNN.BBOX_REG_WEIGHTS[i])
                # Scale back
                for e in range(len(uni_inds)):
                    id_min = batch_se[e][0]
                    id_max = batch_se[e][1]
                    pred_boxes[id_min:id_max + 1] *= im_info[e][2]
                    pred_boxes[id_min:id_max + 1] = box_utils.clip_tiled_boxes(
                        pred_boxes[id_min:id_max + 1], im_info[e][0:2])

                cfg_key = 'TRAIN' if self.training else 'TEST'
                min_size = cfg[cfg_key].RPN_MIN_SIZE

                if not cfg.MODEL.CLS_AGNOSTIC_BBOX_REG:
                    # Cannot use for loop here which may cause "illegal memory access"
                    # Thanks to Chen-Wei Xie !
                    rows = pred_boxes.shape[0]
                    b3 = cls_loc * 4 + np.array([0, 1, 2, 3])
                    b4 = np.array(range(rows))
                    c = pred_boxes[np.repeat(b4, 4), b3.flatten()]
                    proposals = np.reshape(c, (-1, 4))
                else:
                    proposals = pred_boxes[:, 4:8]

                keep = _filter_boxes(proposals, min_size, im_info[0])

                proposals = proposals[keep, :]
                batch_inds = batch_inds[keep]

                batch_inds = np.reshape(batch_inds, [len(batch_inds), 1])

                proposals = np.concatenate((batch_inds, proposals), axis=1)

                from modeling.collect_and_distribute_fpn_rpn_proposals import CollectAndDistributeFpnRpnProposalsOp
                self.CollectAndDistributeFpnRpnProposals = CollectAndDistributeFpnRpnProposalsOp(
                )
                self.CollectAndDistributeFpnRpnProposals.training = self.training
                # proposals.astype('float32')
                blobs_out = self.CollectAndDistributeFpnRpnProposals(proposals,
                                                                     roidb,
                                                                     im_info,
                                                                     stage=i)
                # Update rpn_ret
                new_rpn_ret = {}
                for key, value in rpn_ret.items():
                    if 'rpn' in key:
                        new_rpn_ret[key] = value

                new_rpn_ret.update(blobs_out)

                if not self.training:
                    return_dict['blob_conv'] = blob_conv

                if not cfg.MODEL.RPN_ONLY:
                    if i == 1:
                        if cfg.MODEL.SHARE_RES5 and self.training:
                            box_feat, res5_feat = self.Box_Head_2(
                                blob_conv, new_rpn_ret)
                        else:
                            box_feat = self.Box_Head_2(blob_conv, new_rpn_ret)

                        cls_score, bbox_pred = self.Box_Outs_2(box_feat)
                    elif i == 2:
                        if cfg.MODEL.SHARE_RES5 and self.training:
                            box_feat, res5_feat = self.Box_Head_3(
                                blob_conv, new_rpn_ret)
                        else:
                            box_feat = self.Box_Head_3(blob_conv, new_rpn_ret)

                        cls_score, bbox_pred = self.Box_Outs_3(box_feat)

                if self.training:
                    # rpn loss
                    rpn_kwargs.update(
                        dict((k, new_rpn_ret[k]) for k in new_rpn_ret.keys()
                             if (k.startswith('rpn_cls_logits')
                                 or k.startswith('rpn_bbox_pred'))))
                    loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                        **rpn_kwargs)
                    if cfg.FPN.FPN_ON:
                        for k, lvl in enumerate(
                                range(cfg.FPN.RPN_MIN_LEVEL,
                                      cfg.FPN.RPN_MAX_LEVEL + 1)):
                            return_dict['losses']['loss_rpn_cls_fpn%d' %
                                                  lvl] += loss_rpn_cls[k]
                            return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                                  lvl] += loss_rpn_bbox[k]
                    else:
                        return_dict['losses']['loss_rpn_cls'] += loss_rpn_cls
                        return_dict['losses']['loss_rpn_bbox'] += loss_rpn_bbox

                    # bbox loss
                    loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                        cls_score,
                        bbox_pred,
                        new_rpn_ret['labels_int32'],
                        new_rpn_ret['bbox_targets'],
                        new_rpn_ret['bbox_inside_weights'],
                        new_rpn_ret['bbox_outside_weights'],
                        stage=i)

                    return_dict['losses']['loss_cls'] += loss_cls
                    return_dict['losses']['loss_bbox'] += loss_bbox
                    return_dict['metrics']['accuracy_cls'] += accuracy_cls

                return_dict['rois' + '_{}'.format(i)] = deepcopy(
                    new_rpn_ret['rois'])
                return_dict['cls_score' + '_{}'.format(i)] = cls_score.detach()
                return_dict['bbox_pred' + '_{}'.format(i)] = bbox_pred.detach()

                rpn_ret = new_rpn_ret.copy()

        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)  # (2, 1024, 50, 50), conv4

        rpn_ret = self.RPN(blob_conv, im_info, roidb)  # (N*2000, 5)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(
                    blob_conv, rpn_ret)  # box_feat: (N*512, 2014, 1, 1)
            else:
                # pdb.set_trace()
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(
                box_feat)  # cls_score: (N*512, C)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss

            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))

            # pdb.set_trace()
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)

            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            # LJY: if the image has no mask annotations, then disable loss computing for them
            for idx, e in enumerate(roidb):
                has_mask = e['has_mask']
                if has_mask:
                    continue
                ind = rpn_ret['mask_rois'][:, 0] == idx
                rpn_ret['masks_int32'][ind, :] = np.zeros_like(
                    rpn_ret['masks_int32'][ind, :])

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        return_dict = {}  # A dict to collect return variables
        if cfg.FPN.NON_LOCAL:
            blob_conv, f_div_C = self.Conv_Body(im_data)
            if cfg.MODEL.NON_LOCAL_TEST:
                return_dict['f_div_C'] = f_div_C
        else:
            blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            # we only use the car cls
            car_cls_int = 4
            if cfg.MODEL.CAR_CLS_HEAD_ON:
                if getattr(self.car_cls_Head, 'SHARE_RES5', False):
                    # TODO: add thos shared_res5 module
                    pass
                else:
                    car_cls_rot_feat = self.car_cls_Head(blob_conv, rpn_ret)
                    car_cls_score, car_cls, rot_pred = self.car_cls_Outs(
                        car_cls_rot_feat)
                    # car classification loss, we only fine tune the labelled cars

                # we only use the car cls
                car_idx = np.where(rpn_ret['labels_int32'] == car_cls_int)
                if len(cfg.TRAIN.CE_CAR_CLS_FINETUNE_WIGHT):
                    ce_weight = np.array(cfg.TRAIN.CE_CAR_CLS_FINETUNE_WIGHT)
                else:
                    ce_weight = None

                loss_car_cls, loss_rot, accuracy_car_cls = car_3d_pose_heads.fast_rcnn_car_cls_rot_losses(
                    car_cls_score[car_idx],
                    rot_pred[car_idx],
                    car_cls[car_idx],
                    rpn_ret['car_cls_labels_int32'][car_idx],
                    rpn_ret['quaternions'][car_idx],
                    ce_weight,
                    shape_sim_mat=self.shape_sim_mat)

                return_dict['losses']['loss_car_cls'] = loss_car_cls
                return_dict['losses']['loss_rot'] = loss_rot
                return_dict['metrics']['accuracy_car_cls'] = accuracy_car_cls
                return_dict['metrics']['shape_sim'] = shape_sim(
                    car_cls[car_idx].data.cpu().numpy(), self.shape_sim_mat,
                    rpn_ret['car_cls_labels_int32'][car_idx].astype('int64'))
                return_dict['metrics']['rot_diff_degree'] = rot_sim(
                    rot_pred[car_idx].data.cpu().numpy(),
                    rpn_ret['quaternions'][car_idx])

            if cfg.MODEL.TRANS_HEAD_ON:
                pred_boxes = car_3d_pose_heads.bbox_transform_pytorch(
                    rpn_ret['rois'], bbox_pred, im_info,
                    cfg.MODEL.BBOX_REG_WEIGHTS)
                car_idx = np.where(rpn_ret['labels_int32'] == car_cls_int)

                # Build translation head heres from the bounding box
                if cfg.TRANS_HEAD.INPUT_CONV_BODY:
                    pred_boxes_car = pred_boxes[:, 4 * car_cls_int:4 *
                                                (car_cls_int + 1)].squeeze(
                                                    dim=0)
                    car_trans_feat = self.car_trans_Head(
                        blob_conv, rpn_ret, pred_boxes_car)
                    car_trans_pred = self.car_trans_Outs(car_trans_feat)
                    car_trans_pred = car_trans_pred[car_idx]
                elif cfg.TRANS_HEAD.INPUT_TRIPLE_HEAD:
                    pred_boxes_car = pred_boxes[:, 4 * car_cls_int:4 *
                                                (car_cls_int + 1)].squeeze(
                                                    dim=0)
                    car_trans_feat = self.car_trans_Head(pred_boxes_car)
                    car_trans_pred = self.car_trans_Outs(
                        car_trans_feat, car_cls_rot_feat)
                    car_trans_pred = car_trans_pred[car_idx]
                else:
                    pred_boxes_car = pred_boxes[car_idx, 4 * car_cls_int:4 *
                                                (car_cls_int + 1)].squeeze(
                                                    dim=0)
                    car_trans_feat = self.car_trans_Head(pred_boxes_car)
                    car_trans_pred = self.car_trans_Outs(car_trans_feat)

                label_trans = rpn_ret['car_trans'][car_idx]
                loss_trans = car_3d_pose_heads.car_trans_losses(
                    car_trans_pred, label_trans)
                return_dict['losses']['loss_trans'] = loss_trans
                return_dict['metrics']['trans_diff_meter'], return_dict['metrics']['trans_thresh_per'] = \
                    trans_sim(car_trans_pred.data.cpu().numpy(), rpn_ret['car_trans'][car_idx],
                              cfg.TRANS_HEAD.TRANS_MEAN, cfg.TRANS_HEAD.TRANS_STD)

            # A 3D to 2D projection loss
            if cfg.MODEL.LOSS_3D_2D_ON:
                # During the mesh generation, using GT(True) or predicted(False) Car ID
                if cfg.LOSS_3D_2D.MESH_GEN_USING_GT:
                    # Acquire car id
                    car_ids = rpn_ret['car_cls_labels_int32'][car_idx].astype(
                        'int64')
                else:
                    # Using the predicted car id
                    print("Not properly implemented for pytorch")
                    car_ids = car_cls_score[car_idx].max(dim=1)
                # Get mesh vertices and generate loss
                UV_projection_loss = plane_projection_loss(
                    car_trans_pred, label_trans, rot_pred[car_idx],
                    rpn_ret['quaternions'][car_idx], car_ids, im_info,
                    self.car_models, self.intrinsic_mat, self.car_names)

                return_dict['losses'][
                    'UV_projection_loss'] = UV_projection_loss

            if cfg.MODEL.MASK_TRAIN_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                if type(v) == np.float64:
                    return_dict['metrics'][k] = v
                else:
                    return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
    def _forward(self, data, query, im_info, query_type, roidb=None, **rpn_kwargs):

        #query_type = query_type.item()

        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        # feed image data to base model to obtain base feature map
        blob_conv = self.Conv_Body(im_data)

        query_conv = []
        shot = len(query)
        for i in range(shot):
            query_conv.append(self.Conv_Body(query[i]))
        
        def pooling(feats, method='avg', dim = 0):
            feats = torch.stack(feats)
            if method == 'avg':
                feat = torch.mean(feats, dim=dim)
            elif method == 'max':
                feat, _ = torch.max(feats, dim=dim)
            return feat
        
        if cfg.FPN.FPN_ON:
            query_conv = list(map(list, zip(*query_conv)))
            query_conv = [pooling(QC_i, method='avg', dim = 0) for QC_i in query_conv]

            rpn_feat = []
            act_feat = []
            act_aim = []
            c_weight = []
            
            if len(blob_conv) == 5:
                start_match_idx = 1
                #sa_blob_conv = self.sa(blob_conv[0])
                sa_blob_conv = blob_conv[0]
                rpn_feat.append(sa_blob_conv)
            else:
                start_match_idx = 0

            for IP, QP in zip(blob_conv[start_match_idx:], query_conv[start_match_idx:]):
                _rpn_feat, _act_feat, _act_aim, _c_weight = self.match_net(IP, QP)
                rpn_feat.append(_rpn_feat)
                act_feat.append(_act_feat)
                act_aim.append(_act_aim)
                c_weight.append(_c_weight)
                """
                correlation = []
                QP_pool = self.global_avgpool(QP)
                #IP_sa = self.sa(IP)
                IP_sa = IP
                for IP_sa_batch, QP_pool_batch in zip(IP_sa, QP_pool):
                    IP_sa_batch, QP_pool_batch = IP_sa_batch.unsqueeze(0), QP_pool_batch.unsqueeze(0)
                    correlation.append(F.conv2d(IP_sa_batch, QP_pool_batch.permute(1,0,2,3), groups=IP.shape[1]).squeeze(0))
                correlation = torch.stack(correlation)
                rpn_feat.append(correlation)
                act_feat.append(correlation)
                act_aim.append(QP)
                """
        else:
            query_conv = pooling(query_conv)
            rpn_feat, act_feat, act_aim, c_weight = self.match_net(blob_conv, query_conv)
            """
            correlation = []
            QP_pool = self.global_avgpool(QP)
            #IP_sa = self.sa(IP)
            IP_sa = IP
            for IP_sa_batch, QP_pool_batch in zip(IP_sa, QP_pool):
                IP_sa_batch, QP_pool_batch = IP_sa_batch.unsqueeze(0), QP_pool_batch.unsqueeze(0)
                correlation.append(F.conv2d(IP_sa_batch, QP_pool_batch.permute(1,0,2,3), groups=IP.shape[1]).squeeze(0))
            correlation = torch.stack(correlation)
            rpn_feat = correlation
            act_feat = correlation
            act_aim = query_conv
            """
        rpn_ret = self.RPN(rpn_feat, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv
            return_dict['query_conv'] = query_conv

        if not cfg.MODEL.RPN_ONLY:
            if not cfg.FPN.FPN_ON:
                if cfg.MODEL.SHARE_RES5 and self.training:
                    if cfg.RELATION_RCNN:
                        box_feat, query_box_feat, res5_feat, query_res5_feat = self.Box_Head(act_feat, act_aim, rpn_ret)
                    else:
                        box_feat, res5_feat = self.Box_Head(act_feat, rpn_ret)
                else:
                    if cfg.RELATION_RCNN:
                        box_feat, query_box_feat= self.Box_Head(act_feat, act_aim, rpn_ret)
                    else:
                        box_feat = self.Box_Head(act_feat, rpn_ret)
                
                if cfg.RELATION_RCNN:
                    cls_score, bbox_pred = self.Box_Outs(box_feat, query_box_feat)
                else:
                    cls_score, bbox_pred = self.Box_Outs(box_feat)
            else:
                if cfg.RELATION_RCNN:
                    box_feat, query_box_feat = self.Box_Head(act_feat, act_aim, rpn_ret)
                    cls_score, bbox_pred = self.Box_Outs(box_feat, query_box_feat)
                else:
                    box_feat = self.Box_Head(act_feat, rpn_ret)
                    cls_score, bbox_pred = self.Box_Outs(box_feat)

        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            ))
            #rpn_kwargs.update({'query_type': query_type})
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls, _ = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'], rpn_ret['rois'], query_type, use_marginloss=False)
            #return_dict['losses']['margin_loss'] = margin_loss
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    if cfg.RELATION_RCNN:
                        mask_feat = self.Mask_Head(res5_feat, query_res5_feat, rpn_ret,
                                                roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                    else:
                        mask_feat = self.Mask_Head(res5_feat, rpn_ret, roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    if cfg.RELATION_RCNN:
                        mask_feat = self.Mask_Head(act_feat, act_aim, rpn_ret)
                    else:
                        mask_feat = self.Mask_Head(act_feat, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                #loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'], rpn_ret['mask_rois'], query_type)
                #return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)
            
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
Exemple #17
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data

        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            ))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls            

            # EDIT: target soft-labels with distillation loss
            if cfg.TRAIN.GT_SCORES:
                # Smooth labels for the noisy dataset assumed to be "dataset-0"
                if rpn_ret['dataset_id'][0] == 0:

                    if cfg.TRAIN.DISTILL_ATTN:
                        pre_lambda_attn = self.AttentionRoi_Head(box_feat)
                        loss_distill = fast_rcnn_heads.smooth_label_loss(
                                        cls_score, rpn_ret['labels_int32'], 
                                        rpn_ret['gt_scores'], 
                                        rpn_ret['gt_source'],
                                        dist_T=cfg.TRAIN.DISTILL_TEMPERATURE, 
                                        dist_lambda=F.sigmoid(pre_lambda_attn), 
                                        tracker_score=cfg.TRAIN.TRACKER_SCORE)
                        loss_distill += torch.sum(torch.abs(pre_lambda_attn))  # regularize attn magnitude
                    else:
                        loss_distill = fast_rcnn_heads.smooth_label_loss(
                                        cls_score, rpn_ret['labels_int32'], 
                                        rpn_ret['gt_scores'], 
                                        rpn_ret['gt_source'],
                                        dist_T=cfg.TRAIN.DISTILL_TEMPERATURE, 
                                        dist_lambda=cfg.TRAIN.DISTILL_LAMBDA, 
                                        tracker_score=cfg.TRAIN.TRACKER_SCORE)

                    return_dict['losses']['loss_cls'] = loss_distill

            # EDIT: domain adversarial losses
            if cfg.TRAIN.DOMAIN_ADAPT_IM or cfg.TRAIN.DOMAIN_ADAPT_ROI:
                if rpn_ret['dataset_id'][0] == 0:
                    # dataset-0 is assumed unlabeled - so zero out other losses
                    #   NOTE: multiplied by 0 to maintain valid autodiff graph
                    return_dict['losses']['loss_cls'] *= 0
                    return_dict['losses']['loss_bbox'] *= 0
                    return_dict['losses']['loss_rpn_cls'] *= 0
                    return_dict['losses']['loss_rpn_bbox'] *= 0

                if cfg.TRAIN.DOMAIN_ADAPT_IM:
                    da_image_pred = self.DiscriminatorImage_Head(blob_conv)                
                    loss_da_im = adversarial_heads.domain_loss_im(
                                    da_image_pred, rpn_ret['dataset_id'][0])                
                    return_dict['losses']['loss_da-im'] = loss_da_im

                if cfg.TRAIN.DOMAIN_ADAPT_ROI:                    
                    da_roi_pred = self.DiscriminatorRoi_Head(box_feat)                
                    loss_da_roi = adversarial_heads.domain_loss_im(
                                    da_roi_pred, rpn_ret['dataset_id'][0])                
                    return_dict['losses']['loss_da-roi'] = loss_da_roi

                if cfg.TRAIN.DOMAIN_ADAPT_CST:
                    assert cfg.TRAIN.DOMAIN_ADAPT_ROI and cfg.TRAIN.DOMAIN_ADAPT_IM
                    loss_da_cst = adversarial_heads.domain_loss_cst(
                                    da_image_pred, da_roi_pred)                
                    return_dict['losses']['loss_da-cst'] = loss_da_cst


            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(res5_feat, rpn_ret,
                                               roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask


            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict